Skip to content

Commit f110c82

Browse files
authored
Merge pull request #9 from shaham-lab/feat/bugfixes-tests-cicd
Feat/bugfixes tests cicd
2 parents 62de252 + 05e0b31 commit f110c82

20 files changed

Lines changed: 5357 additions & 154 deletions

.github/workflows/ci.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
name: Test (Python ${{ matrix.python-version }})
12+
runs-on: ubuntu-latest
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: ["3.11", "3.12"]
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
cache: pip
26+
27+
- name: Install package with test dependencies
28+
run: |
29+
pip install --upgrade pip
30+
pip install -e ".[dev]"
31+
32+
- name: Run tests
33+
run: pytest src/tests -v --tb=short --no-header
34+
35+
- name: Upload test results
36+
if: always()
37+
uses: actions/upload-artifact@v4
38+
with:
39+
name: test-results-${{ matrix.python-version }}
40+
path: .pytest_cache/

.github/workflows/release.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
name: Release
2+
3+
on:
4+
push:
5+
branches: [main]
6+
7+
permissions:
8+
contents: write # push tags / GitHub release
9+
id-token: write # OIDC token for trusted PyPI publishing
10+
11+
jobs:
12+
release:
13+
name: Semantic release & publish
14+
runs-on: ubuntu-latest
15+
# Only run when CI passes — avoid publishing broken code
16+
needs: []
17+
18+
concurrency:
19+
group: release
20+
cancel-in-progress: false
21+
22+
steps:
23+
- uses: actions/checkout@v4
24+
with:
25+
fetch-depth: 0 # full history so PSR can read all commits
26+
token: ${{ secrets.GITHUB_TOKEN }}
27+
28+
- name: Set up Python
29+
uses: actions/setup-python@v5
30+
with:
31+
python-version: "3.11"
32+
cache: pip
33+
34+
- name: Install build tools
35+
run: pip install python-semantic-release build
36+
37+
# python-semantic-release reads commit history, bumps the version in
38+
# setup.cfg, creates a tag, a GitHub release, and builds the package.
39+
- name: Run semantic-release
40+
id: semrel
41+
env:
42+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
43+
run: semantic-release version
44+
45+
# Build sdist + wheel only when a new version was actually released
46+
- name: Build distribution
47+
if: steps.semrel.outputs.released == 'true'
48+
run: python -m build
49+
50+
# Publish to PyPI via OIDC (no API token needed — configure trusted
51+
# publisher on pypi.org: owner=<your-gh-username>,
52+
# repo=SpectralNet, workflow=release.yml, environment=pypi)
53+
- name: Publish to PyPI
54+
if: steps.semrel.outputs.released == 'true'
55+
uses: pypa/gh-action-pypi-publish@release/v1
56+
with:
57+
print-hash: true

README.md

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,123 @@ This package is based on the following paper - [SpectralNet](https://openreview.
88

99
## Installation
1010

11-
You can install the latest package version via
11+
### From PyPI
1212

1313
```bash
1414
pip install spectralnet
1515
```
1616

17+
### From source (with pixi)
18+
19+
[pixi](https://pixi.sh) is the recommended way to set up a fully reproducible
20+
development environment after cloning the repo.
21+
22+
```bash
23+
# 1. Install pixi (once, system-wide)
24+
curl -fsSL https://pixi.sh/install.sh | sh
25+
26+
# 2. Clone and enter the repo
27+
git clone https://github.com/shaham-lab/SpectralNet.git
28+
cd SpectralNet
29+
30+
# 3. Install all dependencies (conda + PyPI) into an isolated environment
31+
pixi install
32+
33+
# 4. Run the test suite to verify everything works
34+
pixi run test
35+
```
36+
37+
After `pixi install` you can prefix any command with `pixi run` to execute it
38+
inside the managed environment, or activate the environment with:
39+
40+
```bash
41+
pixi shell
42+
```
43+
1744
## Usage
1845

19-
### Clustering
46+
### Clustering — small datasets (in-memory tensor)
2047

21-
The basic functionality is quite intuitive and easy to use, e.g.,
48+
For datasets that fit in RAM, pass a `torch.Tensor` directly:
2249

2350
```python
2451
from spectralnet import SpectralNet
2552

2653
spectralnet = SpectralNet(n_clusters=10)
27-
spectralnet.fit(X) # X is the dataset and it should be a torch.Tensor
28-
cluster_assignments = spectralnet.predict(X) # Get the final assignments to clusters
54+
spectralnet.fit(X) # X: torch.Tensor of shape (N, ...)
55+
cluster_assignments = spectralnet.predict(X)
2956
```
3057

31-
If you have labels to your dataset and you want to measure ACC and NMI you can do the following:
58+
To measure ACC and NMI when labels are available:
3259

3360
```python
34-
from spectralnet import SpectralNet
35-
from spectralnet import Metrics
36-
61+
from spectralnet import SpectralNet, Metrics
3762

3863
spectralnet = SpectralNet(n_clusters=2)
39-
spectralnet.fit(X, y) # X is the dataset and it should be a torch.Tensor
40-
cluster_assignments = spectralnet.predict(X) # Get the final assignments to clusters
41-
42-
y = y_train.detach().cpu().numpy() # In case your labels are of torch.Tensor type.
43-
acc_score = Metrics.acc_score(cluster_assignments, y, n_clusters=2)
44-
nmi_score = Metrics.nmi_score(cluster_assignments, y)
45-
print(f"ACC: {np.round(acc_score, 3)}")
46-
print(f"NMI: {np.round(nmi_score, 3)}")
64+
spectralnet.fit(X, y) # y: integer label tensor
65+
cluster_assignments = spectralnet.predict(X)
66+
67+
y_np = y.detach().cpu().numpy()
68+
acc_score = Metrics.acc_score(cluster_assignments, y_np, n_clusters=2)
69+
nmi_score = Metrics.nmi_score(cluster_assignments, y_np)
70+
print(f"ACC: {acc_score:.3f} NMI: {nmi_score:.3f}")
71+
```
72+
73+
### Clustering — large datasets (streaming from disk)
74+
75+
For datasets too large to hold in RAM (e.g. millions of images on disk),
76+
define a `torch.utils.data.Dataset` that loads **one sample at a time**
77+
and pass it to `fit()`. Nothing large ever lives in memory at once — every
78+
trainer pulls mini-batches through its own `DataLoader` internally.
79+
80+
```python
81+
from torch.utils.data import Dataset, DataLoader
82+
from spectralnet import SpectralNet
83+
from PIL import Image
84+
import torchvision.transforms as T
85+
import os
86+
87+
class ImageFolderDataset(Dataset):
88+
def __init__(self, root):
89+
self.paths = [
90+
os.path.join(root, f) for f in os.listdir(root) if f.endswith(".jpg")
91+
]
92+
self.transform = T.Compose([T.Resize(64), T.ToTensor(), T.Normalize(0.5, 0.5)])
93+
94+
def __len__(self):
95+
return len(self.paths)
96+
97+
def __getitem__(self, idx):
98+
return self.transform(Image.open(self.paths[idx]).convert("RGB"))
99+
100+
dataset = ImageFolderDataset("/path/to/images")
101+
102+
spectralnet = SpectralNet(
103+
n_clusters=10,
104+
should_use_ae=True, # compress images before clustering
105+
ae_hiddens=[2048, 512, 64, 10],
106+
spectral_hiddens=[512, 512, 10],
107+
)
108+
spectralnet.fit(dataset)
109+
110+
# predict() also accepts a DataLoader for large test sets
111+
test_loader = DataLoader(dataset, batch_size=512, shuffle=False)
112+
cluster_assignments = spectralnet.predict(test_loader)
47113
```
48114

49-
You can read the code docs for more information and functionalities<br>
115+
> **Note on Siamese training with large datasets:** the Siamese network
116+
> builds exact k-NN pairs, which requires loading all features into memory.
117+
> For very large datasets either disable it (`should_use_siamese=False`),
118+
> enable approximate neighbours (`siamese_use_approx=True`), or pass a
119+
> representative subset as the Dataset.
50120
51-
#### Running examples
121+
### Running examples
52122

53-
In order to run the model on twomoons or MNIST datasets, you should first cd to the examples folder and then run:<br>
54-
`python3 cluster_twomoons.py`<br>
55-
or<br>
56-
`python3 cluster_mnist.py`
123+
```bash
124+
cd examples
125+
python3 cluster_twomoons.py
126+
python3 cluster_mnist.py
127+
```
57128

58129
<!-- ### Data reduction and visualization
59130

0 commit comments

Comments
 (0)