3131logging .getLogger ("p2ch14.dsets" ).setLevel (logging .WARNING )
3232
3333def  print_confusion (label , confusions , do_mal ):
34+     row_labels  =  ['Non-Nodules' , 'Benign' , 'Malignant' ]
35+ 
3436    if  do_mal :
35-         col_labels  =  ['' , 'Complete Miss' , 'Filtered' , 'Benign' , 'Malignant' ]
36-         row_labels  =  ['Non-Nodules' , 'Benign' , 'Malignant' ]
37+         col_labels  =  ['' , 'Complete Miss' , 'Filtered Out' , 'Pred. Benign' , 'Pred. Malignant' ]
3738    else :
38-         col_labels  =  ['' , 'Complete Miss' , 'Filtered' , 'Detected' ]
39-         row_labels  =  ['Non-Nodules' , 'Nodules' ]
40-         confusions [- 2 ] +=  confusions [- 1 ]
39+         col_labels  =  ['' , 'Complete Miss' , 'Filtered Out' , 'Pred. Nodule' ]
4140        confusions [:, - 2 ] +=  confusions [:, - 1 ]
42-         confusions  =  confusions [:- 1 , :- 1 ]
43-     cell_width  =  14 
41+         confusions  =  confusions [:, :- 1 ]
42+     cell_width  =  16 
4443    f  =  '{:>'  +  str (cell_width ) +  '}' 
4544    print (label )
4645    print (' | ' .join ([f .format (s ) for  s  in  col_labels ]))
@@ -72,7 +71,7 @@ def match_and_score(detections, truth, threshold=0.5, threshold_mal=0.5):
7271    confusion  =  np .zeros ((3 , 4 ), dtype = np .int )
7372    if  len (detected_xyz ) ==  0 :
7473        for  tn  in  true_nodules :
75-             confusiion [2  if  tn .isMal_bool  else  1 , 0 ] +=  1 
74+             confusion [2  if  tn .isMal_bool  else  1 , 0 ] +=  1 
7675    elif  len (truth_xyz ) ==  0 :
7776        for  dc  in  detected_classes :
7877            confusion [0 , dc ] +=  1 
@@ -124,7 +123,7 @@ def __init__(self, sys_argv=None):
124123        parser .add_argument ('--segmentation-path' ,
125124            help = "Path to the saved segmentation model" ,
126125            nargs = '?' ,
127-             default = None ,
126+             default = 'data/part2/models/seg_2020-01-26_19.45.12_w4d3c1-bal_1_nodupe-label_pos-d1_fn8-adam.best.state' ,
128127        )
129128
130129        parser .add_argument ('--cls-model' ,
@@ -135,13 +134,14 @@ def __init__(self, sys_argv=None):
135134        parser .add_argument ('--classification-path' ,
136135            help = "Path to the saved classification model" ,
137136            nargs = '?' ,
138-             default = None ,
137+             default = 'data/part2/models/cls_2020-02-06_14.16.55_final-nodule-nonnodule.best.state' ,
139138        )
140139
141140        parser .add_argument ('--malignancy-model' ,
142141            help = "What to model class name to use for the malignancy classifier." ,
143142            action = 'store' ,
144-             default = 'ModifiedLunaModel' ,
143+             default = 'LunaModel' ,
144+             # default='ModifiedLunaModel', 
145145        )
146146        parser .add_argument ('--malignancy-path' ,
147147            help = "Path to the saved malignancy classification model" ,
@@ -303,7 +303,6 @@ def main(self):
303303        val_list  =  sorted (series_set  &  val_set )
304304
305305
306-         candidateInfo_list  =  []
307306        candidateInfo_dict  =  getCandidateInfoDict ()
308307        series_iter  =  enumerateWithEstimate (
309308            val_list  +  train_list ,
@@ -314,10 +313,8 @@ def main(self):
314313            ct  =  getCt (series_uid )
315314            mask_a  =  self .segmentCt (ct , series_uid )
316315
317-             candidateInfo_list  =  self .clusterSegmentationOutput (
318-                 series_uid ,
319-                 ct ,
320-                 mask_a ,
316+             candidateInfo_list  =  self .groupSegmentationOutput (
317+                 series_uid , ct , mask_a 
321318            )
322319            classifications_list  =  self .classifyCandidates (ct , candidateInfo_list )
323320
@@ -339,7 +336,6 @@ def main(self):
339336        print_confusion ("Total" , all_confusion , self .malignancy_model  is  not None )
340337
341338
342- 
343339    def  classifyCandidates (self , ct , candidateInfo_list ):
344340        cls_dl  =  self .initClassificationDl (candidateInfo_list )
345341        classifications_list  =  []
@@ -348,49 +344,50 @@ def classifyCandidates(self, ct, candidateInfo_list):
348344
349345            input_g  =  input_t .to (self .device )
350346            with  torch .no_grad ():
351-                 _ , probability_g  =  self .cls_model (input_g )
347+                 _ , probability_nodule_g  =  self .cls_model (input_g )
352348                if  self .malignancy_model  is  not None :
353349                    _ , probability_mal_g  =  self .malignancy_model (input_g )
354350                else :
355-                     probability_mal_g  =  torch .zeros_like (probability_g )
351+                     probability_mal_g  =  torch .zeros_like (probability_nodule_g )
356352
357-             for  center_irc , prob , prob_mal  in  zip (center_list ,
358-                                                   probability_g [:,1 ].tolist (),
359-                                                   probability_mal_g [:,1 ].tolist ()
360-                                                   ):
353+             zip_iter  =  zip (
354+                 center_list ,
355+                 probability_nodule_g [:,1 ].tolist (),
356+                 probability_mal_g [:,1 ].tolist (),
357+             )
358+             for  center_irc , prob_nodule , prob_mal  in  zip_iter :
361359                center_xyz  =  irc2xyz (
362360                    center_irc ,
363361                    direction_a = ct .direction_a ,
364362                    origin_xyz = ct .origin_xyz ,
365-                     vxSize_xyz = ct .vxSize_xyz )
366-                 classifications_list .append ((prob , prob_mal , center_xyz , center_irc ))
363+                     vxSize_xyz = ct .vxSize_xyz ,
364+                 )
365+                 cls_tup  =  (prob_nodule , prob_mal , center_xyz , center_irc )
366+                 classifications_list .append (cls_tup )
367367        return  classifications_list 
368368
369369    def  segmentCt (self , ct , series_uid ):
370370        with  torch .no_grad ():
371371            output_a  =  np .zeros_like (ct .hu_a , dtype = np .float32 )
372372            seg_dl  =  self .initSegmentationDl (series_uid )
373373            for  batch_tup  in  seg_dl :
374-                 input_t  =  batch_tup [0 ]
375-                 ndx_list  =  batch_tup [4 ]
374+                 input_t , label_t , series_list , slice_ndx_list  =  batch_tup 
376375
377376                input_g  =  input_t .to (self .device )
378377                prediction_g  =  self .seg_model (input_g )
379378
380-                 for  i , sample_ndx  in  enumerate (ndx_list ):
381-                     output_a [sample_ndx ] =  prediction_g [i ].cpu ().numpy ()
379+                 for  i , slice_ndx  in  enumerate (slice_ndx_list ):
380+                     output_a [slice_ndx ] =  prediction_g [i ].cpu ().numpy ()
382381
383-             # mask_a = output_a > 0.25 
384382            mask_a  =  output_a  >  0.5 
385-             # mask_a = morphology.binary_erosion(mask_a, iterations=1) 
386-             # mask_a = morphology.binary_dilation(mask_a, iterations=2) 
383+             mask_a  =  morphology .binary_erosion (mask_a , iterations = 1 )
387384
388385        return  mask_a 
389386
390-     def  clusterSegmentationOutput (self , series_uid ,  ct , clean_a ):
387+     def  groupSegmentationOutput (self , series_uid ,  ct , clean_a ):
391388        candidateLabel_a , candidate_count  =  measurements .label (clean_a )
392389        centerIrc_list  =  measurements .center_of_mass (
393-             ct .hu_a  +  1001 ,
390+             ct .hu_a . clip ( - 1000 ,  1000 )  +  1001 ,
394391            labels = candidateLabel_a ,
395392            index = np .arange (1 , candidate_count + 1 ),
396393        )
0 commit comments