diff --git a/src/latexify/generate_latex_test.py b/src/latexify/generate_latex_test.py index 6964306..e2104d8 100644 --- a/src/latexify/generate_latex_test.py +++ b/src/latexify/generate_latex_test.py @@ -82,6 +82,24 @@ def f(x): assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag +def test_get_latex_identifiers_with_keyword_value() -> None: + def f(lambda_): + return lambda_ + + latex = generate_latex.get_latex( + f, identifiers={"lambda_": "lambda"}, use_math_symbols=True + ) + assert latex == r"f(\lambda) = \lambda" + + +def test_get_latex_identifiers_with_non_symbol_keyword_value() -> None: + def f(x): + return x + + latex = generate_latex.get_latex(f, identifiers={"x": "return"}) + assert latex == r"f(\mathrm{return}) = \mathrm{return}" + + def test_get_latex_use_math_symbols() -> None: def f(alpha): return alpha diff --git a/src/latexify/transformers/identifier_replacer.py b/src/latexify/transformers/identifier_replacer.py index aa0296e..a92fb75 100644 --- a/src/latexify/transformers/identifier_replacer.py +++ b/src/latexify/transformers/identifier_replacer.py @@ -31,15 +31,18 @@ def __init__(self, mapping: dict[str, str]): Args: mapping: User defined mapping of names. Keys are the original names of the identifiers, and corresponding values are the replacements. - Both keys and values have to represent valid Python identifiers: - ^[A-Za-z_][A-Za-z0-9_]*$ + Keys have to represent valid Python identifiers (not keywords). + Values have to be valid Python identifiers but may be keywords + (e.g., "lambda" for use with math symbols). """ self._mapping = mapping for k, v in self._mapping.items(): + # Keys must be valid non-keyword identifiers (they reference source code). if not str.isidentifier(k) or keyword.iskeyword(k): raise ValueError(f"'{k}' is not an identifier name.") - if not str.isidentifier(v) or keyword.iskeyword(v): + # Values may be keywords (e.g., "lambda" resolves to a math symbol). + if not str.isidentifier(v): raise ValueError(f"'{v}' is not an identifier name.") def _replace_args(self, args: list[ast.arg]) -> list[ast.arg]: diff --git a/src/latexify/transformers/identifier_replacer_test.py b/src/latexify/transformers/identifier_replacer_test.py index 16a5ac4..8efbc39 100644 --- a/src/latexify/transformers/identifier_replacer_test.py +++ b/src/latexify/transformers/identifier_replacer_test.py @@ -16,7 +16,19 @@ def test_invalid_mapping() -> None: with pytest.raises(ValueError, match=r"'456' is not an identifier name."): IdentifierReplacer({"foo": "456"}) with pytest.raises(ValueError, match=r"'def' is not an identifier name."): - IdentifierReplacer({"foo": "def"}) + IdentifierReplacer({"def": "foo"}) + + +def test_keyword_value_accepted() -> None: + assert IdentifierReplacer({"x": "lambda"}) is not None + assert IdentifierReplacer({"x": "return"}) is not None + + +def test_name_replaced_with_keyword_value() -> None: + source = ast.Name(id="x", ctx=ast.Load()) + expected = ast.Name(id="lambda", ctx=ast.Load()) + transformed = IdentifierReplacer({"x": "lambda"}).visit(source) + test_utils.assert_ast_equal(transformed, expected) def test_name_replaced() -> None: