Skip to content

Commit 4a4b64a

Browse files
authored
Feat/chipper v3 (#308)
This new version of Chipper should largely improve the output for tables. In the attached file, the output looked as having many cells spread across multiple columns, and largely because of the $ character, which was inconsistently annotated in the Odetta set. As well colspan did not work properly for the header. This new version of Chipper does not predict thead and tbody tokens for tables. To test it, you need to run the code below. It will print the predicted elements. The code should print only one page and one element. The element has a field name text_as_html. The HTML within that field can be pasted in a new file renamed as html to be open with a browser. Example with Chipperv2 <img width="1146" alt="image" src="https://github.com/Unstructured-IO/unstructured-inference/assets/3939469/feffe674-8c9b-4c64-bd6d-08bd602c596a"> Example with Chipperv3 <img width="666" alt="image" src="https://github.com/Unstructured-IO/unstructured-inference/assets/3939469/f06867a9-2636-4055-a158-42badc58dd09"> <img width="677" alt="apple" src="https://github.com/Unstructured-IO/unstructured-inference/assets/3939469/d7ec628e-0dca-409c-894a-612350fce71f"> ``` from unstructured_inference.inference.layout import DocumentLayout from unstructured_inference.models.base import get_model model = get_model("chipper") doc = DocumentLayout.from_image_file("[point to the location of the file]/apple.png", detection_model=model) for i in range(len(doc.pages)): print(f"********** Page {i}") print(*[element.__dict__ for element in doc.pages[i].elements], sep="\n") ``` --------- Co-authored-by: Antonio Jimeno Yepes <[email protected]>
1 parent 4e5c4e6 commit 4a4b64a

File tree

6 files changed

+105
-18
lines changed

6 files changed

+105
-18
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.7.20
2+
3+
* chipper-v3: improved table prediction
4+
15
## 0.7.19
26

37
* refactor: remove all OCR related code

test_unstructured_inference/models/test_chippermodel.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,8 @@ def test_no_repeat_ngram_logits():
139139

140140
no_repeat_ngram_size = 2
141141

142-
output = chipper._no_repeat_ngram_logits(
143-
input_ids=input_ids,
144-
cur_len=cur_len,
145-
logits=logits,
146-
batch_size=batch_size,
147-
no_repeat_ngram_size=no_repeat_ngram_size,
148-
)
142+
logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2)
143+
output = logitsProcessor(input_ids=input_ids, scores=logits)
149144

150145
assert (
151146
int(
@@ -194,6 +189,25 @@ def test_no_repeat_ngram_logits():
194189
)
195190

196191

192+
def test_ngram_repetiton_stopping_criteria():
193+
input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]])
194+
logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])
195+
196+
stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
197+
repetition_window=2, skip_tokens={0, 1, 2, 3, 4}
198+
)
199+
200+
output = stoppingCriteria(input_ids=input_ids, scores=logits)
201+
202+
assert output is False
203+
204+
stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
205+
repetition_window=2, skip_tokens={1, 2, 3, 4}
206+
)
207+
output = stoppingCriteria(input_ids=input_ids, scores=logits)
208+
assert output is True
209+
210+
197211
@pytest.mark.parametrize(
198212
("decoded_str", "expected_classes"),
199213
[
@@ -241,7 +255,51 @@ def test_postprocess_bbox(decoded_str, expected_classes):
241255
assert out[i].type == expected_classes[i]
242256

243257

244-
def test_run_chipper_v2():
258+
def test_predict_tokens_beam_indices():
259+
model = get_model("chipper")
260+
model.stopping_criteria = [
261+
chipper.NGramRepetitonStoppingCriteria(
262+
repetition_window=1,
263+
skip_tokens={},
264+
),
265+
]
266+
img = Image.open("sample-docs/easy_table.jpg")
267+
output = model.predict_tokens(image=img)
268+
assert len(output) > 0
269+
270+
271+
def test_largest_margin_edge():
272+
model = get_model("chipper")
273+
img = Image.open("sample-docs/easy_table.jpg")
274+
output = model.largest_margin(image=img, input_bbox=[0, 1, 0, 0], transpose=False)
275+
276+
assert output is None
277+
278+
output = model.largest_margin(img, [1, 1, 1, 1], False)
279+
280+
assert output is None
281+
282+
output = model.largest_margin(img, [2, 1, 3, 10], True)
283+
284+
assert output == (0, 0, 0)
285+
286+
287+
def test_deduplicate_detected_elements():
288+
model = get_model("chipper")
289+
img = Image.open("sample-docs/easy_table.jpg")
290+
elements = model(img)
291+
292+
output = model.deduplicate_detected_elements(elements)
293+
294+
assert len(output) == 2
295+
296+
297+
def test_norepeatnGramlogitsprocessor_exception():
298+
with pytest.raises(ValueError):
299+
chipper.NoRepeatNGramLogitsProcessor(ngram_size="")
300+
301+
302+
def test_run_chipper_v3():
245303
model = get_model("chipper")
246304
img = Image.open("sample-docs/easy_table.jpg")
247305
elements = model(img)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.19" # pragma: no cover
1+
__version__ = "0.7.20" # pragma: no cover

unstructured_inference/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@ class Source(Enum):
1313
CHIPPER = "chipper"
1414
CHIPPERV1 = "chipperv1"
1515
CHIPPERV2 = "chipperv2"
16+
CHIPPERV3 = "chipperv3"
1617
MERGED = "merged"
1718
SUPER_GRADIENTS = "super-gradients"
1819

1920

21+
CHIPPER_VERSIONS = (
22+
Source.CHIPPER,
23+
Source.CHIPPERV1,
24+
Source.CHIPPERV2,
25+
Source.CHIPPERV3,
26+
)
27+
28+
2029
class ElementType:
2130
IMAGE = "Image"
2231
FIGURE = "Figure"
@@ -37,3 +46,6 @@ class ElementType:
3746

3847

3948
FULL_PAGE_REGION_THRESHOLD = 0.99
49+
50+
# this field is defined by pytesseract/unstructured.pytesseract
51+
TESSERACT_TEXT_HEIGHT = "height"

unstructured_inference/inference/layoutelement.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from unstructured_inference.config import inference_config
1212
from unstructured_inference.constants import (
13+
CHIPPER_VERSIONS,
1314
FULL_PAGE_REGION_THRESHOLD,
1415
ElementType,
1516
Source,
@@ -108,7 +109,7 @@ def merge_inferred_layout_with_extracted_layout(
108109
continue
109110
region_matched = False
110111
for inferred_region in inferred_layout:
111-
if inferred_region.source in (Source.CHIPPER, Source.CHIPPERV1):
112+
if inferred_region.source in CHIPPER_VERSIONS:
112113
continue
113114

114115
if inferred_region.bbox.intersects(extracted_region.bbox):

unstructured_inference/models/chipper.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from transformers.generation.logits_process import LogitsProcessor
1616
from transformers.generation.stopping_criteria import StoppingCriteria
1717

18-
from unstructured_inference.constants import Source
18+
from unstructured_inference.constants import CHIPPER_VERSIONS, Source
1919
from unstructured_inference.inference.elements import Rectangle
2020
from unstructured_inference.inference.layoutelement import LayoutElement
2121
from unstructured_inference.logger import logger
@@ -44,11 +44,22 @@
4444
"max_length": 1536,
4545
"heatmap_h": 40,
4646
"heatmap_w": 30,
47+
"source": Source.CHIPPERV2,
48+
},
49+
"chipperv3": {
50+
"pre_trained_model_repo": "unstructuredio/chipper-v3",
51+
"swap_head": True,
52+
"swap_head_hidden_layer_size": 128,
53+
"start_token_prefix": "<s_",
54+
"prompt": "<s><s_hierarchical>",
55+
"max_length": 1536,
56+
"heatmap_h": 40,
57+
"heatmap_w": 30,
4758
"source": Source.CHIPPER,
4859
},
4960
}
5061

51-
MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv2"]
62+
MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv3"]
5263

5364

5465
class UnstructuredChipperModel(UnstructuredElementExtractionModel):
@@ -390,7 +401,7 @@ def deduplicate_detected_elements(
390401
min_text_size: int = 15,
391402
) -> List[LayoutElement]:
392403
"""For chipper, remove elements from other sources."""
393-
return [el for el in elements if el.source in (Source.CHIPPER, Source.CHIPPERV1)]
404+
return [el for el in elements if el.source in CHIPPER_VERSIONS]
394405

395406
def adjust_bbox(self, bbox, x_offset, y_offset, ratio, target_size):
396407
"""Translate bbox by (x_offset, y_offset) and shrink by ratio."""
@@ -516,12 +527,13 @@ def reduce_element_bbox(
516527
Given a LayoutElement element, reduce the size of the bounding box,
517528
depending on existing elements
518529
"""
519-
bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
530+
if element.bbox:
531+
bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
520532

521-
if not self.element_overlap(elements, element):
522-
element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox))
523-
else:
524-
element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox))
533+
if not self.element_overlap(elements, element):
534+
element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox))
535+
else:
536+
element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox))
525537

526538
def bbox_overlap(
527539
self,

0 commit comments

Comments
 (0)