Skip to content

Commit ddead88

Browse files
committed
solve by extension class of KeyError
Signed-off-by: dafnapension <[email protected]>
1 parent 68db0ef commit ddead88

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

src/unitxt/error_utils.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _get_existing_context(error: Exception):
109109
existing["context_object"],
110110
existing["context"],
111111
)
112-
return str(error), None, {}
112+
return error.original_error if type(error) == ExtKeyError else str(error), None, {}
113113

114114

115115
def _format_object_context(obj: Any) -> Optional[str]:
@@ -239,13 +239,22 @@ def _store_context_attributes(
239239
"original_message": original_message,
240240
}
241241
try:
242-
error.original_error = type(error)(original_message)
242+
error.original_error = (
243+
original_message
244+
if type(error) == KeyError
245+
else type(error)(original_message)
246+
)
243247
except (TypeError, ValueError):
244248
error.original_error = Exception(original_message)
245249
error.context_object = context_object
246250
error.context = context
247251

248252

253+
class ExtKeyError(KeyError):
254+
def __str__(self):
255+
return "\n" + self.args[0]
256+
257+
249258
def _add_context_to_exception(
250259
original_error: Exception, context_object: Any = None, **context
251260
):
@@ -262,14 +271,17 @@ def _add_context_to_exception(
262271
}
263272
context_parts = _build_context_parts(final_context_object, final_context)
264273
context_message = _create_context_box(context_parts)
265-
_store_context_attributes(
266-
original_error, final_context_object, final_context, original_message
267-
)
274+
if type(original_error) == KeyError:
275+
f = ExtKeyError(original_message)
276+
else:
277+
f = original_error
278+
_store_context_attributes(f, final_context_object, final_context, original_message)
268279
if context_parts:
269280
formatted_message = f"\n{context_message}\n\n{original_message}"
270-
original_error.args = (formatted_message,)
281+
f.args = (formatted_message,)
271282
else:
272-
original_error.args = (original_message,)
283+
f.args = (original_message,)
284+
return f
273285

274286

275287
@contextmanager
@@ -298,7 +310,5 @@ def error_context(context_object: Any = None, **context):
298310
try:
299311
yield
300312
except Exception as e:
301-
if e.__class__.__name__ == "KeyError":
302-
e = RuntimeError(e.__class__.__name__ + ": '" + e.args[0] + "'")
303-
_add_context_to_exception(e, context_object, **context)
304-
raise e
313+
f = _add_context_to_exception(e, context_object, **context)
314+
raise f from None

tests/library/test_error_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestProcessor:
4747

4848
processor = TestProcessor()
4949

50-
with self.assertRaises(RuntimeError) as cm:
50+
with self.assertRaises(KeyError) as cm:
5151
with error_context(processor):
5252
raise KeyError("Missing key")
5353

@@ -186,12 +186,12 @@ class TestProcessor:
186186

187187
def test_error_context_without_object(self):
188188
"""Test error_context without a context object."""
189-
with self.assertRaises(RuntimeError) as cm:
189+
with self.assertRaises(KeyError) as cm:
190190
with error_context(input_file="data.json", line_number=156):
191191
raise KeyError("Missing field")
192192

193193
error = cm.exception
194-
self.assertIsInstance(error, RuntimeError)
194+
self.assertIsInstance(error, KeyError)
195195
self.assertIsNone(error.context_object)
196196
# Context now includes version info plus the specified context
197197
self.assertIn("Unitxt", error.context)

0 commit comments

Comments
 (0)