@@ -84,8 +84,8 @@ def update_leaf_distributions(
8484 # y_true_range = torch.arange(0, batch_size)
8585 # y_true_indices = torch.stack((y_true_range, y_true))
8686 # y_true_one_hot = torch.sparse_coo_tensor(y_true_indices,
87- # torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Or other more suitable sparse format,
88- # or even better,
87+ # torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Might be better to use CSR or CSC
88+ # or better still ,
8989 # y_true_one_hot = F.sparse_one_hot(y_true, num_classes=num_classes, dtype=torch.bool),
9090 # but PyTorch doesn't yet have sufficient sparse mask support for the logic in update_leaf to work.
9191 y_true_one_hot = F .one_hot (y_true , num_classes = num_classes ).to (dtype = torch .bool )
@@ -113,6 +113,12 @@ def update_leaf(
113113 log_p_arrival = node_to_prob [leaf ].log_p_arrival .unsqueeze (1 )
114114 # shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
115115 leaf_logits = leaf .logits ()
116+
117+ # TODO If PyTorch had more support for sparse masks we might be able to do something like
118+ # masked_logits = logits.sparse_mask(y_true_one_hot),
119+ # and perhaps if necessary combine it with
120+ # masked_log_p_arrival = y_true_one_hot * log_p_arrival # sparse_mask can't broadcast
121+ # masked_leaf_logits = y_true_one_hot * leaf_logits # sparse_mask can't broadcast.
116122 masked_logits = masked_tensor (logits , y_true_one_hot )
117123
118124 masked_log_combined = log_p_arrival + (leaf_logits - masked_logits )
0 commit comments