diff --git a/.gitignore b/.gitignore index 5b870e9..3c3fc5a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ uv.lock *.pdf dist +models/ diff --git a/commonforms/inference.py b/commonforms/inference.py index f429049..1b2ebaa 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -1,6 +1,7 @@ from __future__ import annotations from ultralytics import YOLO from pathlib import Path +from huggingface_hub import hf_hub_download from commonforms.utils import BoundingBox, Page, Widget from commonforms.form_creator import PyPdfFormCreator @@ -10,6 +11,15 @@ import pypdfium2 +# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub +models = { + ("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"), + ("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"), + ("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"), + ("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"), +} + + class FFDNetDetector: def __init__( self, model_or_path: str, device: int | str = "cpu", fast: bool = False @@ -32,11 +42,9 @@ def get_model_path( """ model_upper = model_or_path.upper() if model_upper in ["FFDNET-S", "FFDNET-L"]: - extension = "onnx" if fast else "pt" - # load from the package - normalize to proper case - model_name = "FFDNet-S" if model_upper == "FFDNET-S" else "FFDNet-L" - model_path = Path(__file__).parent / "models" / f"{model_name}.{extension}" - print(f"using model: {model_path}") + # download the model, will just use the cached version if it already exists + repo_id, filename = models[(model_upper, fast)] + model_path = hf_hub_download(repo_id=repo_id, filename=filename) else: model_path = model_or_path diff --git a/commonforms/models/FFDNet-L.onnx b/commonforms/models/FFDNet-L.onnx deleted file mode 100644 index 02b3289..0000000 Binary files a/commonforms/models/FFDNet-L.onnx and /dev/null differ diff --git a/commonforms/models/FFDNet-L.pt b/commonforms/models/FFDNet-L.pt deleted file mode 100644 index eda3cc3..0000000 Binary files a/commonforms/models/FFDNet-L.pt and /dev/null differ diff --git a/commonforms/models/FFDNet-S.onnx b/commonforms/models/FFDNet-S.onnx deleted file mode 100644 index 3df113f..0000000 Binary files a/commonforms/models/FFDNet-S.onnx and /dev/null differ diff --git a/commonforms/models/FFDNet-S.pt b/commonforms/models/FFDNet-S.pt deleted file mode 100644 index 36d206e..0000000 Binary files a/commonforms/models/FFDNet-S.pt and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index b45954c..a4e46ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ urls = { Homepage = "https://github.com/jbarrow/commonforms" } dependencies = [ "cryptography>=3.1", "formalpdf==0.1.5", + "huggingface-hub>=0.35.3", "onnx>=1.19.1", "onnxruntime>=1.23.1", "onnxslim>=0.1.71", @@ -28,7 +29,7 @@ commonforms = "commonforms:main" packages = ["commonforms"] [tool.setuptools.package-data] -commonforms = ["models/*.pt"] +commonforms = ["models/*.pt", "models/*.onnx"] [dependency-groups] dev = [