@@ -83,25 +83,30 @@ def update_leaf_distributions(
8383 """
8484 batch_size , num_classes = logits .shape
8585
86- y_true_one_hot = F .one_hot (y_true , num_classes = num_classes )
87- y_true_logits = torch .log (y_true_one_hot )
86+ # Other sparse formats may be better than COO.
87+ # TODO: This is a bit convoluted. Why is there no sparse version of torch.nn.functional.one_hot ?
88+ y_true_range = torch .arange (0 , batch_size )
89+ y_true_indices = torch .stack ((y_true_range , y_true ))
90+ y_true_one_hot = torch .sparse_coo_tensor (
91+ y_true_indices , torch .ones_like (y_true , dtype = torch .bool ), logits .shape
92+ )
8893
8994 for leaf in root .leaves :
90- update_leaf (leaf , node_to_prob , logits , y_true_logits , smoothing_factor )
95+ update_leaf (leaf , node_to_prob , logits , y_true_one_hot , smoothing_factor )
9196
9297
9398def update_leaf (
9499 leaf : Leaf ,
95100 node_to_prob : dict [Node , NodeProbabilities ],
96101 logits : torch .Tensor ,
97- y_true_logits : torch .Tensor ,
102+ y_true_one_hot : torch .Tensor ,
98103 smoothing_factor : float ,
99104):
100105 """
101106 :param leaf:
102107 :param node_to_prob:
103108 :param logits: of shape (batch_size, num_classes)
104- :param y_true_logits: of shape (batch_size, num_classes)
109+ :param y_true_one_hot: boolean tensor of shape (batch_size, num_classes)
105110 :param smoothing_factor:
106111 :return:
107112 """
@@ -110,15 +115,15 @@ def update_leaf(
110115 # shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
111116 leaf_logits = leaf .y_logits ()
112117
113- # TODO: y_true_logits is mostly -Inf terms (the rest being 0s) that won't contribute to the total, and we are also
114- # summing together tensors of different shapes. We should be able to express this more clearly and efficiently by
115- # taking advantage of this sparsity.
116- log_dist_update = torch . logsumexp (
117- log_p_arrival + leaf_logits + y_true_logits - logits ,
118- dim = 0 ,
119- )
118+ masked_logits = logits . sparse_mask ( y_true_one_hot )
119+ masked_log_p_arrival = y_true_one_hot * log_p_arrival
120+ masked_leaf_logits = y_true_one_hot * leaf_logits
121+ masked_log_combined = masked_log_p_arrival + masked_leaf_logits - masked_logits
122+
123+ # TODO: Can't use logsumexp because masked tensors don't support it.
124+ masked_dist_update = torch . logsumexp ( masked_log_combined , dim = 0 )
120125
121- dist_update = torch . exp ( log_dist_update )
126+ dist_update = masked_dist_update . to_tensor ( 0.0 )
122127
123128 # This scaling (subtraction of `-1/n_batches * c` in the ProtoTree paper) seems to be a form of exponentially
124129 # weighted moving average, designed to ensure stability of the leaf class probability distributions (
0 commit comments