Skip to content

Commit 87f934d

Browse files
fix: bugs in _prepare_text and IntruderScorer prompt formatting (#153)
* clean up some code and fix bug with preparing text with multiple false positives * clean up intruderscorer a bit more * another cleanup in _prepare_and_batch * cleanup in Pipeline * fix inconsistency with intruder example formatting * simplify _generate error reporting * ignore uv.lock * undo code readability/style changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8a79bb7 commit 87f934d

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,4 @@ results/
183183
statistics/
184184
.embedding_cache/
185185
wandb/
186+
uv.lock

delphi/scorers/classifier/intruder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _build_prompt(
275275
"""
276276

277277
examples = "\n".join(
278-
f"Example {i}: {example}" for i, example in enumerate(sample.examples)
278+
f"Example {i}:{example}" for i, example in enumerate(sample.examples)
279279
)
280280

281281
return self.prompt(examples=examples)
@@ -319,7 +319,6 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
319319
# default result is a error
320320
return IntruderResult()
321321
else:
322-
323322
try:
324323
interpretation, prediction = self._parse(response.text)
325324
except Exception as e:

delphi/scorers/classifier/sample.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _prepare_text(
9191
str_toks = example.str_tokens
9292
assert str_toks is not None, "str_toks were not set"
9393
clean = "".join(str_toks)
94+
9495
# Just return text if there's no highlighting
9596
if not highlighted:
9697
return clean, str_toks
@@ -125,9 +126,17 @@ def threshold_check(i):
125126
token_pos = len(str_toks) - len(str_toks) // 4
126127
if token_pos in below_threshold:
127128
random_indices = [token_pos]
128-
if n_incorrect > 1:
129+
130+
num_remaining_tokens_to_highlight = n_incorrect - 1
131+
if num_remaining_tokens_to_highlight > 0:
132+
remaining_tokens_below_threshold = below_threshold.tolist()
133+
remaining_tokens_below_threshold.remove(token_pos)
134+
129135
random_indices.extend(
130-
random.sample(below_threshold.tolist(), n_incorrect - 1)
136+
random.sample(
137+
remaining_tokens_below_threshold,
138+
num_remaining_tokens_to_highlight,
139+
)
131140
)
132141
else:
133142
random_indices = random.sample(below_threshold.tolist(), n_incorrect)

0 commit comments

Comments
 (0)