Skip to content

Commit 3e6bbac

Browse files
voodoo11Manul from Pathway
authored andcommitted
PaddleOCR parser (#9338)
GitOrigin-RevId: e63bf7bd8ea1a268397edadb5cbcf1c7f33b6223
1 parent 24c8443 commit 3e6bbac

File tree

4 files changed

+266
-8
lines changed

4 files changed

+266
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
55
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66
## [Unreleased]
77

8+
### Added
9+
10+
- New parser `pathway.xpacks.llm.parsers.PaddleOCRParser` supporting parsing of PDF, PPTX and images.
11+
812
## [0.26.2]
913

1014
### Added

integration_tests/xpack/test_parsers.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
import pytest
66

77
import pathway as pw
8-
from pathway.tests.utils import assert_table_equality
9-
from pathway.xpacks.llm.parsers import DoclingParser, UnstructuredParser
8+
from pathway.tests.utils import assert_table_equality, run_all
9+
from pathway.xpacks.llm.parsers import (
10+
DoclingParser,
11+
PaddleOCRParser,
12+
PypdfParser,
13+
UnstructuredParser,
14+
)
1015

1116
FOLDER_WITH_ONE_FILE_ID = "1XisWrSjKMCx2jfUW8OSgt6L8veq8c4Mh"
1217

@@ -29,7 +34,6 @@ class schema(pw.Schema):
2934

3035

3136
@pytest.mark.environment_changes
32-
@pytest.mark.asyncio
3337
def test_parse_unstructured_unk_exception(monkeypatch):
3438
parser = UnstructuredParser()
3539

@@ -55,7 +59,14 @@ class schema(pw.Schema):
5559
assert "FileType.UNK" in exception_msg
5660

5761

58-
def test_single_file_read_with_constraints(tmp_path, credentials_dir):
62+
single_page_parsers = {
63+
"docling": lambda: DoclingParser(table_parsing_strategy="docling", chunk=False),
64+
"paddle_ocr_structure": lambda: PaddleOCRParser(concatenate_pages=True),
65+
}
66+
67+
68+
@pytest.mark.parametrize("parser_name", single_page_parsers.keys())
69+
def test_parse_pdf_single_page(parser_name, tmp_path, credentials_dir):
5970
files_table = pw.io.gdrive.read(
6071
FOLDER_WITH_ONE_FILE_ID,
6172
mode="static",
@@ -64,12 +75,13 @@ def test_single_file_read_with_constraints(tmp_path, credentials_dir):
6475
with_metadata=True,
6576
)
6677

67-
parser = DoclingParser(table_parsing_strategy="docling", chunk=False)
78+
parser = single_page_parsers[parser_name]()
79+
6880
parse_table = files_table.select(parsed_text=parser(pw.this.data)[0][0])
6981

7082
pw.io.jsonlines.write(parse_table, tmp_path / "output.jsonl")
7183

72-
pw.run()
84+
run_all()
7385

7486
rows_count = 0
7587
with open(tmp_path / "output.jsonl", "r") as f:
@@ -80,3 +92,44 @@ def test_single_file_read_with_constraints(tmp_path, credentials_dir):
8092

8193
assert rows_count == 1
8294
assert "first decomposed with a parse tree and converted" in text
95+
96+
97+
parsers = {
98+
"pypdf": lambda: PypdfParser(apply_text_cleanup=True),
99+
"paddle_ocr_structure": lambda: PaddleOCRParser(),
100+
}
101+
102+
103+
@pytest.mark.parametrize("parser_name", parsers.keys())
104+
def test_parse_pdf_multi_page_output(parser_name, tmp_path, credentials_dir):
105+
files_table = pw.io.gdrive.read(
106+
FOLDER_WITH_ONE_FILE_ID,
107+
mode="static",
108+
service_user_credentials_file=str(credentials_dir / "credentials.json"),
109+
object_size_limit=None,
110+
with_metadata=True,
111+
)
112+
113+
parser = parsers[parser_name]()
114+
115+
parse_table = (
116+
files_table.select(result=parser(pw.this.data))
117+
.flatten(pw.this.result)
118+
.select(
119+
text=pw.this.result[0],
120+
page_number=pw.this.result[1]["page_number"].as_int(),
121+
)
122+
)
123+
124+
pw.io.jsonlines.write(parse_table, tmp_path / "output.jsonl")
125+
126+
run_all()
127+
128+
pages = set()
129+
with open(tmp_path / "output.jsonl", "r") as f:
130+
for raw_row in f:
131+
row = json.loads(raw_row)
132+
pages.add(row["page_number"])
133+
assert len(row["text"]) > 100
134+
135+
assert pages == set(range(0, 10))

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ xpack-llm-docs = [
8080
"pdf2image",
8181
"pypdf",
8282
# unstructured dependency, but unstructured doesn't force this version even though it's needed
83-
"pdfminer.six == 20250506",
83+
"pdfminer.six == 20250506",
84+
"paddleocr[doc-parser] >= 3.2.0",
85+
"paddlepaddle >=3.1.1"
8486
]
8587
xpack-sharepoint = [
8688
"Office365-REST-Python-Client >= 2.5.3",

python/pathway/xpacks/llm/parsers.py

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,30 @@
1111
import logging
1212
import re
1313
import warnings
14+
from abc import ABC, abstractmethod
1415
from collections import defaultdict
1516
from collections.abc import Callable
1617
from functools import partial
1718
from io import BytesIO
1819
from typing import TYPE_CHECKING, Iterable, Iterator, Literal, TypeAlias, get_args
1920

21+
import numpy as np
22+
from pdf2image import convert_from_bytes
2023
from PIL import Image
2124
from pydantic import BaseModel
25+
from unstructured.file_utils.filetype import FileType, detect_filetype
2226

2327
import pathway as pw
2428
from pathway.internals import udfs
2529
from pathway.internals.config import _check_entitlements
2630
from pathway.optional_import import optional_imports
27-
from pathway.xpacks.llm import llms, prompts
31+
from pathway.xpacks.llm import _parser_utils, llms, prompts
2832
from pathway.xpacks.llm._utils import _prepare_executor
2933
from pathway.xpacks.llm.constants import DEFAULT_VISION_MODEL
3034

3135
if TYPE_CHECKING:
3236
with optional_imports("xpack-llm-docs"):
37+
from paddleocr import PaddleOCR, PPStructureV3
3338
from unstructured.documents.elements import Element
3439

3540
logger = logging.getLogger(__name__)
@@ -1091,3 +1096,197 @@ def replace_newline(match: re.Match):
10911096

10921097
modified_text = re.sub(r"\n(\w)", replace_newline, text)
10931098
return modified_text
1099+
1100+
1101+
class _PaddleParser(ABC):
1102+
"""
1103+
Abstract wrapper for Paddle pipeline, that extracts text from OCR results.
1104+
"""
1105+
1106+
pipeline: PaddleOCR | PPStructureV3
1107+
1108+
def __init__(self, pipeline: PaddleOCR | PPStructureV3):
1109+
self.pipeline = pipeline
1110+
1111+
def parse(self, image: np.ndarray) -> str:
1112+
ocr_result = self.pipeline.predict(image)
1113+
return self.extract_text(ocr_result)
1114+
1115+
@abstractmethod
1116+
def extract_text(self, ocr_result: list) -> str:
1117+
pass
1118+
1119+
@staticmethod
1120+
def create_for(pipeline: PaddleOCR | PPStructureV3) -> _PaddleParser:
1121+
with optional_imports("xpack-llm-docs"):
1122+
from paddleocr import PaddleOCR, PPStructureV3
1123+
1124+
match pipeline:
1125+
case PPStructureV3():
1126+
return _PaddlePPStructureV3Parser(pipeline)
1127+
case PaddleOCR():
1128+
return _PaddleOCRParser(pipeline)
1129+
case _:
1130+
raise NotImplementedError(
1131+
f"Extractor for {type(pipeline)} is not implemented."
1132+
)
1133+
1134+
1135+
class _PaddlePPStructureV3Parser(_PaddleParser):
1136+
def extract_text(self, ocr_result: list) -> str:
1137+
pages = []
1138+
1139+
for res in ocr_result:
1140+
try:
1141+
pages.append(res.markdown)
1142+
except AttributeError:
1143+
logger.error("Failed to extract text from OCR result.")
1144+
continue
1145+
1146+
result = self.pipeline.concatenate_markdown_pages(pages)
1147+
1148+
return result
1149+
1150+
1151+
class _PaddleOCRParser(_PaddleParser):
1152+
def extract_text(self, ocr_result: list) -> str:
1153+
result = ""
1154+
for res in ocr_result:
1155+
try:
1156+
text = res["rec_texts"]
1157+
result += " ".join(text) + "\n\n"
1158+
except KeyError:
1159+
logger.error("Failed to extract text from OCR result.")
1160+
continue
1161+
return result
1162+
1163+
1164+
class PaddleOCRParser(pw.UDF):
1165+
"""
1166+
A class to parse images, PDFs and PPTX slides using PaddleOCR.
1167+
1168+
Args:
1169+
pipeline: A Paddle pipeline object. Currently PaddleOCR and PPStructureV3 are supported.
1170+
If not provided, a default PPStructureV3 pipeline will be used.
1171+
Use PPStructureV3 for better accuracy on documents with complex layouts. PaddleOCR can be used for
1172+
simpler documents, extracting only text but may be faster.
1173+
concatenate_pages: Whether to concatenate multi-paged documents into a single output. Defaults to False.
1174+
intermediate_image_format: Intermediate image format used when converting PDFs to images.
1175+
Defaults to ``"jpg"`` for speed and memory use.
1176+
max_image_size: Maximum allowed size of the images in bytes. Default is 15 MB.
1177+
downsize_horizontal_width: Width to which images are downsized if necessary.
1178+
Default is 1920.
1179+
cache_strategy: Defines the caching mechanism. To enable caching,
1180+
a valid :py:class:``~pathway.udfs.CacheStrategy`` should be provided.
1181+
Defaults to None.
1182+
async_mode: Mode of execution for the UDF, either ``"batch_async"`` or ``"fully_async"``.
1183+
Default is ``"batch_async"``.
1184+
"""
1185+
1186+
parser: _PaddleParser
1187+
intermediate_image_format: str
1188+
max_image_size: int
1189+
downsize_horizontal_width: int
1190+
1191+
def __init__(
1192+
self,
1193+
pipeline: PaddleOCR | PPStructureV3 | None = None,
1194+
*,
1195+
concatenate_pages: bool = False,
1196+
intermediate_image_format: str = "jpg",
1197+
max_image_size: int = 15 * 1024 * 1024,
1198+
downsize_horizontal_width: int = 1920,
1199+
cache_strategy: udfs.CacheStrategy | None = None,
1200+
async_mode: Literal["batch_async", "fully_async"] = "batch_async",
1201+
):
1202+
super().__init__(
1203+
executor=_prepare_executor(async_mode=async_mode),
1204+
cache_strategy=cache_strategy,
1205+
)
1206+
1207+
with optional_imports("xpack-llm-docs"):
1208+
import paddleocr # noqa:F401
1209+
1210+
self.intermediate_image_format = intermediate_image_format
1211+
self.max_image_size = max_image_size
1212+
self.downsize_horizontal_width = downsize_horizontal_width
1213+
self.concatenate_pages = concatenate_pages
1214+
1215+
if pipeline is None:
1216+
pipeline = self._default_pipeline()
1217+
1218+
self.parser = _PaddleParser.create_for(pipeline)
1219+
1220+
def _default_pipeline(self) -> PPStructureV3:
1221+
with optional_imports("xpack-llm-docs"):
1222+
from paddleocr import PPStructureV3
1223+
return PPStructureV3(
1224+
use_table_recognition=False,
1225+
use_doc_orientation_classify=False,
1226+
use_doc_unwarping=False,
1227+
use_textline_orientation=False,
1228+
use_seal_recognition=False,
1229+
use_formula_recognition=False,
1230+
use_chart_recognition=False,
1231+
use_region_detection=False,
1232+
)
1233+
1234+
def _normalize_input(
1235+
self,
1236+
contents: bytes,
1237+
) -> tuple[list[Image.Image], FileType | None]:
1238+
byte_file = io.BytesIO(contents)
1239+
filetype = detect_filetype(file=byte_file)
1240+
1241+
match filetype:
1242+
case FileType.PPT | FileType.PPTX:
1243+
contents = _parser_utils._convert_pptx_to_pdf(contents)
1244+
images = convert_from_bytes(
1245+
contents, fmt=self.intermediate_image_format
1246+
)
1247+
case FileType.PDF:
1248+
images = convert_from_bytes(
1249+
contents, fmt=self.intermediate_image_format
1250+
)
1251+
case _ as filetype:
1252+
try:
1253+
images = [Image.open(io.BytesIO(contents)).convert("RGB")]
1254+
except Exception as e:
1255+
logger.error(f"Failed to parse provided file. Reason: {e}")
1256+
return [], None
1257+
1258+
images = [
1259+
_parser_utils.maybe_downscale(
1260+
img,
1261+
max_image_size=self.max_image_size,
1262+
downsize_horizontal_width=self.downsize_horizontal_width,
1263+
)
1264+
for img in images
1265+
]
1266+
1267+
return images, filetype
1268+
1269+
async def __wrapped__(self, contents: bytes) -> list[tuple[str, dict]]:
1270+
images, original_filetype = self._normalize_input(contents)
1271+
1272+
def metadata(page_number: int) -> dict:
1273+
if original_filetype in [FileType.PPT, FileType.PPTX, FileType.PDF]:
1274+
return {"page_number": page_number}
1275+
return {}
1276+
1277+
docs = []
1278+
1279+
for i, image in enumerate(images):
1280+
try:
1281+
img_np = np.array(image)
1282+
text = self.parser.parse(img_np)
1283+
docs.append((text, metadata(i)))
1284+
except Exception as e:
1285+
logger.error(f"Failed to process an image. Reason: {e}")
1286+
continue
1287+
1288+
if self.concatenate_pages and len(docs) > 1:
1289+
concatenated_text = "\n\n".join([doc[0] for doc in docs])
1290+
docs = [(concatenated_text, {"page_number": 0})]
1291+
1292+
return docs

0 commit comments

Comments
 (0)