Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit a0c3384

Browse files
authored
Exception IPA-GNN localization (aka raise attributions) support try/except (#78)
1. Updates localization ("raise contribution") calculation to handle try/except blocks. If an exception is caught by a try/except block, but (via normal execution and branch decisions) the exception still finds its way to the exception node, the original node that raised gets credit. If another exception is raised, the new exception gets credit. 2. tests for get_nodes_at_lineno 3. Fix for target_lineno when docstring is included in source 4. New dataset (FULL_DATASET_PATH_WITH_DOCSTRINGS) using appropriate target_linenos and corresponding node indexes. 5. test_compute_localization_accuracy
1 parent 1d022c4 commit a0c3384

File tree

11 files changed

+425
-54
lines changed

11 files changed

+425
-54
lines changed

config/default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def default_config():
6363
config.eval_metric_names: Tuple[str] = metrics.all_metric_names()
6464
config.eval_subsample = 1.0
6565
config.eval_max_batches = 30
66+
config.unsupervised_localization: bool = True # Must be set to True to compute localization logits.
6667

6768
# Logging
6869
config.printoptions_threshold = 256

core/data/codenet_paths.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
DEFAULT_DATASET_PATH = 'datasets/codenet/2021-11-01-f=0.01'
99
TEST_DATASET_PATH = 'datasets/codenet/2021-11-01-f=0.01'
1010
DEFAULT_TOKENIZER_PATH = 'out/tokenizers/train-1000000.json'
11+
DOCSTRING_TOKENIZER_PATH = 'out/tokenizers/train-docstrings-1000000.json'
1112
DEFAULT_SPLITS_PATH = 'out/splits/default.json'
1213
DEFAULT_EXPERIMENTS_DIR = 'out/experiments'
1314
EXPERIMENT_ID_PATH = 'out/experiment_id.txt'
1415

1516
FULL_DATASET_PATH = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-10-07-full'
16-
FULL_DATASET_PATH_WITH_DOCSTRINGS = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-11-01'
17+
FULL_DATASET_PATH_WITH_DOCSTRINGS = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-11-17'
1718
# Raw control_flow_programs data pattern:
1819
DEFAULT_CFP_DATA_PATTERN = '/mnt/runtime-error-problems-experiments/datasets/control_flow_programs/decimal-large-state-L10/0.0.48/control_flow_programs-train.tfrecord-*'
1920
# Processed control_flow_programs dataset path:

core/data/test_process.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,142 @@ def test_make_runtimeerrorproblem_try_finally_in_try_except(self):
314314
# a finally inside the try.
315315
# Can only get into "raising" territory via a finally block's true branch or via a raise edge.
316316

317+
def test_get_nodes_at_lineno_no_error(self):
318+
lineno = 0
319+
target = '1'
320+
source = """x = 1
321+
while x < 2:
322+
y = 3
323+
while y < 4:
324+
y += 5
325+
x += 6
326+
"""
327+
raw = process.make_rawruntimeerrorproblem(
328+
source, target, lineno)
329+
nodes = process.get_nodes_at_lineno(raw, lineno)
330+
self.assertEqual(nodes, [])
331+
332+
def test_get_nodes_at_lineno_1(self):
333+
lineno = 1 # x = 1
334+
target = '1'
335+
source = """x = 1
336+
while x < 2:
337+
y = 3
338+
while y < 4:
339+
y += 5
340+
x += 6
341+
"""
342+
raw = process.make_rawruntimeerrorproblem(
343+
source, target, lineno)
344+
nodes = process.get_nodes_at_lineno(raw, lineno)
345+
self.assertEqual(nodes, [0])
346+
347+
def test_get_nodes_at_lineno_2(self):
348+
lineno = 2 # while x < 2:
349+
target = '1'
350+
source = """x = 1
351+
while x < 2:
352+
y = 3
353+
while y < 4:
354+
y += 5
355+
x += 6
356+
"""
357+
raw = process.make_rawruntimeerrorproblem(
358+
source, target, lineno)
359+
nodes = process.get_nodes_at_lineno(raw, lineno)
360+
self.assertEqual(nodes, [1])
361+
362+
def test_get_nodes_at_lineno_docstring(self):
363+
lineno = 5 # while x < 2:
364+
target = '1'
365+
source = '''"""Example
366+
docstring
367+
"""
368+
x = 1
369+
while x < 2:
370+
y = 3
371+
while y < 4:
372+
y += 5
373+
x += 6
374+
'''
375+
raw = process.make_rawruntimeerrorproblem(
376+
source, target, lineno)
377+
nodes = process.get_nodes_at_lineno(raw, lineno)
378+
self.assertEqual(nodes, [2])
379+
380+
def test_get_nodes_at_lineno_for(self):
381+
lineno = 5 # for y in range(100):
382+
target = '1'
383+
source = '''"""Example
384+
docstring
385+
"""
386+
x = 1
387+
for y in range(100):
388+
while y < 4:
389+
y += 5
390+
x += 6
391+
'''
392+
raw = process.make_rawruntimeerrorproblem(
393+
source, target, lineno)
394+
nodes = process.get_nodes_at_lineno(raw, lineno)
395+
self.assertEqual(nodes, [2, 3])
396+
397+
def test_get_nodes_at_lineno_multiline(self):
398+
lineno = 6 # 100/0
399+
target = '1'
400+
source = '''"""Example
401+
docstring
402+
"""
403+
x = 1
404+
for y in range(
405+
100/0
406+
):
407+
while y < 4:
408+
y += 5
409+
x += 6
410+
'''
411+
raw = process.make_rawruntimeerrorproblem(
412+
source, target, lineno)
413+
nodes = process.get_nodes_at_lineno(raw, lineno)
414+
self.assertEqual(nodes, [2]) # range(100/0)
415+
416+
def test_get_nodes_at_lineno_multiline_unpack(self):
417+
lineno = 6 # for x,y in range(
418+
target = '1'
419+
source = r'''"""Example
420+
docstring
421+
"""
422+
x = 1
423+
for \
424+
x,y\
425+
in range(100):
426+
while y < 4:
427+
y += 5
428+
x += 6
429+
'''
430+
raw = process.make_rawruntimeerrorproblem(
431+
source, target, lineno)
432+
nodes = process.get_nodes_at_lineno(raw, lineno)
433+
self.assertEqual(nodes, [3])
434+
435+
def test_get_nodes_at_lineno_multiline_ambiguous(self):
436+
lineno = 5 # for x,y in range(
437+
target = '1'
438+
source = '''"""Example
439+
docstring
440+
"""
441+
x = 1
442+
for x,y in range(
443+
100
444+
):
445+
while y < 4:
446+
y += 5
447+
x += 6
448+
'''
449+
raw = process.make_rawruntimeerrorproblem(
450+
source, target, lineno)
451+
nodes = process.get_nodes_at_lineno(raw, lineno)
452+
self.assertEqual(nodes, [2, 3])
317453

318454
if __name__ == '__main__':
319455
unittest.main()

core/lib/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def compute_localization_accuracy(
201201
return None
202202

203203
def is_correct(targets, num_targets, prediction):
204+
# targets.shape: max_num_targets
205+
# num_targets.shape: scalar.
204206
is_example = num_targets > 0
205207
mask = jnp.arange(targets.shape[0]) < num_targets
206208
# mask.shape: max_num_nodes

core/lib/test_metrics.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,26 @@ def test_compute_weighted_f1_score_error_only_omits_correct_examples(self):
151151
# weighted average is 1/3.
152152
self.assertEqual(f1_score, 1/3)
153153

154+
def test_compute_localization_accuracy(self):
155+
localization_targets = jnp.array([
156+
[0, 1, 2, 0, 0, 0, 0], # correct
157+
[1, 2, 0, 0, 0, 0, 0], # correct
158+
[0, 0, 0, 0, 0, 0, 0], # is_example == False
159+
[0, 0, 0, 0, 0, 0, 0], # correct
160+
[0, 1, 2, 0, 0, 0, 0], # incorrect
161+
[1, 2, 0, 0, 0, 0, 0], # incorrect
162+
[0, 0, 0, 0, 0, 0, 0], # is_example == False
163+
[0, 0, 0, 0, 0, 0, 0], # incorrect
164+
[4, 5, 6, 0, 0, 0, 0], # correct
165+
])
166+
localization_num_targets = jnp.array([3, 2, 0, 1, 3, 2, 0, 1, 3])
167+
localization_predictions = jnp.array([0, 2, 0, 0, 3, 0, 1, 1, 4])
168+
acc = metrics.compute_localization_accuracy(
169+
localization_targets,
170+
localization_num_targets,
171+
localization_predictions)
172+
self.assertEqual(acc, 4/7)
173+
174+
154175
if __name__ == '__main__':
155176
unittest.main()

core/models/ipagnn.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,4 @@ def __call__(self, x):
122122
)(exit_node_embeddings)
123123
# logits.shape: batch_size, num_classes
124124

125-
if config.raise_in_ipagnn:
126-
per_node_raise_contributions = raise_contributions_lib.get_raise_contribution_from_batch_and_aux(
127-
x, ipagnn_output)
128-
localization_logits = per_node_raise_contributions
129-
ipagnn_output['localization_logits'] = localization_logits
130-
131125
return logits, ipagnn_output

core/modules/ipagnn/ipagnn.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from core.lib.metrics import EvaluationMetric
88
from core.modules.ipagnn import rnn
9+
from core.modules.ipagnn import raise_contributions as raise_contributions_lib
910

1011

1112
def _rnn_state_to_embedding(hidden_state):
@@ -61,7 +62,7 @@ def __call__(
6162
config = self.config
6263

6364
# State. Varies from step to step.
64-
hidden_states, instruction_pointer, current_step = carry
65+
hidden_states, instruction_pointer, attribution, current_step = carry
6566

6667
# Inputs.
6768
vocab_size = info.vocab_size
@@ -229,6 +230,8 @@ def set_values(a, value, index):
229230
# raise_decision.shape: batch_size, num_nodes, 2
230231
# Make sure you cannot raise from the exit node.
231232
raise_decisions = batch_set(raise_decisions, jnp.array([0, 1]), exit_node_indexes)
233+
# Make sure you cannot raise from the raise node.
234+
raise_decisions = batch_set(raise_decisions, jnp.array([0, 1]), raise_node_indexes)
232235
# raise_decision.shape: batch_size, num_nodes, 2
233236
else:
234237
raise_decisions = jnp.concatenate([
@@ -257,6 +260,18 @@ def set_values(a, value, index):
257260
raise_node_indexes, true_indexes, false_indexes, raise_indexes)
258261
# leaves(hidden_states_new).shape: batch_size, num_nodes, hidden_size
259262

263+
attribution = raise_contributions_lib.get_raise_contribution_step_batch(
264+
attribution,
265+
instruction_pointer,
266+
branch_decisions,
267+
raise_decisions,
268+
true_indexes,
269+
false_indexes,
270+
raise_indexes,
271+
num_nodes,
272+
)
273+
# attribution.shape: batch_size, num_nodes, num_nodes
274+
260275
# current_step.shape: batch_size
261276
# step_limits.shape: batch_size
262277
instruction_pointer_orig = instruction_pointer
@@ -281,7 +296,7 @@ def set_values(a, value, index):
281296
'hidden_state_contributions': hidden_state_contributions,
282297
}
283298
aux.update(aux_ip)
284-
return (hidden_states, instruction_pointer, current_step), aux
299+
return (hidden_states, instruction_pointer, attribution, current_step), aux
285300

286301

287302
class IPAGNNModule(nn.Module):
@@ -400,10 +415,12 @@ def make_instruction_pointer(start_node_index):
400415
instruction_pointer = jax.vmap(make_instruction_pointer)(start_node_indexes)
401416
# instruction_pointer.shape: batch_size, num_nodes
402417

418+
attribution = jnp.zeros((batch_size, num_nodes, num_nodes))
419+
403420
# Run self.max_steps steps of IPAGNNLayer.
404-
(hidden_states, instruction_pointer, current_step), aux = self.ipagnn_layer_scan(
421+
(hidden_states, instruction_pointer, attribution, current_step), aux = self.ipagnn_layer_scan(
405422
# State:
406-
(hidden_states, instruction_pointer, current_step),
423+
(hidden_states, instruction_pointer, attribution, current_step),
407424
# Inputs:
408425
node_embeddings,
409426
edge_sources,
@@ -438,6 +455,10 @@ def get_hidden_state_single_example(hidden_states, node_index):
438455
raise_node_instruction_pointer = get_instruction_pointer_value(instruction_pointer, raise_node_indexes)
439456
# raise_node_instruction_pointer.shape: batch_size
440457

458+
if config.raise_in_ipagnn and config.unsupervised_localization:
459+
localization_logits = attribution[jnp.arange(batch_size), raise_node_indexes]
460+
aux['localization_logits'] = localization_logits
461+
441462
aux.update({
442463
'exit_node_instruction_pointer': exit_node_instruction_pointer,
443464
'exit_node_embeddings': exit_node_embeddings,

0 commit comments

Comments
 (0)