File tree Expand file tree Collapse file tree 3 files changed +26
-2
lines changed Expand file tree Collapse file tree 3 files changed +26
-2
lines changed Original file line number Diff line number Diff line change 55from luminoth .models .fasterrcnn .rcnn_target import RCNNTarget
66from luminoth .models .fasterrcnn .roi_pool import ROIPoolingLayer
77from luminoth .utils .losses import smooth_l1_loss
8+ from luminoth .utils .safe_wrappers import (
9+ safe_softmax_cross_entropy_with_logits
10+ )
811from luminoth .utils .vars import (
912 get_initializer , layer_summaries , variable_summaries ,
1013 get_activation_function
@@ -304,7 +307,7 @@ def loss(self, prediction_dict):
304307
305308 # We get cross entropy loss of each proposal.
306309 cross_entropy_per_proposal = (
307- tf . nn . softmax_cross_entropy_with_logits (
310+ safe_softmax_cross_entropy_with_logits (
308311 labels = cls_target_one_hot , logits = cls_score_labeled
309312 )
310313 )
Original file line number Diff line number Diff line change 1010from .rpn_target import RPNTarget
1111from .rpn_proposal import RPNProposal
1212from luminoth .utils .losses import smooth_l1_loss
13+ from luminoth .utils .safe_wrappers import (
14+ safe_softmax_cross_entropy_with_logits
15+ )
1316from luminoth .utils .vars import (
1417 get_initializer , layer_summaries , variable_summaries ,
1518 get_activation_function
@@ -257,7 +260,7 @@ def loss(self, prediction_dict):
257260 cls_target = tf .one_hot (labels , depth = 2 )
258261
259262 # Equivalent to log loss
260- ce_per_anchor = tf . nn . softmax_cross_entropy_with_logits (
263+ ce_per_anchor = safe_softmax_cross_entropy_with_logits (
261264 labels = cls_target , logits = cls_score
262265 )
263266 prediction_dict ['cross_entropy_per_anchor' ] = ce_per_anchor
Original file line number Diff line number Diff line change 1+ import tensorflow as tf
2+
3+
4+ def safe_softmax_cross_entropy_with_logits (
5+ labels , logits , name = 'safe_cross_entropy' ):
6+ with tf .name_scope (name ):
7+ safety_condition = tf .logical_and (
8+ tf .greater (tf .shape (labels )[0 ], 0 , name = 'labels_notzero' ),
9+ tf .greater (tf .shape (logits )[0 ], 0 , name = 'logits_notzero' ),
10+ name = 'safety_condition'
11+ )
12+ return tf .cond (
13+ safety_condition ,
14+ true_fn = lambda : tf .nn .softmax_cross_entropy_with_logits (
15+ labels = labels , logits = logits
16+ ),
17+ false_fn = lambda : tf .constant ([], dtype = logits .dtype )
18+ )
You can’t perform that action at this time.
0 commit comments