File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -170,23 +170,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
170170
171171
172172@pytest .mark .parametrize ('bsize' , [5 , 10 ]) 
173- def  test_batch_3d_squeeze_batch_dim ( sample_ds_3d , bsize ):
173+ def  test_batch_1d_squeeze_batch_dim ( sample_ds_1d , bsize ):
174174    xbsize  =  20 
175175    bg  =  BatchGenerator (
176-         sample_ds_3d ,
177-         input_dims = {'time'  :  1 ,  'y' :  bsize ,  ' x'xbsize },
176+         sample_ds_1d ,
177+         input_dims = {'x' : xbsize },
178178        squeeze_batch_dim = False ,
179179    )
180180    for  ds_batch  in  bg :
181-         assert  ds_batch ['x ' ].shape  ==  [1 ,  bsize , xbsize ]
181+         assert  list ( ds_batch ['foo ' ].shape )  ==  [1 , xbsize ]
182182
183183    bg2  =  BatchGenerator (
184-         sample_ds_3d ,
185-         input_dims = {'time'  :  1 ,  'y' :  bsize ,  ' x'xbsize },
184+         sample_ds_1d ,
185+         input_dims = {'x' : xbsize },
186186        squeeze_batch_dim = True ,
187187    )
188-     for  ds_batch  in  bg :
189-         assert  ds_batch ['x ' ].shape  ==  [bsize ,  xbsize ]
188+     for  ds_batch  in  bg2 :
189+         assert  list ( ds_batch ['foo ' ].shape )  ==  [xbsize ]
190190
191191
192192def  test_preload_batch_false (sample_ds_1d ):
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments