Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ jobs:

- name: Smoke test wheel
run: >
uv run --isolated --no-project --with dist/*.whl python -c "from
uv run --no-cache --isolated --no-project --with dist/*.whl python -c "from
hilbertsfc import hilbert_decode_2d, hilbert_encode_2d; i =
hilbert_encode_2d(1, 2, nbits=2); assert hilbert_decode_2d(i, nbits=2)
== (1, 2)"

- name: Smoke test source distribution
run: >
uv run --isolated --no-project --with dist/*.tar.gz python -c "from
uv run --no-cache --isolated --no-project --with dist/*.tar.gz python -c "from
hilbertsfc import hilbert_decode_3d, hilbert_encode_3d; i =
hilbert_encode_3d(1, 2, 3, nbits=2); assert hilbert_decode_3d(i,
nbits=2) == (1, 2, 3)"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
run: uv build

- name: Smoke test (wheel)
run: uv run --isolated --no-project --with dist/*.whl --with pytest pytest -q
run: uv run --no-cache --isolated --no-project --with dist/*.whl --with pytest pytest -q

- name: Smoke test (source distribution)
run: uv run --isolated --no-project --with dist/*.tar.gz --with pytest pytest -q
run: uv run --no-cache --isolated --no-project --with dist/*.tar.gz --with pytest pytest -q

- name: Publish
run: |
Expand Down
89 changes: 73 additions & 16 deletions src/hilbertsfc/_kernels/numba/hilbert2d_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,24 @@ def build_hilbert_decode_2d_impl(nbits: int, *, tile_nbits: TileNBits2D | None =

if tile_nbits == 7:
lut = lut_2d7b_q_bs_u64()
kernel = _hilbert_decode_2d_7bit_compacted_bs

@nb.njit(inline="always", cache=True)
def decode_2d_7bit(index: IntScalar) -> tuple[int, int]:
return _hilbert_decode_2d_7bit_compacted_bs(index, nbits, lut)

return decode_2d_7bit

elif tile_nbits == 4:
lut = lut_2d4b_q_bs_u64()
kernel = _hilbert_decode_2d_4bit_compacted_bs
else:
raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")

@nb.njit(inline="always", cache=False)
def decode_2d(index: IntScalar) -> tuple[int, int]:
return kernel(index, nbits, lut)
@nb.njit(inline="always", cache=True)
def decode_2d_4bit(index: IntScalar) -> tuple[int, int]:
return _hilbert_decode_2d_4bit_compacted_bs(index, nbits, lut)

return decode_2d_4bit

return decode_2d
else:
raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")


@kernel_cache
Expand All @@ -106,12 +112,63 @@ def build_hilbert_decode_2d_batch_impl(

validate_nbits_2d(nbits)

decode_scalar = build_hilbert_decode_2d_impl(nbits, tile_nbits=tile_nbits)

@nb.njit(parallel=parallel, cache=False)
def decode_2d_batch(indices: UIntArray, xs: UIntArray, ys: UIntArray) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i] = decode_scalar(indices.flat[i])
if tile_nbits is None:
tile_nbits = _auto_tile_nbits_2d(nbits)

return decode_2d_batch
if tile_nbits == 7:
lut = lut_2d7b_q_bs_u64()
if parallel:

@nb.njit(parallel=True, cache=True)
def decode_2d_batch_7bit_parallel(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i] = _hilbert_decode_2d_7bit_compacted_bs(
indices.flat[i], nbits, lut
)

return decode_2d_batch_7bit_parallel

@nb.njit(parallel=False, cache=True)
def decode_2d_batch_7bit_serial(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in range(n):
xs.flat[i], ys.flat[i] = _hilbert_decode_2d_7bit_compacted_bs(
indices.flat[i], nbits, lut
)

return decode_2d_batch_7bit_serial

if tile_nbits == 4:
lut = lut_2d4b_q_bs_u64()
if parallel:

@nb.njit(parallel=True, cache=True)
def decode_2d_batch_4bit_parallel(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i] = _hilbert_decode_2d_4bit_compacted_bs(
indices.flat[i], nbits, lut
)

return decode_2d_batch_4bit_parallel

@nb.njit(parallel=False, cache=True)
def decode_2d_batch_4bit_serial(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in range(n):
xs.flat[i], ys.flat[i] = _hilbert_decode_2d_4bit_compacted_bs(
indices.flat[i], nbits, lut
)

return decode_2d_batch_4bit_serial

raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")
93 changes: 77 additions & 16 deletions src/hilbertsfc/_kernels/numba/hilbert2d_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,28 @@ def build_hilbert_encode_2d_impl(nbits: int, *, tile_nbits: TileNBits2D | None =

if tile_nbits == 7:
lut = lut_2d7b_b_qs_u64()
kernel = _hilbert_encode_2d_7bit_compacted_qs

@nb.njit(inline="always", cache=True)
def encode_2d_7bit(x: IntScalar, y: IntScalar) -> int:
return _hilbert_encode_2d_7bit_compacted_qs( # type: ignore[reportReturnType]
x, y, nbits, lut
)

return encode_2d_7bit

elif tile_nbits == 4:
lut = lut_2d4b_b_qs_u64()
kernel = _hilbert_encode_2d_4bit_compacted_qs
else:
raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")

@nb.njit(inline="always", cache=False)
def encode_2d(x: IntScalar, y: IntScalar) -> int:
return kernel(x, y, nbits, lut) # type: ignore[reportReturnType]
@nb.njit(inline="always", cache=True)
def encode_2d_4bit(x: IntScalar, y: IntScalar) -> int:
return _hilbert_encode_2d_4bit_compacted_qs( # type: ignore[reportReturnType]
x, y, nbits, lut
)

return encode_2d_4bit

return encode_2d
else:
raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")


@kernel_cache
Expand All @@ -103,12 +113,63 @@ def build_hilbert_encode_2d_batch_impl(

validate_nbits_2d(nbits)

encode_scalar = build_hilbert_encode_2d_impl(nbits, tile_nbits=tile_nbits)

@nb.njit(parallel=parallel, cache=False)
def encode_2d_batch(xs: UIntArray, ys: UIntArray, out: UIntArray) -> None:
n = xs.size
for i in nb.prange(n): # type: ignore[not-iterable]
out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i])
if tile_nbits is None:
tile_nbits = _auto_tile_nbits_2d(nbits)

return encode_2d_batch
if tile_nbits == 7:
lut = lut_2d7b_b_qs_u64()
if parallel:

@nb.njit(parallel=True, cache=True)
def encode_2d_batch_7bit_parallel(
xs: UIntArray, ys: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in nb.prange(n): # type: ignore[not-iterable]
out.flat[i] = _hilbert_encode_2d_7bit_compacted_qs(
xs.flat[i], ys.flat[i], nbits, lut
)

return encode_2d_batch_7bit_parallel

@nb.njit(parallel=False, cache=True)
def encode_2d_batch_7bit_serial(
xs: UIntArray, ys: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in range(n):
out.flat[i] = _hilbert_encode_2d_7bit_compacted_qs(
xs.flat[i], ys.flat[i], nbits, lut
)

return encode_2d_batch_7bit_serial

if tile_nbits == 4:
lut = lut_2d4b_b_qs_u64()
if parallel:

@nb.njit(parallel=True, cache=True)
def encode_2d_batch_4bit_parallel(
xs: UIntArray, ys: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in nb.prange(n): # type: ignore[not-iterable]
out.flat[i] = _hilbert_encode_2d_4bit_compacted_qs(
xs.flat[i], ys.flat[i], nbits, lut
)

return encode_2d_batch_4bit_parallel

@nb.njit(parallel=False, cache=True)
def encode_2d_batch_4bit_serial(
xs: UIntArray, ys: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in range(n):
out.flat[i] = _hilbert_encode_2d_4bit_compacted_qs(
xs.flat[i], ys.flat[i], nbits, lut
)

return encode_2d_batch_4bit_serial

raise ValueError("tile_nbits must be 4 or 7 (or None for auto)")
30 changes: 23 additions & 7 deletions src/hilbertsfc/_kernels/numba/hilbert3d_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_hilbert_decode_3d_impl(

lut = lut_3d2b_so_sb(lut_dtype)

@nb.njit(inline="always", cache=False)
@nb.njit(inline="always", cache=True)
def decode_3d(index: IntScalar) -> tuple[int, int, int]:
return _hilbert_decode_3d_2bit_sb(index, nbits, lut)

Expand All @@ -61,14 +61,30 @@ def build_hilbert_decode_3d_batch_impl(

validate_nbits_3d(nbits)

decode_scalar = build_hilbert_decode_3d_impl(nbits, lut_dtype=lut_dtype)
lut = lut_3d2b_so_sb(lut_dtype)

if parallel:

@nb.njit(parallel=True, cache=True)
def decode_3d_batch_parallel(
indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i], zs.flat[i] = _hilbert_decode_3d_2bit_sb(
indices.flat[i], nbits, lut
)

return decode_3d_batch_parallel

@nb.njit(parallel=parallel, cache=False)
def decode_3d_batch(
@nb.njit(parallel=False, cache=True)
def decode_3d_batch_serial(
indices: UIntArray, xs: UIntArray, ys: UIntArray, zs: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i], zs.flat[i] = decode_scalar(indices.flat[i])
for i in range(n):
xs.flat[i], ys.flat[i], zs.flat[i] = _hilbert_decode_3d_2bit_sb(
indices.flat[i], nbits, lut
)

return decode_3d_batch
return decode_3d_batch_serial
30 changes: 23 additions & 7 deletions src/hilbertsfc/_kernels/numba/hilbert3d_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def build_hilbert_encode_3d_impl(

lut = lut_3d2b_sb_so(lut_dtype)

@nb.njit(inline="always", cache=False)
@nb.njit(inline="always", cache=True)
def encode_3d(x: IntScalar, y: IntScalar, z: IntScalar) -> int:
return _hilbert_encode_3d_2bit_so(x, y, z, nbits, lut) # type: ignore[reportReturnType]

Expand All @@ -63,14 +63,30 @@ def build_hilbert_encode_3d_batch_impl(

validate_nbits_3d(nbits)

encode_scalar = build_hilbert_encode_3d_impl(nbits, lut_dtype=lut_dtype)
lut = lut_3d2b_sb_so(lut_dtype)

if parallel:

@nb.njit(parallel=True, cache=True)
def encode_3d_batch_parallel(
xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in nb.prange(n): # type: ignore[not-iterable]
out.flat[i] = _hilbert_encode_3d_2bit_so(
xs.flat[i], ys.flat[i], zs.flat[i], nbits, lut
)

return encode_3d_batch_parallel

@nb.njit(parallel=parallel, cache=False)
def encode_3d_batch(
@nb.njit(parallel=False, cache=True)
def encode_3d_batch_serial(
xs: UIntArray, ys: UIntArray, zs: UIntArray, out: UIntArray
) -> None:
n = xs.size
for i in nb.prange(n): # type: ignore[not-iterable]
out.flat[i] = encode_scalar(xs.flat[i], ys.flat[i], zs.flat[i])
for i in range(n):
out.flat[i] = _hilbert_encode_3d_2bit_so(
xs.flat[i], ys.flat[i], zs.flat[i], nbits, lut
)

return encode_3d_batch
return encode_3d_batch_serial
26 changes: 19 additions & 7 deletions src/hilbertsfc/_kernels/numba/morton2d_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def build_morton_decode_2d_impl(nbits: int):

validate_nbits_2d(nbits)

@nb.njit(inline="always", cache=False)
@nb.njit(inline="always", cache=True)
def decode_2d(index: IntScalar) -> tuple[int, int]:
return _morton_decode_2d(index, nbits) # type: ignore[reportReturnType]

Expand All @@ -80,12 +80,24 @@ def build_morton_decode_2d_batch_impl(nbits: int, *, parallel: bool = False):

validate_nbits_2d(nbits)

decode_scalar = build_morton_decode_2d_impl(nbits)
if parallel:

@nb.njit(parallel=parallel, cache=False)
def decode_2d_batch(indices: UIntArray, xs: UIntArray, ys: UIntArray) -> None:
@nb.njit(parallel=True, cache=True)
def decode_2d_batch_parallel(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i] = _morton_decode_2d(indices.flat[i], nbits)

return decode_2d_batch_parallel

@nb.njit(parallel=False, cache=True)
def decode_2d_batch_serial(
indices: UIntArray, xs: UIntArray, ys: UIntArray
) -> None:
n = indices.size
for i in nb.prange(n): # type: ignore[not-iterable]
xs.flat[i], ys.flat[i] = decode_scalar(indices.flat[i])
for i in range(n):
xs.flat[i], ys.flat[i] = _morton_decode_2d(indices.flat[i], nbits)

return decode_2d_batch
return decode_2d_batch_serial
Loading