diff --git a/analysis/box_indexer/benchmark.py b/analysis/box_indexer/benchmark.py new file mode 100644 index 00000000..aa547b18 --- /dev/null +++ b/analysis/box_indexer/benchmark.py @@ -0,0 +1,147 @@ +""" + +Benchmarking different ways of indexing and searching over Boxes + +RTree + INFO:root:RTreeIndexer took 0.23895001411437988 seconds to index 10000 boxes) + INFO:root:RTreeIndexer took 188.41757702827454 seconds for 1000 queries + + INFO:root:RTreeIndexer took 0.24275708198547363 seconds to index 10000 boxes) + INFO:root:RTreeIndexer took 1.1544108390808105 seconds for 1000 queries + +KDTree + Stopped Implementation was kind of tricky. + +Numpy array + INFO:root:NumpyIndexer took 0.0027608871459960938 seconds to index 10000 boxes) + INFO:root:NumpyIndexer took 187.8270788192749 seconds for 1000 queries + + INFO:root:NumpyIndexer took 0.0034339427947998047 seconds to index 10000 boxes) + INFO:root:NumpyIndexer took 0.07767200469970703 seconds for 1000 queries + +""" + +import logging +import time +from typing import List, Set, Tuple + +import numpy as np +from rtree import index +from sklearn.neighbors import KDTree + +from mmda.types.box import Box + +logging.basicConfig(level=logging.INFO) + +# create random boxes +boxes = [ + Box( + l=np.random.rand(), + t=np.random.rand(), + w=np.random.rand(), + h=np.random.rand(), + page=0, + ) + for _ in range(10000) +] + +# create random boxes for querying +queries = [ + Box( + l=np.random.rand(), + t=np.random.rand(), + w=np.random.rand(), + h=np.random.rand(), + page=0, + ) + for _ in range(1000) +] + + +class RTreeIndexer: + def __init__(self, boxes: List[Box]): + self.boxes = boxes + self.rtree = index.Index(interleaved=True) + for i, b in enumerate(boxes): + x1, y1, x2, y2 = b.coordinates + self.rtree.insert(i, (x1, y1, x2, y2)) + + def find(self, query: Box) -> List[int]: + x1, y1, x2, y2 = query.coordinates + box_ids = self.rtree.intersection((x1, y1, x2, y2)) + return list(box_ids) + + +class NumpyIndexer: + def __init__(self, boxes: List[Box]): + self.boxes = boxes + self.np_boxes_x1 = np.array([b.l for b in boxes]) + self.np_boxes_y1 = np.array([b.t for b in boxes]) + self.np_boxes_x2 = np.array([b.l + b.w for b in boxes]) + self.np_boxes_y2 = np.array([b.t + b.h for b in boxes]) + + def find(self, query: Box) -> List[int]: + x1, y1, x2, y2 = query.coordinates + mask = ( + (self.np_boxes_x1 <= x2) + & (self.np_boxes_x2 >= x1) + & (self.np_boxes_y1 <= y2) + & (self.np_boxes_y2 >= y1) + ) + return np.where(mask)[0].tolist() + + +def bulk_query(indexer, boxes, queries, is_validate: bool = True): + for q in queries: + found = indexer.find(q) + if is_validate: + for i in range(len(boxes)): + if i in found: + assert boxes[i].is_overlap(q) + else: + assert not boxes[i].is_overlap(q) + + +def benchmark_rtree(boxes, queries, is_validate: bool = True): + # indexing time + start = time.time() + logging.info("Starting benchmarking") + rtree_indexer = RTreeIndexer(boxes) + end = time.time() + logging.info( + f"RTreeIndexer took {end - start} seconds to index {len(boxes)} boxes)" + ) + + # searching time + start = time.time() + logging.info("Starting benchmarking") + bulk_query(rtree_indexer, boxes, queries, is_validate=is_validate) + + end = time.time() + logging.info(f"RTreeIndexer took {end - start} seconds for {len(queries)} queries") + + +def benchmark_numpy(boxes, queries, is_validate: bool = True): + # indexing time + start = time.time() + logging.info("Starting benchmarking") + numpy_indexer = NumpyIndexer(boxes) + end = time.time() + logging.info( + f"NumpyIndexer took {end - start} seconds to index {len(boxes)} boxes)" + ) + + # searching time + start = time.time() + logging.info("Starting benchmarking") + bulk_query(numpy_indexer, boxes, queries, is_validate=is_validate) + + end = time.time() + logging.info(f"NumpyIndexer took {end - start} seconds for {len(queries)} queries") + + +benchmark_rtree(boxes=boxes, queries=queries) +benchmark_numpy(boxes=boxes, queries=queries) + +benchmark_rtree(boxes=boxes, queries=queries, is_validate=False) +benchmark_numpy(boxes=boxes, queries=queries, is_validate=False) diff --git a/analysis/box_indexer/requirements.txt b/analysis/box_indexer/requirements.txt new file mode 100644 index 00000000..e7ced466 --- /dev/null +++ b/analysis/box_indexer/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +rtree \ No newline at end of file diff --git a/src/mmda/eval/metrics.py b/src/mmda/eval/metrics.py index abc431f9..dfdd36ce 100644 --- a/src/mmda/eval/metrics.py +++ b/src/mmda/eval/metrics.py @@ -75,11 +75,11 @@ def levenshtein( def box_overlap(box: Box, container: Box) -> float: """Returns the percentage of area of a box inside of a container.""" - bl, bt, bw, bh = box.xywh + bl, bt, bw, bh = box.l, box.t, box.w, box.h br = bl + bw bb = bt + bh - cl, ct, cw, ch = container.xywh + cl, ct, cw, ch = container.l, container.t, container.w, container.h cr = cl + cw cb = ct + ch diff --git a/src/mmda/parsers/NOTES.md b/src/mmda/parsers/NOTES.md new file mode 100644 index 00000000..d4c07c58 --- /dev/null +++ b/src/mmda/parsers/NOTES.md @@ -0,0 +1,122 @@ +# notes + + +## bounding boxes + + +**1. Problem** + +Token boxes from PDFPlumber aren't necessarily disjoint, which is annoying. + +Here's an example (converted to xy coordinates for ease): + +``` +Token=0 +(0.8155551504075421, + 0.9343934350623241, + 0.8317358800186805, + 0.9445089287110585, + 'Vol') + +Token=1 +(0.8315880494325322, + 0.9343934350623241, + 0.8343161957041776, + 0.9445089287110585, + '.') +``` + +We see that Token0 and Token1 are on the same line given their ymin and ymax hovering at 0.93-0.94. They should be disjoint but Token0's xmax (0.8317) extends into Token1's xmin (0.8316). This should never be the case. + +The weird thing is though, while boxes are kind of messed up like this, the text isn't. The two tokens together spell out `Vol.` correctly. + +Why is this bad? Because it messes up our ability to lookup overlapping entities based on bounding box overlap. This is *especially* bad when we're trying to go from a bounding box (e.g. from LayoutParser) to token spans. + +**2. Solutions** + +There are a few solutions: + +1. Fix this in PDFPlumberParser with posthoc operator over bounding boxes to enforce disjointness. + + ➔ Kind of problematic because requires so much reprocessing. + + ➔ May also make it harder to keep in-sync w/ PDFPlumber over time. Posthoc corrections maybe should still be decoupled from the parser itself? + +2. Apply fixes at `Document.from_json()` time to adjust all token bounding boxes "inward". + + ➔ Not as risky. Somewhat hacky but fairly straightforward. Problem is it doesn't quite fix the data, which is still being generated with weird boxes. + +3. Overlapping bounding boxes primarily affects vision-based lookup methods (e.g. give me all entities within a visual region). A different way to fix this is to base all bounding box-based lookup methods not on the bounding boxes, but instead on the box centroid or something else. + + ➔ Not sure whether it would work. Seems dependent on box quality. For example, if the box is extremely off-center, then it's not like this would solve anything. And it's actually quite a bit of code to refactor. Let's rule it out. + + +Thinking about it, #2 is pretty reasonable for now given that we're aiming for overall stability. We can worry about #1 later. + +**2. Implementation details of Approach 2** + +There are a few ways to do #2 actually. + +If the Boxes are actually pretty good, but only overlap with each other slightly (e.g. boundaries kinda fuzzy), then the easiest way to fix things is to shrink all boxes to avoid all overlapping regions. + +But if the Boxes are pretty poor quality (imagine something that's really big/off-center), then there's no real way to resolve that box without actually pulling up the original image and trying to do some pixel-based localization. + +We *really* don't want to do the latter, so let's investigate whether the former is the more common case. + +See this example where we cluster all the token boxes on a page and only shade the ones that have some overlap with another box: ![image](fixtures/token-box-overlap.png). + +The overlapping regions are really small. Let's go with the easier implementation. + + +**3. Global epsilon-based adjustment to all boxes** + +Before we do that, there are actually boxes out of PDFPlumber that are perfectly shared borders: + +``` +{'x1': 0.38126804444444445, 'y1': 0.12202539797979782, 'x2': 0.3861468209150327, 'y2': 0.13334661010100995} +{'x1': 0.3861468209150327, 'y1': 0.12202539797979782, 'x2': 0.42277427189542477, 'y2': 0.13334661010100995} +``` + +As you can see, `x2` of the first box is exactly `x1` of the second box. + +One way to make this way less annoying is to apply a very small shrinkage. How big is this? Well, looking at the size of typical images rendered at `dpi=72`, we're probably looking at pages that have dimension `(800, 620)` at most. So conservatively, we can set an `epsilon=1e-4` without worrying about it being perceptible. + +For example, just doing something like this: +``` +BUFFER = 1e-4 +for token in doc.tokens: + token.box_group.boxes[0].l += BUFFER + token.box_group.boxes[0].t += BUFFER + token.box_group.boxes[0].w -= 2 * BUFFER + token.box_group.boxes[0].h -= 2 * BUFFER +``` + +will fix a lot of the overlapping boxes from the first figure: + +![image](fixtures/token-box-overlap-fixed-with-epsilon.png) + + +**5. Clustering-based solution*** + +Now how do we fix the remaining boxes? + +One thought is -- Can we do a fast thing just writing rules to compare `xmin, xmax, ymin, ymax` and directly make adjustments to those boundaries? Probably, but it gets confusing really quickly. Consider this case: + +![image](fixtures/overlap-tokens-edge-case.png) + +The problem is that we don't really want to work off clusters of boxes, but actually *pairs* of boxes that overlap. But the O(N^2) over all boxes on a page is kind of costly to check. So let's pre-cluster and then perform the O(N^2) within each cluster. + + +The technique looks something like: + +![image](fixtures/overlap-token-strategy.png) + +The figure above is what would happen if you just swapped the correct coordinates of the overlapping boxes. It does create more whitespace between, which may not be desirable, but it's also easier to follow. You can instead replace the relevant box coordinates with something newly calculated (e.g. split the difference). + +**6. Boxes that aren't even on the page** + +If we look at PDF in test fixtures `4be952924cd565488b4a239dc6549095029ee578.pdf`, we'll actually find weird boxes that come out of PDFPlumber that are off the page: + +``` + +``` \ No newline at end of file diff --git a/src/mmda/parsers/fixtures/overlap-token-strategy.png b/src/mmda/parsers/fixtures/overlap-token-strategy.png new file mode 100644 index 00000000..a50c75c5 Binary files /dev/null and b/src/mmda/parsers/fixtures/overlap-token-strategy.png differ diff --git a/src/mmda/parsers/fixtures/overlap-tokens-edge-case.png b/src/mmda/parsers/fixtures/overlap-tokens-edge-case.png new file mode 100644 index 00000000..e5d82a3b Binary files /dev/null and b/src/mmda/parsers/fixtures/overlap-tokens-edge-case.png differ diff --git a/src/mmda/parsers/fixtures/token-box-overlap-fixed-with-epsilon.png b/src/mmda/parsers/fixtures/token-box-overlap-fixed-with-epsilon.png new file mode 100644 index 00000000..d1b4f341 Binary files /dev/null and b/src/mmda/parsers/fixtures/token-box-overlap-fixed-with-epsilon.png differ diff --git a/src/mmda/parsers/fixtures/token-box-overlap.png b/src/mmda/parsers/fixtures/token-box-overlap.png new file mode 100644 index 00000000..ae9ade8d Binary files /dev/null and b/src/mmda/parsers/fixtures/token-box-overlap.png differ diff --git a/src/mmda/parsers/pdfplumber_parser.py b/src/mmda/parsers/pdfplumber_parser.py index b5472417..35fccebd 100644 --- a/src/mmda/parsers/pdfplumber_parser.py +++ b/src/mmda/parsers/pdfplumber_parser.py @@ -1,6 +1,6 @@ import itertools import string -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import pdfplumber @@ -201,24 +201,28 @@ def parse(self, input_pdf_path: str) -> Document: ) assert len(word_ids_of_fine_tokens) == len(fine_tokens) # 4) normalize / clean tokens & boxes + boxes = [ + Box.from_coordinates( + x1=float(token["x0"]), + y1=float(token["top"]), + x2=float(token["x1"]), + y2=float(token["bottom"]), + page=int(page_id), + ).get_relative( + page_width=float(page.width), page_height=float(page.height) + ) + for token in fine_tokens + ] + # 4.5) reformat + # TODO: remove tokens that are 'off page' fine_tokens = [ { "text": token["text"], "fontname": token["fontname"], "size": token["size"], - "bbox": Box.from_pdf_coordinates( - x1=float(token["x0"]), - y1=float(token["top"]), - x2=float(token["x1"]), - y2=float(token["bottom"]), - page_width=float(page.width), - page_height=float(page.height), - page=int(page_id), - ).get_relative( - page_width=float(page.width), page_height=float(page.height) - ), + "bbox": box, } - for token in fine_tokens + for box, token in zip(boxes, fine_tokens) ] # 5) group tokens into lines # TODO - doesnt belong in parser; should be own predictor diff --git a/src/mmda/types/annotation.py b/src/mmda/types/annotation.py index 4857df5c..4cc28ac6 100644 --- a/src/mmda/types/annotation.py +++ b/src/mmda/types/annotation.py @@ -9,6 +9,7 @@ import warnings from abc import abstractmethod from copy import deepcopy +from itertools import combinations from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union from mmda.types.box import Box @@ -22,7 +23,6 @@ __all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] - def warn_deepcopy_of_annotation(obj: "Annotation") -> None: """Warns when a deepcopy is performed on an Annotation.""" @@ -34,15 +34,14 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: warnings.warn(msg, UserWarning, stacklevel=2) - class Annotation: """Annotation is intended for storing model predictions for a document.""" def __init__( - self, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None + self, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, ): self.id = id self.doc = doc @@ -77,23 +76,30 @@ def __getattr__(self, field: str) -> List["Annotation"]: return self.__getattribute__(field) - class BoxGroup(Annotation): def __init__( - self, - boxes: List[Box], - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + boxes: List[Box], + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.boxes = boxes + if not allow_overlap: + clusters = Box.cluster_boxes(boxes=boxes) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "BoxGroup does not allow overlapping boxes. " + "Consider setting allow_overlap=True." + ) super().__init__(id=id, doc=doc, metadata=metadata) def to_json(self) -> Dict: box_group_dict = dict( boxes=[box.to_json() for box in self.boxes], id=self.id, - metadata=self.metadata.to_json() + metadata=self.metadata.to_json(), ) return { key: value for key, value in box_group_dict.items() if value @@ -101,16 +107,13 @@ def to_json(self) -> Dict: @classmethod def from_json(cls, box_group_dict: Dict) -> "BoxGroup": - if "metadata" in box_group_dict: metadata_dict = box_group_dict["metadata"] else: # this fallback is necessary to ensure compatibility with box # groups that were create before the metadata migration and # therefore have "type" in the root of the json dict instead. - metadata_dict = { - "type": box_group_dict.get("type", None) - } + metadata_dict = {"type": box_group_dict.get("type", None)} return cls( boxes=[ @@ -132,7 +135,7 @@ def __deepcopy__(self, memo): box_group = BoxGroup( boxes=deepcopy(self.boxes, memo), id=self.id, - metadata=deepcopy(self.metadata, memo) + metadata=deepcopy(self.metadata, memo), ) # Don't copy an attached document @@ -150,25 +153,30 @@ def type(self, type: Union[str, None]) -> None: class SpanGroup(Annotation): - def __init__( - self, - spans: List[Span], - box_group: Optional[BoxGroup] = None, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + spans: List[Span], + box_group: Optional[BoxGroup] = None, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.spans = spans + if not allow_overlap: + clusters = Span.cluster_spans(spans=spans) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "SpanGroup does not allow overlapping spans. " + "Consider setting allow_overlap=True." + ) self.box_group = box_group super().__init__(id=id, doc=doc, metadata=metadata) @property def symbols(self) -> List[str]: if self.doc is not None: - return [ - self.doc.symbols[span.start: span.end] for span in self.spans - ] + return [self.doc.symbols[span.start : span.end] for span in self.spans] else: return [] @@ -187,12 +195,10 @@ def to_json(self) -> Dict: spans=[span.to_json() for span in self.spans], id=self.id, metadata=self.metadata.to_json(), - box_group=self.box_group.to_json() if self.box_group else None + box_group=self.box_group.to_json() if self.box_group else None, ) return { - key: value - for key, value in span_group_dict.items() - if value is not None + key: value for key, value in span_group_dict.items() if value is not None } # only serialize non-null values @classmethod @@ -211,7 +217,7 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup": # therefore have "id", "type" in the root of the json dict instead. metadata_dict = { "type": span_group_dict.get("type", None), - "text": span_group_dict.get("text", None) + "text": span_group_dict.get("text", None), } return cls( @@ -256,7 +262,7 @@ def __deepcopy__(self, memo): spans=deepcopy(self.spans, memo), id=self.id, metadata=deepcopy(self.metadata, memo), - box_group=deepcopy(self.box_group, memo) + box_group=deepcopy(self.box_group, memo), ) # Don't copy an attached document @@ -284,6 +290,5 @@ def text(self, text: Union[str, None]) -> None: self.metadata.text = text - class Relation(Annotation): - pass \ No newline at end of file + pass diff --git a/src/mmda/types/box.py b/src/mmda/types/box.py index c1bfd4a9..76c1d9f2 100644 --- a/src/mmda/types/box.py +++ b/src/mmda/types/box.py @@ -5,79 +5,66 @@ """ -from typing import List, Dict, Tuple, Union -from dataclasses import dataclass +import logging import warnings +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + import numpy as np -def is_overlap_1d(start1: float, end1: float, start2: float, end2: float, x: float = 0) -> bool: +def _is_overlap_1d( + start1: float, end1: float, start2: float, end2: float, x: float = 0 +) -> bool: """Return whether two 1D intervals overlaps given x""" assert start1 <= end1 assert start2 <= end2 - return not (start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x) # ll # rr + return not ( + start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x + ) # ll # rr -@dataclass class Box: - l: float - t: float - w: float - h: float - page: int + def __init__(self, l: float, t: float, w: float, h: float, page: int) -> None: + if w < 0 or h < 0: + raise ValueError(f"Width and height must be non-negative, got {w} and {h}") + if page < 0: + raise ValueError(f"Page must be non-negative, got {page}") + if l < 0 or t < 0: + raise ValueError(f"Left and top must be non-negative, got {l} and {t}") + self.l = l + self.t = t + self.w = w + self.h = h + self.page = page def to_json(self) -> Dict[str, float]: - return {'left': self.l, 'top': self.t, 'width': self.w, 'height': self.h, 'page': self.page} + return { + "left": self.l, + "top": self.t, + "width": self.w, + "height": self.h, + "page": self.page, + } @classmethod def from_json(cls, box_dict: Dict[str, Union[float, int]]) -> "Box": - return Box(l=box_dict['left'], t=box_dict['top'], w=box_dict['width'], h=box_dict['height'], - page=box_dict['page']) + return Box( + l=box_dict["left"], + t=box_dict["top"], + w=box_dict["width"], + h=box_dict["height"], + page=box_dict["page"], + ) @classmethod def from_coordinates(cls, x1: float, y1: float, x2: float, y2: float, page: int): - return cls(x1, y1, x2 - x1, y2 - y1, page) - - @classmethod - def from_pdf_coordinates( - cls, - x1: float, - y1: float, - x2: float, - y2: float, - page_width: float, - page_height: float, - page: int, - ): - """ - Convert PDF coordinates to absolute coordinates. - The difference between from_pdf_coordinates and from_coordinates is that this function - will perform extra checks to ensure the coordinates are valid, i.e., - 0<= x1 <= x2 <= page_width and 0<= y1 <= y2 <= page_height. - """ - - _x1, _x2 = np.clip([x1, x2], 0, page_width) - _y1, _y2 = np.clip([y1, y2], 0, page_height) - - if _x2 < _x1: - _x2 = _x1 - if _y2 < _y1: - _y2 = _y1 - if (_x1, _y1, _x2, _y2) != (x1, y1, x2, y2): - warnings.warn( - f"The coordinates ({x1}, {y1}, {x2}, {y2}) are not valid and converted to ({_x1}, {_y1}, {_x2}, {_y2})." - ) - - return cls(_x1, _y1, _x2 - _x1, _y2 - _y1, page) + return Box(l=x1, t=y1, w=x2 - x1, h=y2 - y1, page=page) @classmethod def small_boxes_to_big_box(cls, boxes: List["Box"]) -> "Box": """Computes one big box that tightly encapsulates all smaller input boxes""" - boxes = [box for box in boxes if box is not None] - if not boxes: - return None - if len({box.page for box in boxes}) != 1: raise ValueError(f"Bboxes not all on same page: {boxes}") x1 = min([bbox.l for bbox in boxes]) @@ -95,11 +82,6 @@ def coordinates(self) -> Tuple[float, float, float, float]: def center(self) -> Tuple[float, float]: return self.l + self.w / 2, self.t + self.h / 2 - @property - def xywh(self) -> Tuple[float, float, float, float]: - """Return a tuple of the (left, top, width, height) format.""" - return self.l, self.t, self.w, self.h - def get_relative(self, page_width: float, page_height: float) -> "Box": """Get the relative coordinates of self based on page_width, page_height.""" return self.__class__( @@ -120,18 +102,108 @@ def get_absolute(self, page_width: int, page_height: int) -> "Box": page=self.page, ) - def is_overlap(self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False) -> bool: + def is_overlap( + self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False + ) -> bool: """ Whether self overlaps with the other Box object. x, y distances for padding center (bool) if True, only consider overlapping if this box's center is contained by other """ + if self.page != other.page: + return False + x11, y11, x12, y12 = self.coordinates x21, y21, x22, y22 = other.coordinates if center: center_x, center_y = self.center - res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d(center_y, center_y, y21, y22, y) + res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d( + center_y, center_y, y21, y22, y + ) else: - res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d(y11, y12, y21, y22, y) - + res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d( + y11, y12, y21, y22, y + ) return res + + @classmethod + def cluster_boxes(cls, boxes: List["Box"]) -> List[List[int]]: + """ + Cluster boxes into groups based on any overlap. + """ + if not boxes: + return [] + + clusters: List[List[int]] = [[0]] + cluster_id_to_big_box: Dict[int, Box] = {0: boxes[0]} + for box_id in range(1, len(boxes)): + box = boxes[box_id] + + # check all the clusters to see if the box overlaps with any of them + is_overlap = False + for cluster_id, big_box in cluster_id_to_big_box.items(): + if box.is_overlap(big_box, x=0, y=0): + is_overlap = True + break + + # resolve + if is_overlap: + clusters[cluster_id].append(box_id) + cluster_id_to_big_box[cluster_id] = cls.small_boxes_to_big_box( + [box, big_box] + ) + else: + clusters.append([box_id]) + cluster_id_to_big_box[len(clusters) - 1] = box + + # sort clusters + for cluster in clusters: + cluster.sort() + clusters.sort(key=lambda x: x[0]) + + return clusters + + def shrink(self, delta: float, ignore: bool = True, clip: bool = True): + x1, y1, x2, y2 = self.coordinates + if x2 - x1 <= 2 * delta: + if ignore: + logging.warning(f"box's x-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's x-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's x-coords {self} go beyond page boundary. Clipping..." + ) + x1 = min(x1 + delta, 1.0) + x2 = max(x2 - delta, 0.0) + else: + raise ValueError( + f"box's x-coordinates {self} go beyond page boundary. need clip." + ) + + if y2 - y1 <= 2 * delta: + if ignore: + logging.warning(f"box's y-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's y-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's y-coords {self} go beyond page boundary. Clipping..." + ) + y1 = min(y1 + delta, 1.0) + y2 = max(y2 - delta, 0.0) + else: + raise ValueError( + f"box's y-coordinates {self} go beyond page boundary. need clip." + ) + + self.l = x1 + self.t = y1 + self.w = x2 - x1 + self.h = y2 - y1 diff --git a/src/mmda/types/document.py b/src/mmda/types/document.py index 4f2457d3..da4c9f01 100644 --- a/src/mmda/types/document.py +++ b/src/mmda/types/document.py @@ -11,10 +11,14 @@ from mmda.types.annotation import Annotation, BoxGroup, SpanGroup from mmda.types.image import PILImage -from mmda.types.indexers import Indexer, SpanGroupIndexer +from mmda.types.indexers import BoxGroupIndexer, SpanGroupIndexer from mmda.types.metadata import Metadata from mmda.types.names import ImagesField, MetadataField, SymbolsField -from mmda.utils.tools import MergeSpans, allocate_overlapping_tokens_for_box, box_groups_to_span_groups +from mmda.utils.tools import ( + MergeSpans, + allocate_overlapping_tokens_for_box, + box_groups_to_span_groups, +) class Document: @@ -25,20 +29,21 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None): self.symbols = symbols self.images = [] self.__fields = [] - self.__indexers: Dict[str, Indexer] = {} + self.__sg_indexers: Dict[str, SpanGroupIndexer] = {} + self.__bg_indexers: Dict[str, BoxGroupIndexer] = {} self.metadata = metadata if metadata else Metadata() @property def fields(self) -> List[str]: return self.__fields - # TODO: extend implementation to support DocBoxGroup def find_overlapping(self, query: Annotation, field_name: str) -> List[Annotation]: - if not isinstance(query, SpanGroup): - raise NotImplementedError( - f"Currently only supports query of type SpanGroup" - ) - return self.__indexers[field_name].find(query=query) + if isinstance(query, SpanGroup): + return self.__sg_indexers[field_name].find(query=query) + elif isinstance(query, BoxGroup): + return self.__bg_indexers[field_name].find(query=query) + else: + raise NotImplementedError(f"Only supports query SpanGroup or BoxGroup") def add_metadata(self, **kwargs): """Copy kwargs into the document metadata""" @@ -46,7 +51,7 @@ def add_metadata(self, **kwargs): self.metadata.set(k, value) def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] + self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] ) -> None: """Annotate the fields for document symbols (correlating the annotations with the symbols) and store them into the papers. @@ -54,7 +59,7 @@ def annotate( # 1) check validity of field names for field_name in kwargs.keys(): assert ( - field_name not in self.SPECIAL_FIELDS + field_name not in self.SPECIAL_FIELDS ), f"The field_name {field_name} should not be in {self.SPECIAL_FIELDS}." if field_name in self.fields: @@ -83,7 +88,7 @@ def annotate( annotation_types = {type(a) for a in annotations} assert ( - len(annotation_types) == 1 + len(annotation_types) == 1 ), f"Annotations in field_name {field_name} more than 1 type: {annotation_types}" annotation_type = annotation_types.pop() @@ -94,7 +99,8 @@ def annotate( elif annotation_type == BoxGroup: # TODO: not good. BoxGroups should be stored on their own, not auto-generating SpanGroups. span_groups = self._annotate_span_group( - span_groups=box_groups_to_span_groups(annotations, self), field_name=field_name + span_groups=box_groups_to_span_groups(annotations, self), + field_name=field_name, ) else: raise NotImplementedError( @@ -108,10 +114,10 @@ def annotate( def remove(self, field_name: str): delattr(self, field_name) self.__fields = [f for f in self.__fields if f != field_name] - del self.__indexers[field_name] + del self.__sg_indexers[field_name] def annotate_images( - self, images: Iterable[PILImage], is_overwrite: bool = False + self, images: Iterable[PILImage], is_overwrite: bool = False ) -> None: if not is_overwrite and len(self.images) > 0: raise AssertionError( @@ -133,22 +139,134 @@ def annotate_images( self.images = images def _annotate_span_group( - self, span_groups: List[SpanGroup], field_name: str + self, span_groups: List[SpanGroup], field_name: str ) -> List[SpanGroup]: """Annotate the Document using a bunch of span groups. It will associate the annotations with the document symbols. """ assert all([isinstance(group, SpanGroup) for group in span_groups]) - # 1) add Document to each SpanGroup + # 1) Build fast overlap lookup index + self.__sg_indexers[field_name] = SpanGroupIndexer(span_groups) + + # 2) add Document to each SpanGroup for span_group in span_groups: span_group.attach_doc(doc=self) - # 2) Build fast overlap lookup index - self.__indexers[field_name] = SpanGroupIndexer(span_groups) - return span_groups + def _annotate_box_group( + self, box_groups: List[BoxGroup], field_name: str + ) -> List[BoxGroup]: + """Annotate the Document using a bunch of box groups. + It will associate the annotations with the document's pixel coords. + """ + assert all([isinstance(group, BoxGroup) for group in box_groups]) + + # 1) Build fast overlap lookup index + self.__bg_indexers[field_name] = BoxGroupIndexer(box_groups) + + # 2) add Document to each BoxGroup + for box_group in box_groups: + box_group.attach_doc(doc=self) + + return box_groups + + def _convert_box_group_to_span_group( + self, box_groups: List[BoxGroup], field_name: str + ) -> List[SpanGroup]: + """Convert a BoxGroup to a SpanGroup. + It will associate the annotations with the document symbols. + """ + assert all([isinstance(group, BoxGroup) for group in box_groups]) + + all_page_tokens = dict() + derived_span_groups = [] + token_box_in_box_group = None + + for box_id, box_group in enumerate(box_groups): + all_tokens_overlapping_box_group = [] + + for box in box_group.boxes: + # Caching the page tokens to avoid duplicated search + if box.page not in all_page_tokens: + cur_page_tokens = all_page_tokens[box.page] = self.pages[ + box.page + ].tokens + if token_box_in_box_group is None: + # Determine whether box is stored on token SpanGroup span.box or in the box_group + token_box_in_box_group = all( + [ + ( + ( + hasattr(token.box_group, "boxes") + and len(token.box_group.boxes) == 1 + ) + and token.spans[0].box is None + ) + for token in cur_page_tokens + ] + ) + else: + cur_page_tokens = all_page_tokens[box.page] + + # Find all the tokens within the box + tokens_in_box, remaining_tokens = allocate_overlapping_tokens_for_box( + tokens=cur_page_tokens, + box=box, + token_box_in_box_group=token_box_in_box_group, + ) + all_page_tokens[box.page] = remaining_tokens + + all_tokens_overlapping_box_group.extend(tokens_in_box) + + merge_spans = ( + MergeSpans.from_span_groups_with_box_groups( + span_groups=all_tokens_overlapping_box_group, index_distance=1 + ) + if token_box_in_box_group + else MergeSpans( + list_of_spans=list( + itertools.chain.from_iterable( + span_group.spans + for span_group in all_tokens_overlapping_box_group + ) + ), + index_distance=1, + ) + ) + + derived_span_groups.append( + SpanGroup( + spans=merge_spans.merge_neighbor_spans_by_symbol_distance(), + box_group=box_group, + # id = box_id, + ) + # TODO Right now we cannot assign the box id, or otherwise running doc.blocks will + # generate blocks out-of-the-specified order. + ) + + if not token_box_in_box_group: + logging.warning( + "tokens with box stored in SpanGroup span.box will be deprecated (that is, " + "future Spans wont contain box). Ensure Document is annotated with tokens " + "having box stored in SpanGroup box_group.boxes" + ) + + del all_page_tokens + + derived_span_groups = sorted( + derived_span_groups, key=lambda span_group: span_group.start + ) + # ensure they are ordered based on span indices + + for box_id, span_group in enumerate(derived_span_groups): + span_group.id = box_id + + return self._annotate_span_group( + span_groups=derived_span_groups, field_name=field_name + ) + # # to & from JSON # diff --git a/src/mmda/types/indexers.py b/src/mmda/types/indexers.py index beb12b1f..828d7094 100644 --- a/src/mmda/types/indexers.py +++ b/src/mmda/types/indexers.py @@ -4,15 +4,16 @@ """ -from typing import List - from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass, field +from typing import List -from mmda.types.annotation import SpanGroup, Annotation -from ncls import NCLS import numpy as np import pandas as pd +from ncls import NCLS + +from mmda.types.annotation import Annotation, Box, BoxGroup, SpanGroup @dataclass @@ -56,7 +57,7 @@ def __init__(self, span_groups: List[SpanGroup]) -> None: self._index = NCLS( pd.Series(starts, dtype=np.int64), pd.Series(ends, dtype=np.int64), - pd.Series(ids, dtype=np.int64) + pd.Series(ids, dtype=np.int64), ) self._ensure_disjoint() @@ -68,15 +69,26 @@ def _ensure_disjoint(self) -> None: """ for span_group in self._sgs: for span in span_group.spans: - matches = [match for match in self._index.find_overlap(span.start, span.end)] - if len(matches) > 1: + match_ids = [ + matched_id + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ) + ] + if len(match_ids) > 1: + matches = [self._sgs[match_id].to_json() for match_id in match_ids] raise ValueError( - f"Detected overlap with existing SpanGroup(s) {matches} for {span_group}" + f"Detected overlap! While processing the Span {span} as part of query SpanGroup {span_group.to_json()}, we found that it overlaps with existing SpanGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) ) def find(self, query: SpanGroup) -> List[SpanGroup]: if not isinstance(query, SpanGroup): - raise ValueError(f'SpanGroupIndexer only works with `query` that is SpanGroup type') + raise ValueError( + f"SpanGroupIndexer only works with `query` that is SpanGroup type" + ) if not query.spans: return [] @@ -84,7 +96,9 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: matched_ids = set() for span in query.spans: - for _start, _end, matched_id in self._index.find_overlap(span.start, span.end): + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ): matched_ids.add(matched_id) matched_span_groups = [self._sgs[matched_id] for matched_id in matched_ids] @@ -95,3 +109,78 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: return sorted(list(matched_span_groups)) +class BoxGroupIndexer(Indexer): + """ + Manages a data structure for locating overlapping BoxGroups. + Builds a static nested containment list from BoxGroups + and accepts other BoxGroups as search probes. + """ + + def __init__(self, box_groups: List[BoxGroup]) -> None: + self._bgs = box_groups + + self._box_id_to_box_group_id = {} + self._boxes = [] + box_id = 0 + for bg_id, bg in enumerate(box_groups): + for box in bg.boxes: + self._boxes.append(box) + self._box_id_to_box_group_id[box_id] = bg_id + box_id += 1 + + self._np_boxes_x1 = np.array([b.l for b in self._boxes]) + self._np_boxes_y1 = np.array([b.t for b in self._boxes]) + self._np_boxes_x2 = np.array([b.l + b.w for b in self._boxes]) + self._np_boxes_y2 = np.array([b.t + b.h for b in self._boxes]) + self._np_boxes_page = np.array([b.page for b in self._boxes]) + + self._ensure_disjoint() + + def _find_overlap_boxes(self, query: Box) -> List[int]: + x1, y1, x2, y2 = query.coordinates + mask = ( + (self._np_boxes_x1 <= x2) + & (self._np_boxes_x2 >= x1) + & (self._np_boxes_y1 <= y2) + & (self._np_boxes_y2 >= y1) + & (self._np_boxes_page == query.page) + ) + return np.where(mask)[0].tolist() + + def _find_overlap_box_groups(self, query: Box) -> List[int]: + return [ + self._box_id_to_box_group_id[box_id] + for box_id in self._find_overlap_boxes(query) + ] + + def _ensure_disjoint(self) -> None: + """ + Constituent box groups must be fully disjoint. + Ensure the integrity of the built index. + """ + for box_group in self._bgs: + for box in box_group.boxes: + match_ids = self._find_overlap_box_groups(query=box) + if len(match_ids) > 1: + matches = [self._bgs[match_id].to_json() for match_id in match_ids] + raise ValueError( + f"Detected overlap! While processing the Box {box} as part of query BoxGroup {box_group.to_json()}, we found that it overlaps with existing BoxGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) + ) + + def find(self, query: BoxGroup) -> List[BoxGroup]: + if not isinstance(query, BoxGroup): + raise ValueError( + f"BoxGroupIndexer only works with `query` that is BoxGroup type" + ) + + if not query.boxes: + return [] + + match_ids = [] + for box in query.boxes: + match_ids.extend(self._find_overlap_box_groups(query=box)) + + return [self._bgs[match_id] for match_id in sorted(set(match_ids))] diff --git a/src/mmda/utils/outline_metadata.py b/src/mmda/utils/outline_metadata.py index 06d59022..c1d8211e 100644 --- a/src/mmda/utils/outline_metadata.py +++ b/src/mmda/utils/outline_metadata.py @@ -6,6 +6,7 @@ @rauthur """ +import logging from dataclasses import asdict, dataclass from io import BytesIO from typing import Any, Dict, List, Union @@ -16,6 +17,10 @@ import pdfminer.pdftypes as pt import pdfminer.psparser as pr +# set logging level to ERROR +logging.getLogger("pdfminer").propagate = False +logging.getLogger("pdfminer").setLevel(logging.ERROR) + from mmda.types.document import Document from mmda.types.metadata import Metadata diff --git a/tests/fixtures/56c0a25e7bd3f220df8f9939f23c1982c2cb5fc4.pdf b/tests/fixtures/56c0a25e7bd3f220df8f9939f23c1982c2cb5fc4.pdf new file mode 100644 index 00000000..8135a1f3 Binary files /dev/null and b/tests/fixtures/56c0a25e7bd3f220df8f9939f23c1982c2cb5fc4.pdf differ diff --git a/tests/fixtures/72b37044a17c9210ed56c2cc7b9a737b1385311b.pdf b/tests/fixtures/72b37044a17c9210ed56c2cc7b9a737b1385311b.pdf new file mode 100644 index 00000000..c7ca65fd Binary files /dev/null and b/tests/fixtures/72b37044a17c9210ed56c2cc7b9a737b1385311b.pdf differ diff --git a/tests/fixtures/faa06090392e9633e608516b8c35f163f4a8f38a.pdf b/tests/fixtures/faa06090392e9633e608516b8c35f163f4a8f38a.pdf new file mode 100644 index 00000000..f8e82b1d Binary files /dev/null and b/tests/fixtures/faa06090392e9633e608516b8c35f163f4a8f38a.pdf differ diff --git a/tests/test_recipes/test_core_recipe.py b/tests/test_recipes/test_core_recipe.py index 6adb4d30..3256888d 100644 --- a/tests/test_recipes/test_core_recipe.py +++ b/tests/test_recipes/test_core_recipe.py @@ -142,3 +142,28 @@ def test_manual_create_using_annotate(self): == doc_json["blocks"] == [b.to_json() for b in self.doc.blocks] ) + + def test_no_fail_pdfs_in_fixtures(self): + pdfpath = os.path.join( + os.path.dirname(__file__), + "../fixtures/56c0a25e7bd3f220df8f9939f23c1982c2cb5fc4.pdf", + ) + doc = self.recipe.from_path(pdfpath=pdfpath) + + pdfpath = os.path.join( + os.path.dirname(__file__), + "../fixtures/72b37044a17c9210ed56c2cc7b9a737b1385311b.pdf", + ) + doc = self.recipe.from_path(pdfpath=pdfpath) + + pdfpath = os.path.join( + os.path.dirname(__file__), + "../fixtures/faa06090392e9633e608516b8c35f163f4a8f38a.pdf", + ) + doc = self.recipe.from_path(pdfpath=pdfpath) + + assert ( + [b.to_json() for b in doc2.blocks] + == doc_json["blocks"] + == [b.to_json() for b in self.doc.blocks] + ) diff --git a/tests/test_types/test_box.py b/tests/test_types/test_box.py index 8528ecc6..9d24d960 100644 --- a/tests/test_types/test_box.py +++ b/tests/test_types/test_box.py @@ -1,18 +1,100 @@ import unittest + from mmda.types import box as mmda_box class TestBox(unittest.TestCase): def setUp(cls) -> None: - cls.box_dict = {'left': 0.2, - 'top': 0.09, - 'width': 0.095, - 'height': 0.017, - 'page': 0} + cls.box_dict = { + "left": 0.2, + "top": 0.09, + "width": 0.095, + "height": 0.017, + "page": 0, + } cls.box = mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0) - def test_from_json(self): - self.assertEqual(self.box.from_json(self.box_dict), self.box) + def test_to_from_json(self): + box = self.box.from_json(self.box_dict) + self.assertDictEqual(box.to_json(), self.box_dict) + + def test_cluster_boxes(self): + # overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0, 1, 2]]) + + # on-overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.4, t=0.30, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1], [2]]) + + # partially overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.301, t=0.201, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1, 2]]) + + def test_create_invalid_box(self): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7 + 0.0000001, t=0.2, w=0.3, h=0.4, page=0) + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6 + 0.000001, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=-1) + + def test_shrink(self): + # usual + box = mmda_box.Box(l=0.1, t=0.2, w=0.3, h=0.4, page=0) + box.shrink(delta=0.1) + self.assertAlmostEqual(box.l, 0.2) # 0.1 + 0.1 + self.assertAlmostEqual(box.t, 0.3) # 0.2 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # 0.3 - 0.1 * 2 + self.assertAlmostEqual(box.h, 0.2) # 0.4 - 0.1 * 2 + + # shrinking until inverts box. would ignore shrinking along appropriate axis. + box = mmda_box.Box(l=0.9, t=0.5, w=0.1, h=0.3, page=0) + box.shrink(delta=0.1, ignore=True) + self.assertAlmostEqual(box.l, 0.9) # ignored + self.assertAlmostEqual(box.t, 0.6) # adjusted; 0.5 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # ignored + self.assertAlmostEqual(box.h, 0.1) # adjusted; 0.3 - 2 * 0.1 + + # shrinking until out of bounds. would clip along appropriate axis. + # actually... does this ever happen unless Box is already out of bounds? - def test_to_json(self): - self.assertEqual(self.box.to_json(), self.box_dict) + def test_cluster_boxes_hard(self): + # from 4be952924cd565488b4a239dc6549095029ee578.pdf, page 2, tokens 650:655 + boxes = [ + mmda_box.Box( + l=0.7761069934640523, + t=0.14276190217171716, + w=0.005533858823529373, + h=0.008037272727272593, + page=2, + ), + mmda_box.Box( + l=0.7836408522875816, + t=0.14691867138383832, + w=0.005239432156862763, + h=0.005360666666666692, + page=2, + ), + mmda_box.Box( + l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + ), + mmda_box.Box( + l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + ), + mmda_box.Box( + l=1.0, t=0.32670311318181816, w=0.0, h=0.010037272727272737, page=2 + ), + ] diff --git a/tests/test_types/test_document.py b/tests/test_types/test_document.py index 355e4993..791ca03f 100644 --- a/tests/test_types/test_document.py +++ b/tests/test_types/test_document.py @@ -1,13 +1,12 @@ import json -import unittest import os +import unittest +from ai2_internal import api from mmda.types.annotation import SpanGroup from mmda.types.document import Document from mmda.types.names import MetadataField, SymbolsField -from ai2_internal import api - def resolve(file: str) -> str: return os.path.join(os.path.dirname(__file__), "../fixtures/types", file) @@ -58,10 +57,15 @@ def test_annotate_box_groups_gets_text(self): spp_doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] spp_doc.annotate(new_span_groups=box_groups) - assert spp_doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") + assert spp_doc.new_span_groups[0].text.startswith( + "Gutman G, Rosenzweig D, Golan J" + ) # when token boxes are on spans plumber_doc = "c8b53e2d9cd247e2d42719e337bfb13784d22bd2.json" @@ -70,7 +74,10 @@ def test_annotate_box_groups_gets_text(self): doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] doc.annotate(new_span_groups=box_groups) assert doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") @@ -80,7 +87,10 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): a last-known-good fixture """ # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), "r") as f: + with open( + resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) @@ -88,9 +98,16 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): # spangroups derived from boxgroups of boxes drawn neatly around bib entries by calling `.annotate` on # list of BoxGroups fixture_span_groups = [] - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json"), "r") as f: + with open( + resolve( + "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json" + ), + "r", + ) as f: raw_json = f.read() - fixture_bib_entries_json = json.loads(raw_json)["bib_entry_span_groups_from_box_groups"] + fixture_bib_entries_json = json.loads(raw_json)[ + "bib_entry_span_groups_from_box_groups" + ] # make box_groups to annotate from test fixture bib entry span groups, and save the for bib_entry in fixture_bib_entries_json: @@ -108,3 +125,30 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): assert sg1.spans == sg2.spans assert sg1.text == sg2.text + def test_indexing_and_find_overlapping(self): + # this doc only has pages, tokens, rows + spp_plumber_doc = "spp-dag-0-0-4-doc.json" + doc_file = resolve(spp_plumber_doc) + with open(doc_file) as f: + spp_doc = Document.from_json(json.load(f)) + + # tokens & rows have box groups that can be indexed + for token in spp_doc.tokens: + assert token.box_group + for row in spp_doc.rows: + assert row.box_group + + # index + spp_doc._annotate_box_group( + box_groups=[token.box_group for token in spp_doc.tokens], + field_name="tokens", + ) + spp_doc._annotate_box_group( + box_groups=[row.box_group for row in spp_doc.rows], field_name="rows" + ) + + # for every row, find the tokens that have visual overlap + for row in spp_doc.rows: + overlap_tokens = spp_doc.find_overlapping( + query=row.box_group, field_name="tokens" + ) diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py index 05ab6885..2cb7b33c 100644 --- a/tests/test_types/test_indexers.py +++ b/tests/test_types/test_indexers.py @@ -1,19 +1,13 @@ import unittest -from mmda.types import SpanGroup, Span -from mmda.types.indexers import SpanGroupIndexer +from mmda.types import Box, BoxGroup, Span, SpanGroup +from mmda.types.indexers import BoxGroupIndexer, SpanGroupIndexer class TestSpanGroupIndexer(unittest.TestCase): def test_overlap_within_single_spangroup_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(4, 7) - ] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(4, 7)], allow_overlap=True) ] with self.assertRaises(ValueError): @@ -21,17 +15,8 @@ def test_overlap_within_single_spangroup_fails_checks(self): def test_overlap_between_spangroups_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(6, 10)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(6, 10)]), ] with self.assertRaises(ValueError): @@ -39,21 +24,9 @@ def test_overlap_between_spangroups_fails_checks(self): def test_finds_matching_groups_in_doc_order(self): span_groups_to_index = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(9, 10)] - ), - SpanGroup( - id=3, - spans=[Span(100, 105)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(9, 10)]), + SpanGroup(id=3, spans=[Span(100, 105)]), ] index = SpanGroupIndexer(span_groups_to_index) @@ -66,4 +39,67 @@ def test_finds_matching_groups_in_doc_order(self): self.assertEqual(matches, [span_groups_to_index[0], span_groups_to_index[1]]) +class TestBoxGroupIndexer(unittest.TestCase): + def test_overlap_within_single_boxgroup_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, + boxes=[Box(0, 0, 5, 5, page=0), Box(4, 4, 7, 7, page=0)], + allow_overlap=True, + ) + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + + def test_overlap_between_boxgroups_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, boxes=[Box(0, 0, 5, 5, page=0), Box(5.01, 5.01, 8, 8, page=0)] + ), + BoxGroup(id=2, boxes=[Box(6, 6, 10, 10, page=0)]), + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + def test_finds_matching_groups_in_doc_order(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=0)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=0)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # should intersect 1 and 2 but not 3 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) + + def test_finds_matching_groups_accounts_for_pages(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=1)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=1)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # shouldnt intersect any given page 0 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 1) + self.assertEqual(matches, [box_groups_to_index[0]]) + + # shoudl intersect after switching to page 1 (and the page 2 box doesnt intersect) + probe = BoxGroup( + id=4, boxes=[Box(1, 1, 5, 5, page=1), Box(100, 100, 1, 1, page=2)] + ) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) diff --git a/tests/test_types/test_span.py b/tests/test_types/test_span.py index 53466097..89627ac6 100644 --- a/tests/test_types/test_span.py +++ b/tests/test_types/test_span.py @@ -33,6 +33,54 @@ def test_to_json(self): self.assertEqual(self.span.from_json(self.span_dict).to_json(), self.span_dict) def test_is_overlap(self): + span1 = mmda_span.Span(start=0, end=8) + span2 = mmda_span.Span(start=0, end=8) + self.assertTrue(span1.is_overlap(span2)) + + span3 = mmda_span.Span(start=2, end=5) + self.assertTrue(span1.is_overlap(span3)) + + span4 = mmda_span.Span(start=8, end=10) + self.assertFalse(span1.is_overlap(span4)) + + span5 = mmda_span.Span(start=10, end=12) + self.assertFalse(span1.is_overlap(span5)) + + def small_spans_to_big_span(self): + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertEqual( + self.span.small_spans_to_big_span(spans), + mmda_span.Span(start=0, end=24), + ) + + def test_cluster_spans(self): + # overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0, 1, 2]]) + + # non-overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1], [2]]) + + # partially overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=9, end=16), + mmda_span.Span(start=10, end=15), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1, 2]]) span = mmda_span.Span(start=0, end=2) self.assertTrue(span.is_overlap(mmda_span.Span(start=0, end=1))) self.assertTrue(span.is_overlap(mmda_span.Span(start=1, end=2)))