Skip to content
This repository was archived by the owner on Jan 22, 2024. It is now read-only.

Commit b8d7f31

Browse files
authored
Larger and more consistent datasets (#79)
1. Fix get_python_major_version (it was missing a comma) so we aren't dropping nearly 1M good examples any more. Oops. 2. Add in_dataset field and config.use_in_dataset_field to ensure docstring- and non-docstring- datasets have the same examples. When use_in_dataset_field, only examples with in_dataset=1 pass through filtering. 3. Adds docstring_tokens to support FiLM-style approaches 4. Regenerates dataset, adding new small 1% dataset to git repo 5. Lots of counting and printing added to process_codenet to track down the pesky missing comma.
1 parent a0c3384 commit b8d7f31

File tree

20 files changed

+195751
-21
lines changed

20 files changed

+195751
-21
lines changed

config/default.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def default_config():
2020
config.experiment_id: Optional[Text] = '' # An experiment is launched by a single command, may have multiple runs.
2121
config.run_id: Optional[Text] = '' # A run is a single trainer run with a single set of hparams. run_id should identify hparams.
2222
config.notes: Optional[Text] = '' # Any notes to record about the run.
23+
config.use_in_dataset_field = True
2324

2425
# Training configs
2526
config.optimizer = 'adam' # sgd, adam

core/data/codenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_python_major_version(problem_id, submission_id):
151151
'Python3',
152152
'Python (3.4.2)',
153153
'Python (3.4.3)',
154-
'Python (3.8.2)'
154+
'Python (3.8.2)',
155155
'PyPy3 (2.4.0)',
156156
'PyPy3 (7.3.0)',
157157
]:

core/data/codenet_paths.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
import time
66

77
DEFAULT_CONFIG_PATH = 'config/default.py'
8-
DEFAULT_DATASET_PATH = 'datasets/codenet/2021-11-01-f=0.01'
9-
TEST_DATASET_PATH = 'datasets/codenet/2021-11-01-f=0.01'
8+
DEFAULT_DATASET_PATH = 'datasets/codenet/2021-12-06-f=0.01'
9+
TEST_DATASET_PATH = 'datasets/codenet/2021-12-06-f=0.01'
1010
DEFAULT_TOKENIZER_PATH = 'out/tokenizers/train-1000000.json'
1111
DOCSTRING_TOKENIZER_PATH = 'out/tokenizers/train-docstrings-1000000.json'
1212
DEFAULT_SPLITS_PATH = 'out/splits/default.json'
1313
DEFAULT_EXPERIMENTS_DIR = 'out/experiments'
1414
EXPERIMENT_ID_PATH = 'out/experiment_id.txt'
1515

16-
FULL_DATASET_PATH = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-10-07-full'
17-
FULL_DATASET_PATH_WITH_DOCSTRINGS = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-11-17'
16+
FULL_DATASET_PATH = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-12-06-nodoc'
17+
FULL_DATASET_PATH_WITH_DOCSTRINGS = '/mnt/runtime-error-problems-experiments/datasets/project-codenet/2021-12-06'
1818
# Raw control_flow_programs data pattern:
1919
DEFAULT_CFP_DATA_PATTERN = '/mnt/runtime-error-problems-experiments/datasets/control_flow_programs/decimal-large-state-L10/0.0.48/control_flow_programs-train.tfrecord-*'
2020
# Processed control_flow_programs dataset path:

core/data/data_io.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def to_tf_example(problem):
1919
"""Constructs a tf.train.Example for the process.RuntimeErrorProblem."""
2020
return tf.train.Example(features=tf.train.Features(feature={
2121
'tokens': _int64_feature(problem.tokens),
22+
'docstring_tokens': _int64_feature(problem.docstring_tokens),
2223
'edge_sources': _int64_feature(problem.edge_sources),
2324
'edge_dests': _int64_feature(problem.edge_dests),
2425
'edge_types': _int64_feature(problem.edge_types),
@@ -39,6 +40,7 @@ def to_tf_example(problem):
3940
'problem_id': _bytes_feature([problem.problem_id]),
4041
'submission_id': _bytes_feature([problem.submission_id]),
4142

43+
'in_dataset': _int64_feature([problem.in_dataset]),
4244
'num_tokens': _int64_feature([len(problem.tokens)]),
4345
'num_nodes': _int64_feature([len(problem.true_branch_nodes)]),
4446
'num_edges': _int64_feature([len(problem.edge_sources)]),
@@ -48,6 +50,7 @@ def to_tf_example(problem):
4850
def decode_fn(record_bytes, include_strings=False):
4951
features = {
5052
'tokens': _int64_sequence_feature(),
53+
'docstring_tokens': _int64_sequence_feature(),
5154
'edge_sources': _int64_sequence_feature(),
5255
'edge_dests': _int64_sequence_feature(),
5356
'edge_types': _int64_sequence_feature(),
@@ -65,6 +68,7 @@ def decode_fn(record_bytes, include_strings=False):
6568
'target_node_indexes': _int64_sequence_feature(),
6669
'num_target_nodes': _int64_scalar_feature(),
6770

71+
'in_dataset': _int64_scalar_feature(),
6872
'num_tokens': _int64_scalar_feature(),
6973
'num_nodes': _int64_scalar_feature(),
7074
'num_edges': _int64_scalar_feature(),
@@ -80,6 +84,7 @@ def decode_fn(record_bytes, include_strings=False):
8084
def get_fake_input(batch_size, max_tokens, max_num_nodes, max_num_edges):
8185
return {
8286
'tokens': jnp.ones((batch_size, max_tokens), dtype=jnp.int32),
87+
'docstring_tokens': jnp.ones((batch_size, max_tokens), dtype=jnp.int32),
8388
'edge_sources': jnp.zeros((batch_size, max_num_edges), dtype=jnp.int32),
8489
'edge_dests': jnp.ones((batch_size, max_num_edges), dtype=jnp.int32),
8590
'edge_types': jnp.zeros((batch_size, max_num_edges), dtype=jnp.int32),
@@ -101,6 +106,7 @@ def get_fake_input(batch_size, max_tokens, max_num_nodes, max_num_edges):
101106
# 'problem_id': jnp.full((batch_size,), 'p12345', dtype=jnp.string),
102107
# 'submission_id': jnp.full((batch_size,), 's123456789', dtype=jnp.string),
103108

109+
'in_dataset': jnp.ones((batch_size, 1), dtype=jnp.int32),
104110
'num_tokens': jnp.full((batch_size, 1), max_tokens, dtype=jnp.int32),
105111
'num_nodes': jnp.full((batch_size, 1), max_num_nodes, dtype=jnp.int32),
106112
'num_edges': jnp.full((batch_size, 1), max_num_edges, dtype=jnp.int32),
@@ -113,6 +119,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
113119
max_target_nodes = 20
114120
shapes = {
115121
'tokens': [max_tokens],
122+
'docstring_tokens': [max_tokens],
116123
'edge_sources': [max_num_edges],
117124
'edge_dests': [max_num_edges],
118125
'edge_types': [max_num_edges],
@@ -130,6 +137,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
130137
'target_node_indexes': [max_target_nodes],
131138
'num_target_nodes': [1],
132139

140+
'in_dataset': [1],
133141
'num_tokens': [1],
134142
'num_nodes': [1],
135143
'num_edges': [1],
@@ -146,6 +154,7 @@ def get_padded_shapes(max_tokens, max_num_nodes, max_num_edges, include_strings=
146154
def make_filter(
147155
max_tokens, max_num_nodes, max_num_edges, max_steps, allowlist=None,
148156
class_subsample_values=None,
157+
use_in_dataset_field=True,
149158
):
150159
"""Makes a tf.Dataset filter function.
151160
@@ -179,6 +188,9 @@ def fn(example):
179188
class_ok |= (target == index)
180189
allowed = allowed & class_ok
181190

191+
if use_in_dataset_field:
192+
allowed &= tf.squeeze(example['in_dataset'] == 1, axis=-1)
193+
182194
# Filter x% of examples with target == 1 (the most common class).
183195
if class_subsample_values is not None:
184196
for key, value in class_subsample_values.items():

core/data/process.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class RawRuntimeErrorProblem:
4141
class RuntimeErrorProblem:
4242
"""RuntimeErrorProblem for use on an accelerator."""
4343
tokens: List[int]
44+
docstring_tokens: List[int]
4445
problem_id: Text
4546
submission_id: Text
4647
edge_sources: List[int]
@@ -58,6 +59,7 @@ class RuntimeErrorProblem:
5859
target: int
5960
target_lineno: Optional[int]
6061
target_node_indexes: List[int]
62+
in_dataset: bool
6163

6264

6365
def get_character_index(source, lineno, col_offset):
@@ -380,17 +382,38 @@ def get_nodes_at_lineno(raw, lineno):
380382
return overlapping_nodes
381383

382384

383-
def make_runtimeerrorproblem(source, target, target_lineno=0, tokenizer=None,
384-
problem_id=None, submission_id=None):
385+
def hardcoded_filter(tokens_extended):
386+
return len(tokens_extended) <= 512
387+
388+
389+
def make_runtimeerrorproblem(
390+
source, target, docstring=None, extended_source=None,
391+
target_lineno=0, tokenizer=None,
392+
problem_id=None, submission_id=None):
385393
raw = make_rawruntimeerrorproblem(
386394
source, target, target_lineno=target_lineno,
387395
problem_id=problem_id, submission_id=submission_id)
388396
tokenizer = tokenizer or tokenization.load_tokenizer()
389397
token_data = tokenize_raw_with_spans(tokenizer, raw)
398+
399+
if extended_source is not None and extended_source != source:
400+
extended_tokenized = tokenizer(extended_source)
401+
tokens_extended = extended_tokenized['input_ids']
402+
else:
403+
tokens_extended = token_data['tokens']
404+
if docstring is not None:
405+
docstring_tokenized = tokenizer(docstring)
406+
docstring_tokens = docstring_tokenized['input_ids']
407+
else:
408+
docstring_tokens = []
409+
410+
in_dataset = hardcoded_filter(tokens_extended)
411+
390412
branch_list = np.array(raw.branch_list)
391413
target_node_indexes = get_nodes_at_lineno(raw, target_lineno)
392414
return RuntimeErrorProblem(
393415
tokens=token_data['tokens'],
416+
docstring_tokens=docstring_tokens,
394417
problem_id=raw.problem_id,
395418
submission_id=raw.submission_id,
396419
edge_sources=raw.edge_sources,
@@ -408,6 +431,7 @@ def make_runtimeerrorproblem(source, target, target_lineno=0, tokenizer=None,
408431
target=raw.target,
409432
target_lineno=raw.target_lineno,
410433
target_node_indexes=target_node_indexes,
434+
in_dataset=in_dataset,
411435
)
412436

413437

core/lib/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def load_dataset(
6565
allowlist = error_kinds.TIER1_ERROR_IDS
6666
filter_fn = data_io.make_filter(
6767
config.max_tokens, config.max_num_nodes, config.max_num_edges,
68-
config.max_steps, allowlist=allowlist, class_subsample_values={1: 0.0660801055})
68+
config.max_steps, allowlist=allowlist, class_subsample_values={1: 0.0660801055},
69+
use_in_dataset_field=config.use_in_dataset_field)
6970

7071
if config.binary_targets:
7172
map_fn = functools.partial(data_io.binarize_targets, dataset_path=dataset_path)

0 commit comments

Comments
 (0)