@@ -348,6 +348,7 @@ def test_set_coordinates(self):
348348 self .assertRaises (ValueError , G .set_coordinates , 'invalid' )
349349
350350 def test_nngraph (self , n_vertices = 30 ):
351+ """Test all the combinations of metric, kind, backend."""
351352 features = np .random .RandomState (42 ).normal (size = (n_vertices , 3 ))
352353 metrics = ['euclidean' , 'manhattan' , 'max_dist' , 'minkowski' ]
353354 backends = ['scipy-kdtree' , 'scipy-ckdtree' , 'scipy-pdist' , 'nmslib' ,
@@ -356,46 +357,30 @@ def test_nngraph(self, n_vertices=30):
356357
357358 for backend in backends :
358359 for metric in metrics :
359- if ((backend == 'flann' and metric == 'max_dist' ) or
360- (backend == 'nmslib' and metric == 'minkowski' )):
361- self .assertRaises (ValueError , graphs .NNGraph , features ,
362- kind = 'knn' , backend = backend ,
363- metric = metric )
364- self .assertRaises (ValueError , graphs .NNGraph , features ,
365- kind = 'radius' , backend = backend ,
366- metric = metric )
367- else :
368- if backend == 'nmslib' :
369- self .assertRaises (ValueError , graphs .NNGraph , features ,
370- kind = 'radius' , backend = backend ,
371- metric = metric , order = order )
360+ for kind in ['knn' , 'radius' ]:
361+ params = dict (features = features , metric = metric ,
362+ order = order , kind = kind , backend = backend )
363+ # Unsupported combinations.
364+ if backend == 'flann' and metric == 'max_dist' :
365+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
366+ elif backend == 'nmslib' and metric == 'minkowski' :
367+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
368+ elif backend == 'nmslib' and kind == 'radius' :
369+ self .assertRaises (ValueError , graphs .NNGraph , ** params )
372370 else :
373- graphs .NNGraph (features , kind = 'radius' ,
374- backend = backend ,
375- metric = metric , order = order )
376- graphs .NNGraph (features , kind = 'knn' ,
377- backend = backend ,
378- metric = metric , order = order )
379- graphs .NNGraph (features , kind = 'knn' ,
380- backend = backend ,
381- metric = metric , order = order ,
382- center = False )
383- graphs .NNGraph (features , kind = 'knn' ,
384- backend = backend ,
385- metric = metric , order = order ,
386- rescale = False )
387- graphs .NNGraph (features , kind = 'knn' ,
388- backend = backend ,
389- metric = metric , order = order ,
390- rescale = False , center = False )
371+ graphs .NNGraph (** params , center = False )
372+ graphs .NNGraph (** params , rescale = False )
373+ graphs .NNGraph (** params , center = False , rescale = False )
374+
375+ # Invalid parameters.
376+ self .assertRaises (ValueError , graphs .NNGraph , features ,
377+ metric = 'invalid' )
391378 self .assertRaises (ValueError , graphs .NNGraph , features ,
392- kind = 'invalid' , backend = backend ,
393- metric = metric )
379+ kind = 'invalid' )
394380 self .assertRaises (ValueError , graphs .NNGraph , features ,
395- kind = 'knn' , backend = 'invalid' ,
396- metric = metric )
381+ backend = 'invalid' )
397382 self .assertRaises (ValueError , graphs .NNGraph , features ,
398- kind = 'knn' , k = n_vertices + 1 )
383+ kind = 'knn' , k = n_vertices + 1 )
399384
400385 def test_nngraph_consistency (self ):
401386 features = np .arange (90 ).reshape (30 , 3 )
0 commit comments