Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions measurements/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,23 @@ def _compute(
# if there is not an already assigned pad_token, assign an existing
# special token to also be the padding token
if tokenizer.pad_token is None and batch_size > 1:
existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
# check that the model already has at least one special token defined
assert (
len(existing_special_tokens) > 0
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
else:
existing_special_tokens = list(tokenizer.special_tokens_map.values())
# check that the model already has at least one special token defined
assert (
len(existing_special_tokens) > 0
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if max_length is None:
max_length = getattr(tokenizer, "model_max_length", None)

# Some tokenizers use very large sentinel values when no real max length is defined.
if max_length is None or max_length > 1e20:
max_length = getattr(model.config, "max_position_embeddings", None)

if add_start_token and max_length:
# leave room for <BOS> token to be added:
Expand All @@ -141,7 +151,7 @@ def _compute(
encodings = tokenizer(
data,
add_special_tokens=False,
padding=True,
padding=batch_size > 1,
truncation=True if max_tokenized_len else False,
max_length=max_tokenized_len,
return_tensors="pt",
Expand Down
26 changes: 18 additions & 8 deletions metrics/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,23 @@ def _compute(
# if there is not an already assigned pad_token, assign an existing
# special token to also be the padding token
if tokenizer.pad_token is None and batch_size > 1:
existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
# check that the model already has at least one special token defined
assert (
len(existing_special_tokens) > 0
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
else:
existing_special_tokens = list(tokenizer.special_tokens_map.values())
# check that the model already has at least one special token defined
assert (
len(existing_special_tokens) > 0
), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
# assign one of the special tokens to also be the pad token
tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if max_length is None:
max_length = getattr(tokenizer, "model_max_length", None)

# Some tokenizers use very large sentinel values when no real max length is defined.
if max_length is None or max_length > 1e20:
max_length = getattr(model.config, "max_position_embeddings", None)

if add_start_token and max_length:
# leave room for <BOS> token to be added:
Expand All @@ -140,7 +150,7 @@ def _compute(
encodings = tokenizer(
predictions,
add_special_tokens=False,
padding=True,
padding=batch_size > 1,
truncation=True if max_tokenized_len else False,
max_length=max_tokenized_len,
return_tensors="pt",
Expand Down
27 changes: 27 additions & 0 deletions tests/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import evaluate


def test_perplexity_gpt2():
perplexity = evaluate.load("./metrics/perplexity", module_type="metric")

result = perplexity.compute(
predictions=["Hello world."],
model_id="gpt2",
)

assert "mean_perplexity" in result
assert len(result["perplexities"]) == 1


def test_perplexity_long_input():
perplexity = evaluate.load("./metrics/perplexity", module_type="metric")

result = perplexity.compute(
predictions=["Hello world. " * 2000],
model_id="gpt2",
add_start_token=False,
)

assert "mean_perplexity" in result