diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bdd1c6..b21e41f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,14 +57,14 @@ jobs: - name: Smoke test wheel run: > - uv run --isolated --no-project --with dist/*.whl python -c "from + uv run --no-cache --isolated --no-project --with dist/*.whl python -c "from hilbertsfc import hilbert_decode_2d, hilbert_encode_2d; i = hilbert_encode_2d(1, 2, nbits=2); assert hilbert_decode_2d(i, nbits=2) == (1, 2)" - name: Smoke test source distribution run: > - uv run --isolated --no-project --with dist/*.tar.gz python -c "from + uv run --no-cache --isolated --no-project --with dist/*.tar.gz python -c "from hilbertsfc import hilbert_decode_3d, hilbert_encode_3d; i = hilbert_encode_3d(1, 2, 3, nbits=2); assert hilbert_decode_3d(i, nbits=2) == (1, 2, 3)" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8173bb8..c2fb5de 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -29,10 +29,10 @@ jobs: run: uv build - name: Smoke test (wheel) - run: uv run --isolated --no-project --with dist/*.whl --with pytest pytest -q + run: uv run --no-cache --isolated --no-project --with dist/*.whl --with pytest pytest -q - name: Smoke test (source distribution) - run: uv run --isolated --no-project --with dist/*.tar.gz --with pytest pytest -q + run: uv run --no-cache --isolated --no-project --with dist/*.tar.gz --with pytest pytest -q - name: Publish run: | diff --git a/src/hilbertsfc/_kernels/numba/hilbert2d_decode.py b/src/hilbertsfc/_kernels/numba/hilbert2d_decode.py index 91e5156..474940c 100644 --- a/src/hilbertsfc/_kernels/numba/hilbert2d_decode.py +++ b/src/hilbertsfc/_kernels/numba/hilbert2d_decode.py @@ -84,18 +84,24 @@ def build_hilbert_decode_2d_impl(nbits: int, *, tile_nbits: TileNBits2D | None = if tile_nbits == 7: lut = lut_2d7b_q_bs_u64() - kernel = _hilbert_decode_2d_7bit_compacted_bs + + @nb.njit(inline="always", cache=True) + def decode_2d_7bit(index: IntScalar) -> tuple[int, int]: + return _hilbert_decode_2d_7bit_compacted_bs(index, nbits, lut) + + return decode_2d_7bit + elif tile_nbits == 4: lut = lut_2d4b_q_bs_u64() - kernel = _hilbert_decode_2d_4bit_compacted_bs - else: - raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") - @nb.njit(inline="always", cache=False) - def decode_2d(index: IntScalar) -> tuple[int, int]: - return kernel(index, nbits, lut) + @nb.njit(inline="always", cache=True) + def decode_2d_4bit(index: IntScalar) -> tuple[int, int]: + return _hilbert_decode_2d_4bit_compacted_bs(index, nbits, lut) + + return decode_2d_4bit - return decode_2d + else: + raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") @kernel_cache @@ -106,12 +112,63 @@ def build_hilbert_decode_2d_batch_impl( validate_nbits_2d(nbits) - decode_scalar = build_hilbert_decode_2d_impl(nbits, tile_nbits=tile_nbits) - - @nb.njit(parallel=parallel, cache=False) - def decode_2d_batch(indices: UIntArray, xs: UIntArray, ys: UIntArray) -> None: - n = indices.size - for i in nb.prange(n): # type: ignore[not-iterable] - xs.flat[i], ys.flat[i] = decode_scalar(indices.flat[i]) + if tile_nbits is None: + tile_nbits = _auto_tile_nbits_2d(nbits) - return decode_2d_batch + if tile_nbits == 7: + lut = lut_2d7b_q_bs_u64() + if parallel: + + @nb.njit(parallel=True, cache=True) + def decode_2d_batch_7bit_parallel( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: + n = indices.size + for i in nb.prange(n): # type: ignore[not-iterable] + xs.flat[i], ys.flat[i] = _hilbert_decode_2d_7bit_compacted_bs( + indices.flat[i], nbits, lut + ) + + return decode_2d_batch_7bit_parallel + + @nb.njit(parallel=False, cache=True) + def decode_2d_batch_7bit_serial( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: + n = indices.size + for i in range(n): + xs.flat[i], ys.flat[i] = _hilbert_decode_2d_7bit_compacted_bs( + indices.flat[i], nbits, lut + ) + + return decode_2d_batch_7bit_serial + + if tile_nbits == 4: + lut = lut_2d4b_q_bs_u64() + if parallel: + + @nb.njit(parallel=True, cache=True) + def decode_2d_batch_4bit_parallel( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: + n = indices.size + for i in nb.prange(n): # type: ignore[not-iterable] + xs.flat[i], ys.flat[i] = _hilbert_decode_2d_4bit_compacted_bs( + indices.flat[i], nbits, lut + ) + + return decode_2d_batch_4bit_parallel + + @nb.njit(parallel=False, cache=True) + def decode_2d_batch_4bit_serial( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: + n = indices.size + for i in range(n): + xs.flat[i], ys.flat[i] = _hilbert_decode_2d_4bit_compacted_bs( + indices.flat[i], nbits, lut + ) + + return decode_2d_batch_4bit_serial + + raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") diff --git a/src/hilbertsfc/_kernels/numba/hilbert2d_encode.py b/src/hilbertsfc/_kernels/numba/hilbert2d_encode.py index 4f90629..f767bce 100644 --- a/src/hilbertsfc/_kernels/numba/hilbert2d_encode.py +++ b/src/hilbertsfc/_kernels/numba/hilbert2d_encode.py @@ -81,18 +81,28 @@ def build_hilbert_encode_2d_impl(nbits: int, *, tile_nbits: TileNBits2D | None = if tile_nbits == 7: lut = lut_2d7b_b_qs_u64() - kernel = _hilbert_encode_2d_7bit_compacted_qs + + @nb.njit(inline="always", cache=True) + def encode_2d_7bit(x: IntScalar, y: IntScalar) -> int: + return _hilbert_encode_2d_7bit_compacted_qs( # type: ignore[reportReturnType] + x, y, nbits, lut + ) + + return encode_2d_7bit + elif tile_nbits == 4: lut = lut_2d4b_b_qs_u64() - kernel = _hilbert_encode_2d_4bit_compacted_qs - else: - raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") - @nb.njit(inline="always", cache=False) - def encode_2d(x: IntScalar, y: IntScalar) -> int: - return kernel(x, y, nbits, lut) # type: ignore[reportReturnType] + @nb.njit(inline="always", cache=True) + def encode_2d_4bit(x: IntScalar, y: IntScalar) -> int: + return _hilbert_encode_2d_4bit_compacted_qs( # type: ignore[reportReturnType] + x, y, nbits, lut + ) + + return encode_2d_4bit - return encode_2d + else: + raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") @kernel_cache @@ -103,12 +113,63 @@ def build_hilbert_encode_2d_batch_impl( validate_nbits_2d(nbits) - encode_scalar = build_hilbert_encode_2d_impl(nbits, tile_nbits=tile_nbits) - - @nb.njit(parallel=parallel, cache=False) - def encode_2d_batch(xs: UIntArray, ys: UIntArray, out: UIntArray) -> None: - n = xs.size - for i in nb.prange(n): # type: ignore[not-iterable] - out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i]) + if tile_nbits is None: + tile_nbits = _auto_tile_nbits_2d(nbits) - return encode_2d_batch + if tile_nbits == 7: + lut = lut_2d7b_b_qs_u64() + if parallel: + + @nb.njit(parallel=True, cache=True) + def encode_2d_batch_7bit_parallel( + xs: UIntArray, ys: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in nb.prange(n): # type: ignore[not-iterable] + out.flat[i] = _hilbert_encode_2d_7bit_compacted_qs( + xs.flat[i], ys.flat[i], nbits, lut + ) + + return encode_2d_batch_7bit_parallel + + @nb.njit(parallel=False, cache=True) + def encode_2d_batch_7bit_serial( + xs: UIntArray, ys: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in range(n): + out.flat[i] = _hilbert_encode_2d_7bit_compacted_qs( + xs.flat[i], ys.flat[i], nbits, lut + ) + + return encode_2d_batch_7bit_serial + + if tile_nbits == 4: + lut = lut_2d4b_b_qs_u64() + if parallel: + + @nb.njit(parallel=True, cache=True) + def encode_2d_batch_4bit_parallel( + xs: UIntArray, ys: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in nb.prange(n): # type: ignore[not-iterable] + out.flat[i] = _hilbert_encode_2d_4bit_compacted_qs( + xs.flat[i], ys.flat[i], nbits, lut + ) + + return encode_2d_batch_4bit_parallel + + @nb.njit(parallel=False, cache=True) + def encode_2d_batch_4bit_serial( + xs: UIntArray, ys: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in range(n): + out.flat[i] = _hilbert_encode_2d_4bit_compacted_qs( + xs.flat[i], ys.flat[i], nbits, lut + ) + + return encode_2d_batch_4bit_serial + + raise ValueError("tile_nbits must be 4 or 7 (or None for auto)") diff --git a/src/hilbertsfc/_kernels/numba/hilbert3d_decode.py b/src/hilbertsfc/_kernels/numba/hilbert3d_decode.py index 69b54c3..839a253 100644 --- a/src/hilbertsfc/_kernels/numba/hilbert3d_decode.py +++ b/src/hilbertsfc/_kernels/numba/hilbert3d_decode.py @@ -46,7 +46,7 @@ def build_hilbert_decode_3d_impl( lut = lut_3d2b_so_sb(lut_dtype) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def decode_3d(index: IntScalar) -> tuple[int, int, int]: return _hilbert_decode_3d_2bit_sb(index, nbits, lut) @@ -61,14 +61,30 @@ def build_hilbert_decode_3d_batch_impl( validate_nbits_3d(nbits) - decode_scalar = build_hilbert_decode_3d_impl(nbits, lut_dtype=lut_dtype) + lut = lut_3d2b_so_sb(lut_dtype) + + if parallel: + + @nb.njit(parallel=True, cache=True) + def decode_3d_batch_parallel( + indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray + ) -> None: + n = indices.size + for i in nb.prange(n): # type: ignore[not-iterable] + xs.flat[i], ys.flat[i], zs.flat[i] = _hilbert_decode_3d_2bit_sb( + indices.flat[i], nbits, lut + ) + + return decode_3d_batch_parallel - @nb.njit(parallel=parallel, cache=False) - def decode_3d_batch( + @nb.njit(parallel=False, cache=True) + def decode_3d_batch_serial( indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray ) -> None: n = indices.size - for i in nb.prange(n): # type: ignore[not-iterable] - xs.flat[i], ys.flat[i], zs.flat[i] = decode_scalar(indices.flat[i]) + for i in range(n): + xs.flat[i], ys.flat[i], zs.flat[i] = _hilbert_decode_3d_2bit_sb( + indices.flat[i], nbits, lut + ) - return decode_3d_batch + return decode_3d_batch_serial diff --git a/src/hilbertsfc/_kernels/numba/hilbert3d_encode.py b/src/hilbertsfc/_kernels/numba/hilbert3d_encode.py index d47e46d..1b47354 100644 --- a/src/hilbertsfc/_kernels/numba/hilbert3d_encode.py +++ b/src/hilbertsfc/_kernels/numba/hilbert3d_encode.py @@ -48,7 +48,7 @@ def build_hilbert_encode_3d_impl( lut = lut_3d2b_sb_so(lut_dtype) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def encode_3d(x: IntScalar, y: IntScalar, z: IntScalar) -> int: return _hilbert_encode_3d_2bit_so(x, y, z, nbits, lut) # type: ignore[reportReturnType] @@ -63,14 +63,30 @@ def build_hilbert_encode_3d_batch_impl( validate_nbits_3d(nbits) - encode_scalar = build_hilbert_encode_3d_impl(nbits, lut_dtype=lut_dtype) + lut = lut_3d2b_sb_so(lut_dtype) + + if parallel: + + @nb.njit(parallel=True, cache=True) + def encode_3d_batch_parallel( + xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in nb.prange(n): # type: ignore[not-iterable] + out.flat[i] = _hilbert_encode_3d_2bit_so( + xs.flat[i], ys.flat[i], zs.flat[i], nbits, lut + ) + + return encode_3d_batch_parallel - @nb.njit(parallel=parallel, cache=False) - def encode_3d_batch( + @nb.njit(parallel=False, cache=True) + def encode_3d_batch_serial( xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray ) -> None: n = xs.size - for i in nb.prange(n): # type: ignore[not-iterable] - out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i], zs.flat[i]) + for i in range(n): + out.flat[i] = _hilbert_encode_3d_2bit_so( + xs.flat[i], ys.flat[i], zs.flat[i], nbits, lut + ) - return encode_3d_batch + return encode_3d_batch_serial diff --git a/src/hilbertsfc/_kernels/numba/morton2d_decode.py b/src/hilbertsfc/_kernels/numba/morton2d_decode.py index 1297145..a367acc 100644 --- a/src/hilbertsfc/_kernels/numba/morton2d_decode.py +++ b/src/hilbertsfc/_kernels/numba/morton2d_decode.py @@ -67,7 +67,7 @@ def build_morton_decode_2d_impl(nbits: int): validate_nbits_2d(nbits) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def decode_2d(index: IntScalar) -> tuple[int, int]: return _morton_decode_2d(index, nbits) # type: ignore[reportReturnType] @@ -80,12 +80,24 @@ def build_morton_decode_2d_batch_impl(nbits: int, *, parallel: bool = False): validate_nbits_2d(nbits) - decode_scalar = build_morton_decode_2d_impl(nbits) + if parallel: - @nb.njit(parallel=parallel, cache=False) - def decode_2d_batch(indices: UIntArray, xs: UIntArray, ys: UIntArray) -> None: + @nb.njit(parallel=True, cache=True) + def decode_2d_batch_parallel( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: + n = indices.size + for i in nb.prange(n): # type: ignore[not-iterable] + xs.flat[i], ys.flat[i] = _morton_decode_2d(indices.flat[i], nbits) + + return decode_2d_batch_parallel + + @nb.njit(parallel=False, cache=True) + def decode_2d_batch_serial( + indices: UIntArray, xs: UIntArray, ys: UIntArray + ) -> None: n = indices.size - for i in nb.prange(n): # type: ignore[not-iterable] - xs.flat[i], ys.flat[i] = decode_scalar(indices.flat[i]) + for i in range(n): + xs.flat[i], ys.flat[i] = _morton_decode_2d(indices.flat[i], nbits) - return decode_2d_batch + return decode_2d_batch_serial diff --git a/src/hilbertsfc/_kernels/numba/morton2d_encode.py b/src/hilbertsfc/_kernels/numba/morton2d_encode.py index 84873b4..3e3dfcf 100644 --- a/src/hilbertsfc/_kernels/numba/morton2d_encode.py +++ b/src/hilbertsfc/_kernels/numba/morton2d_encode.py @@ -62,7 +62,7 @@ def build_morton_encode_2d_impl(nbits: int): validate_nbits_2d(nbits) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def encode_2d(x: IntScalar, y: IntScalar) -> int: return _morton_encode_2d(x, y, nbits) # type: ignore[reportReturnType] @@ -75,12 +75,22 @@ def build_morton_encode_2d_batch_impl(nbits: int, *, parallel: bool = False): validate_nbits_2d(nbits) - encode_scalar = build_morton_encode_2d_impl(nbits) + if parallel: - @nb.njit(parallel=parallel, cache=False) - def encode_2d_batch(xs: UIntArray, ys: UIntArray, out: UIntArray) -> None: + @nb.njit(parallel=True, cache=True) + def encode_2d_batch_parallel( + xs: UIntArray, ys: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in nb.prange(n): # type: ignore[not-iterable] + out.flat[i] = _morton_encode_2d(xs.flat[i], ys.flat[i], nbits) + + return encode_2d_batch_parallel + + @nb.njit(parallel=False, cache=True) + def encode_2d_batch_serial(xs: UIntArray, ys: UIntArray, out: UIntArray) -> None: n = xs.size - for i in nb.prange(n): # type: ignore[not-iterable] - out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i]) + for i in range(n): + out.flat[i] = _morton_encode_2d(xs.flat[i], ys.flat[i], nbits) - return encode_2d_batch + return encode_2d_batch_serial diff --git a/src/hilbertsfc/_kernels/numba/morton3d_decode.py b/src/hilbertsfc/_kernels/numba/morton3d_decode.py index 93d317e..d72ff97 100644 --- a/src/hilbertsfc/_kernels/numba/morton3d_decode.py +++ b/src/hilbertsfc/_kernels/numba/morton3d_decode.py @@ -69,7 +69,7 @@ def build_morton_decode_3d_impl(nbits: int): validate_nbits_3d(nbits) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def decode_3d(index: IntScalar) -> tuple[int, int, int]: return _morton_decode_3d(index, nbits) # type: ignore[reportReturnType] @@ -82,14 +82,28 @@ def build_morton_decode_3d_batch_impl(nbits: int, *, parallel: bool = False): validate_nbits_3d(nbits) - decode_scalar = build_morton_decode_3d_impl(nbits) + if parallel: - @nb.njit(parallel=parallel, cache=False) - def decode_3d_batch( + @nb.njit(parallel=True, cache=True) + def decode_3d_batch_parallel( + indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray + ) -> None: + n = indices.size + for i in nb.prange(n): # type: ignore[not-iterable] + xs.flat[i], ys.flat[i], zs.flat[i] = _morton_decode_3d( + indices.flat[i], nbits + ) + + return decode_3d_batch_parallel + + @nb.njit(parallel=False, cache=True) + def decode_3d_batch_serial( indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray ) -> None: n = indices.size - for i in nb.prange(n): # type: ignore[not-iterable] - xs.flat[i], ys.flat[i], zs.flat[i] = decode_scalar(indices.flat[i]) + for i in range(n): + xs.flat[i], ys.flat[i], zs.flat[i] = _morton_decode_3d( + indices.flat[i], nbits + ) - return decode_3d_batch + return decode_3d_batch_serial diff --git a/src/hilbertsfc/_kernels/numba/morton3d_encode.py b/src/hilbertsfc/_kernels/numba/morton3d_encode.py index 5b7190c..025e537 100644 --- a/src/hilbertsfc/_kernels/numba/morton3d_encode.py +++ b/src/hilbertsfc/_kernels/numba/morton3d_encode.py @@ -70,7 +70,7 @@ def build_morton_encode_3d_impl(nbits: int): validate_nbits_3d(nbits) - @nb.njit(inline="always", cache=False) + @nb.njit(inline="always", cache=True) def encode_3d(x: IntScalar, y: IntScalar, z: IntScalar) -> int: return _morton_encode_3d(x, y, z, nbits) # type: ignore[reportReturnType] @@ -83,14 +83,26 @@ def build_morton_encode_3d_batch_impl(nbits: int, *, parallel: bool = False): validate_nbits_3d(nbits) - encode_scalar = build_morton_encode_3d_impl(nbits) + if parallel: - @nb.njit(parallel=parallel, cache=False) - def encode_3d_batch( + @nb.njit(parallel=True, cache=True) + def encode_3d_batch_parallel( + xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray + ) -> None: + n = xs.size + for i in nb.prange(n): # type: ignore[not-iterable] + out.flat[i] = _morton_encode_3d( + xs.flat[i], ys.flat[i], zs.flat[i], nbits + ) + + return encode_3d_batch_parallel + + @nb.njit(parallel=False, cache=True) + def encode_3d_batch_serial( xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray ) -> None: n = xs.size - for i in nb.prange(n): # type: ignore[not-iterable] - out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i], zs.flat[i]) + for i in range(n): + out.flat[i] = _morton_encode_3d(xs.flat[i], ys.flat[i], zs.flat[i], nbits) - return encode_3d_batch + return encode_3d_batch_serial