Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: test

on:
push:
branches: [main]
pull_request:
workflow_dispatch:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
enable-cache: true
- name: Run tests
Comment thread
rbavery marked this conversation as resolved.
run: uv run --frozen pytest -q
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
]
classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.12" ]
dependencies = [
"torch>=2.8",
"torch>=2.8,<2.13",
]

[project.urls]
Expand All @@ -27,6 +27,15 @@ package-dir = {"" = "src"}
where = ["src"]
include = ["wherobots_export*"]

[dependency-groups]
dev = [
"numpy",
"pytest>=8",
]

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"

[tool.uv.sources]
torch = { index = "pytorch-cpu" }
64 changes: 64 additions & 0 deletions tests/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pathlib import Path

import pytest
import torch
from torch.export.pt2_archive._package import load_pt2

Comment thread
rbavery marked this conversation as resolved.
from wherobots_export.torch.export import create_example_input_from_shape, save


class TinyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, padding=1)

def forward(self, pixels: torch.Tensor) -> torch.Tensor:
return self.conv(pixels).mean(dim=(2, 3))


class Normalize(torch.nn.Module):
def forward(self, pixels: torch.Tensor) -> torch.Tensor:
return (pixels - 0.5) / 0.5


def test_save_creates_nonempty_pt2(tmp_path: Path) -> None:
output_file = tmp_path / "model.pt2"
save(model=TinyModel(), output_file=output_file, input_shape=[1, 3, 32, 32], device="cpu")
assert output_file.exists()
assert output_file.stat().st_size > 0


def test_save_dynamic_batch_roundtrip(tmp_path: Path) -> None:
model = TinyModel().eval()
output_file = tmp_path / "model.pt2"
save(model=model, output_file=output_file, input_shape=[-1, 3, 32, 32], device="cpu")

contents = load_pt2(str(output_file))
assert set(contents.exported_programs) == {"model"}

exported = contents.exported_programs["model"].module()
pixels = torch.randn(4, 3, 32, 32) # batch differs from the export-time example
torch.testing.assert_close(exported(pixels), model(pixels))


def test_save_with_transforms_packages_both(tmp_path: Path) -> None:
output_file = tmp_path / "model.pt2"
save(
model=TinyModel(),
output_file=output_file,
input_shape=[-1, 3, 32, 32],
device="cpu",
transforms=Normalize(),
)
contents = load_pt2(str(output_file))
assert set(contents.exported_programs) == {"model", "transforms"}


def test_create_example_input_rejects_all_dynamic() -> None:
with pytest.raises(ValueError):
create_example_input_from_shape([-1, -1, -1, -1])


def test_create_example_input_fills_dynamic_dims() -> None:
example = create_example_input_from_shape([-1, 3, -1, -1], shape_default_value=64)
assert example.shape == (2, 3, 64, 64)
238 changes: 238 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading