From 42a9f3cd6b1dac8540e100ad4a72539bdb5b6363 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 24 Sep 2021 16:39:35 -0700 Subject: [PATCH] Add instruction pointer entropy metric. Work in-progress. --- core/lib/metrics.py | 16 ++++++++++++++++ core/lib/trainer.py | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/core/lib/metrics.py b/core/lib/metrics.py index bef84a42..6061a154 100644 --- a/core/lib/metrics.py +++ b/core/lib/metrics.py @@ -20,6 +20,7 @@ class EvaluationMetric(enum.Enum): F1_SCORE = 'f1_score' CONFUSION_MATRIX = 'confusion_matrix' INSTRUCTION_POINTER = 'instruction_pointer' + INSTRUCTION_POINTER_ENTROPY = 'instruction_pointer_entropy' def all_metric_names() -> Tuple[str]: @@ -161,6 +162,21 @@ def instruction_pointers_to_images(instruction_pointer, multidevice: bool): return jnp.array(instruction_pointer_image_list) +def instruction_pointers_to_entropy(instruction_pointer, multidevice: bool): + """Converts the given batched instruction pointer to an entropy value. + + The entropy value measures the sharpness of the instruction pointer, i.e. how + hard vs soft it is. + """ + if multidevice: + # instruction_pointer: device, batch_size / device, timesteps, num_nodes + instruction_pointer = instruction_pointer[0] + + # instruction_pointer: batch_size / device, timesteps, num_nodes + # TODO: Implement entropy calculation. + raise NotImplementedError() + + def pad(array, leading_dim_size: int): """Pad the leading dimension of the given array.""" leading_dim_difference = max(0, leading_dim_size - array.shape[0]) diff --git a/core/lib/trainer.py b/core/lib/trainer.py index 401bfe65..2a200a53 100644 --- a/core/lib/trainer.py +++ b/core/lib/trainer.py @@ -378,6 +378,14 @@ def run_train(self, dataset_path=DEFAULT_DATASET_PATH, split='train', steps=None transform_fn=functools.partial( metrics.instruction_pointers_to_images, multidevice=config.multidevice)) + metrics.write_metric( + EvaluationMetric.INSTRUCTION_POINTER_ENTROPY.value, + aux, + train_writer.scalar, + step, + transform_fn=functools.partial( + metrics.instruction_pointers_to_entropy, + multidevice=config.multidevice)) # Write validation metrics. valid_writer.scalar('loss', valid_loss, step)