diff --git a/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py b/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py index 6d8bd4f6..6daec513 100644 --- a/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py +++ b/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py @@ -5,13 +5,15 @@ """ import string -from typing import Optional, Set, List +from typing import Any, Dict, Optional, Set, List from mmda.predictors.base_predictors.base_predictor import BasePredictor from mmda.types.annotation import Annotation, Span, SpanGroup from mmda.types.document import Document from mmda.types.names import Rows, Tokens +from ftfy import fix_text, TextFixerConfig + class DictionaryWordPredictor(BasePredictor): @@ -20,7 +22,11 @@ class DictionaryWordPredictor(BasePredictor): _dictionary: Optional[Set[str]] = None - def __init__(self, dictionary_file_path: str) -> None: + def __init__( + self, + dictionary_file_path: str, + ftfy_config: Optional[Dict[str, Any]] = None + ) -> None: """Build a predictor that indexes the given dictionary file. A dictionary is simply a case-sensitive list of words as a text file. Words should be lower-case in the dictionary unless they are invalid @@ -41,6 +47,9 @@ def __init__(self, dictionary_file_path: str) -> None: """ self.dictionary_file_path = dictionary_file_path + ftfy_config = ftfy_config or {"explain": False} + self.ftfy_config = TextFixerConfig(**ftfy_config) + @property def dictionary(self) -> Set[str]: """Global dictionary and not document specific. This dictionary is the basis for @@ -171,7 +180,9 @@ def predict(self, document: Document) -> List[SpanGroup]: return words def _token_text(self, token: SpanGroup) -> str: - return "".join(token.symbols) + text = "".join(token.symbols) + text = fix_text(text, config=self.ftfy_config) + return text def _copy_token_with_text(self, token: SpanGroup) -> SpanGroup: return SpanGroup(spans=token.spans, text=self._token_text(token)) diff --git a/setup.py b/setup.py index a8747657..d799767b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="mmda", description="mmda", - version="0.0.42", + version="0.0.43", url="https://www.github.com/allenai/mmda", python_requires=">= 3.7", packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]), @@ -15,7 +15,8 @@ "pandas", "pydantic", "ncls", - "necessary" + "necessary", + "ftfy>=6.1.0,<7.0.0", ], extras_require={ "dev": ["pytest"], diff --git a/tests/test_predictors/test_dictionary_word_predictor.py b/tests/test_predictors/test_dictionary_word_predictor.py index d8085362..46959c51 100644 --- a/tests/test_predictors/test_dictionary_word_predictor.py +++ b/tests/test_predictors/test_dictionary_word_predictor.py @@ -186,7 +186,7 @@ def test_optional_plurarl_words_combined(self): def test_next_row_single_token(self): # fmt:off - #0 10 + #0 10 #012345678901 text = "Many lin-es" # fmt:on @@ -214,3 +214,52 @@ def test_next_row_single_token(self): "Many lin-es", " ".join([w.text for w in words]), ) + + +class TestFtfyFixes(unittest.TestCase): + def test_ftfy(self): + text = ( + "Fact Verification, which reasons whether " + "a claim is supported / refuted by mult-" + "iple evidences" + ) + spans = [ + Span(start=0, end=4), # Fact + Span(start=5, end=17), # Verification, + Span(start=18, end=23), # which + Span(start=24, end=31), # reasons + Span(start=32, end=39), # whether + Span(start=40, end=41), # a + Span(start=42, end=47), # claim + Span(start=48, end=50), # is + Span(start=51, end=60), # supported + Span(start=61, end=62), # / + Span(start=63, end=70), # refuted + Span(start=71, end=73), # by + Span(start=74, end=79), # mult- + Span(start=79, end=83), # ple + Span(start=84, end=93), # evidences + ] + rows = [ + SpanGroup(id=1, spans=spans[:5]), + SpanGroup(id=2, spans=spans[5:13]), + SpanGroup(id=3, spans=spans[13:]), + ] + + doc = mock_document(symbols=text, spans=spans, rows=rows) + + with tempfile.NamedTemporaryFile() as f: + f.write("multiple\n".encode("utf-8")) + f.flush() + + predictor = DictionaryWordPredictor( + dictionary_file_path=f.name + ) + words: List[SpanGroup] = predictor.predict(doc) + + parsed = " ".join([str(w.text) for w in words]) + reference = ( + "Fact Verification, which reasons whether a claim is " + "supported / refuted by multiple evidences" + ) + self.assertEqual(reference, parsed) diff --git a/tests/test_predictors/test_vila_predictors.py b/tests/test_predictors/test_vila_predictors.py index 2698ac2d..b555397c 100644 --- a/tests/test_predictors/test_vila_predictors.py +++ b/tests/test_predictors/test_vila_predictors.py @@ -1,4 +1,5 @@ -import json +import json +from pathlib import Path from PIL import Image @@ -52,6 +53,8 @@ S2VL_LABEL_MAP = {int(key): val for key, val in S2VL_LABEL_MAP.items()} +FIXTURES_PATH = Path(__file__).parent.parent / "fixtures" + def test_vila_predictors(): layout_predictor = LayoutParserPredictor.from_pretrained( @@ -61,8 +64,13 @@ def test_vila_predictors(): pdfplumber_parser = PDFPlumberParser() rasterizer = PDF2ImageRasterizer() - doc = pdfplumber_parser.parse(input_pdf_path="tests/fixtures/1903.10676.pdf") - images = rasterizer.rasterize(input_pdf_path="tests/fixtures/1903.10676.pdf", dpi=72) + doc = pdfplumber_parser.parse( + input_pdf_path=str(FIXTURES_PATH /"1903.10676.pdf") + ) + images = rasterizer.rasterize( + input_pdf_path=str(FIXTURES_PATH /"1903.10676.pdf"), + dpi=72 + ) doc.annotate_images(images) layout_regions = layout_predictor.predict(doc) @@ -123,9 +131,9 @@ def test_vila_predictors(): assert [ele.type for ele in resA] == [S2VL_LABEL_MAP[ele.type] for ele in resB] def test_vila_predictors_with_special_unicode_inputs(): - - test_doc_path = "tests/fixtures/unicode-test.json" - + + test_doc_path = FIXTURES_PATH / "unicode-test.json" + with open(test_doc_path, 'r') as fp: res = json.load(fp) @@ -136,4 +144,4 @@ def test_vila_predictors_with_special_unicode_inputs(): "allenai/ivila-row-layoutlm-finetuned-s2vl-v2" ) - ivilaA.predict(doc, subpage_per_run=2) \ No newline at end of file + ivilaA.predict(doc, subpage_per_run=2)