@@ -84,12 +84,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
8484            x .get ("continuous" , torch .empty (0 , 0 )),
8585            x .get ("categorical" , torch .empty (0 , 0 )),
8686        )
87-         assert  (
88-             categorical_data . shape [ 1 ]  ==   self . categorical_dim 
89-         ),  "categorical_data must have same number of columns as categorical embedding layers" 
90-         assert  (
91-             continuous_data . shape [ 1 ]  ==   self . continuous_dim 
92-         ),  "continuous_data must have same number of columns as continuous dim" 
87+         assert  categorical_data . shape [ 1 ]  ==   self . categorical_dim ,  (
88+             " categorical_data must have same number of columns as categorical embedding layers" 
89+         )
90+         assert  continuous_data . shape [ 1 ]  ==   self . continuous_dim ,  (
91+             " continuous_data must have same number of columns as continuous dim" 
92+         )
9393        embed  =  None 
9494        if  continuous_data .shape [1 ] >  0 :
9595            if  self .batch_norm_continuous_input :
@@ -141,12 +141,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
141141            x .get ("continuous" , torch .empty (0 , 0 )),
142142            x .get ("categorical" , torch .empty (0 , 0 )),
143143        )
144-         assert  categorical_data .shape [1 ] ==  len (
145-             self . cat_embedding_layers 
146-         ),  "categorical_data must have same number of columns as categorical embedding layers" 
147-         assert  (
148-             continuous_data . shape [ 1 ]  ==   self . continuous_dim 
149-         ),  "continuous_data must have same number of columns as continuous dim" 
144+         assert  categorical_data .shape [1 ] ==  len (self . cat_embedding_layers ), ( 
145+             "categorical_data must have same number of columns as categorical embedding layers" 
146+         )
147+         assert  continuous_data . shape [ 1 ]  ==   self . continuous_dim ,  (
148+             " continuous_data must have same number of columns as continuous dim" 
149+         )
150150        embed  =  None 
151151        if  continuous_data .shape [1 ] >  0 :
152152            if  self .batch_norm_continuous_input :
@@ -273,12 +273,12 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
273273            x .get ("continuous" , torch .empty (0 , 0 )),
274274            x .get ("categorical" , torch .empty (0 , 0 )),
275275        )
276-         assert  categorical_data .shape [1 ] ==  len (
277-             self . cat_embedding_layers 
278-         ),  "categorical_data must have same number of columns as categorical embedding layers" 
279-         assert  (
280-             continuous_data . shape [ 1 ]  ==   self . continuous_dim 
281-         ),  "continuous_data must have same number of columns as continuous dim" 
276+         assert  categorical_data .shape [1 ] ==  len (self . cat_embedding_layers ), ( 
277+             "categorical_data must have same number of columns as categorical embedding layers" 
278+         )
279+         assert  continuous_data . shape [ 1 ]  ==   self . continuous_dim ,  (
280+             " continuous_data must have same number of columns as continuous dim" 
281+         )
282282        embed  =  None 
283283        if  continuous_data .shape [1 ] >  0 :
284284            cont_idx  =  torch .arange (self .continuous_dim , device = continuous_data .device ).expand (
0 commit comments