Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
546 changes: 546 additions & 0 deletions candidate_formats/archiveformat.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions candidate_formats/benchmark_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
N_res,npz_write,npz_size,hdf5_write,hdf5_size,zarr_write,zarr_size,npz_full_read,hdf5_full_read,zarr_full_read,npz_layer_read,hdf5_layer_read,zarr_layer_read,hdf5_head_read,zarr_head_read,zarr_random,hdf5_random,viz_hdf5,viz_zarr,zarr_parallel,hdf5_parallel,npz_write_MBps,hdf5_write_MBps,zarr_write_MBps,useful_bytes,hdf5_efficiency,zarr_efficiency
256,20.959038972854614,400619395,0.4188868999481201,432569736,3.2118980884552,387861993,0.0011882781982421875,0.0011639595031738281,0.0006070137023925781,0.008630990982055664,0.0006988048553466797,0.04167675971984863,0.00026297569274902344,0.004118919372558594,0.14974498748779297,0.029961109161376953,0.0058100223541259766,0.029222965240478516,0.09261107444763184,0.037297964096069336,19.114397159090533,1032.6647504459427,120.7578765945674,4096,15575584.029011786,994435.5860152813
512,83.98201513290405,1593689281,0.6405029296875,1720749448,10.872176885604858,1542865571,0.0007750988006591797,0.0013039112091064453,0.0003631114959716797,0.03376269340515137,0.004105806350708008,0.16460800170898438,0.0004391670227050781,0.012581825256347656,0.24428105354309082,0.0844881534576416,0.015540122985839844,0.05322766304016113,0.08516788482666016,0.0678720474243164,18.976554426301142,2686.559839530398,141.90953543469365,4096,9326747.65689468,325548.949897673
1024,348.0411250591278,6357243010,4.268572807312012,6864022920,44.16462802886963,6154440006,0.004289865493774414,0.0027647018432617188,0.00041222572326660156,0.14620518684387207,0.025741100311279297,1.0825958251953125,0.002031087875366211,0.0656440258026123,0.38530993461608887,0.13773298263549805,0.031294822692871094,0.09022998809814453,0.11237192153930664,0.08473801612854004,18.26578111687228,1608.036978599033,139.35224365474,4096,2016653.2672848925,62397.14810173936
2048,1411.5884323120117,25393961662,281.561222076416,27418226056,351.5809578895569,24585016703,0.011762857437133789,0.016555070877075195,0.0022859573364257812,0.556689977645874,0.09723091125488281,3.3311328887939453,0.0060617923736572266,0.21351218223571777,0.42122387886047363,0.2177579402923584,0.05145597457885742,0.10213494300842285,0.0932619571685791,0.1259291172027588,17.98963570451463,97.37926925377054,69.92704283695296,4096,675707.7358505408,19183.917081967764
13 changes: 13 additions & 0 deletions candidate_formats/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
numpy>=1.24

pandas>=2.0

matplotlib>=3.7

h5py>=3.10

zarr<3

numcodecs>=0.11

tqdm>=4.65
17 changes: 17 additions & 0 deletions standardizedarchive/MNIST/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# datasets
data/

# generated archives
*.zarr

# virtual environments
zarr_env/
venv/
.env/

# python cache
__pycache__/
*.pyc

# OS files
.DS_Store
45 changes: 45 additions & 0 deletions standardizedarchive/MNIST/mnist_trace_project/archive/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import zarr
from numcodecs import Blosc

def create_archive(path, dataset_size):
compressor = Blosc(cname='zstd', clevel=5, shuffle=Blosc.SHUFFLE)

root = zarr.open(path, mode="w")

root.create_dataset(
"inputs/images",
shape=(dataset_size, 1, 28, 28),
chunks=(64, 1, 28, 28),
dtype="float32",
compressor=compressor,
)

root.create_dataset(
"outputs/logits",
shape=(dataset_size, 10),
chunks=(64, 10),
dtype="float32",
compressor=compressor,
)

root.create_dataset(
"outputs/predictions",
shape=(dataset_size,),
chunks=(64,),
dtype="int64",
compressor=compressor,
)

# Activations (example shapes — update after first forward pass if needed)
root.create_dataset(
"activations/conv1",
shape=(dataset_size, 16, 26, 26),
chunks=(64, 16, 26, 26),
dtype="float32",
compressor=compressor,
)

root.attrs["dataset"] = "MNIST"
root.attrs["archive_version"] = "v1"

return root
10 changes: 10 additions & 0 deletions standardizedarchive/MNIST/mnist_trace_project/archive/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def write_batch(root, start_idx, images, logits, preds, activations):
batch_size = images.shape[0]
end_idx = start_idx + batch_size

root["inputs/images"][start_idx:end_idx] = images
root["outputs/logits"][start_idx:end_idx] = logits
root["outputs/predictions"][start_idx:end_idx] = preds
root["activations/conv1"][start_idx:end_idx] = activations["conv1"]

return end_idx
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import time
import zarr

def benchmark(path):
root = zarr.open(path, mode="r")

start = time.time()
_ = root["activations/conv1"][500]
print("Random read time:", time.time() - start)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from models.cnn import SimpleCNN
from utils.hooks import register_hooks
from archive.schema import create_archive
from archive.writer import write_batch

def run(output_path):

transform = transforms.ToTensor()
dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64)

model = SimpleCNN()
model.eval()

activations = register_hooks(model)

root = create_archive(output_path, len(dataset))

start_idx = 0

with torch.no_grad():
for images, labels in dataloader:
logits = model(images)
preds = logits.argmax(dim=1)

start_idx = write_batch(
root,
start_idx,
images.numpy(),
logits.numpy(),
preds.numpy(),
activations
)
4 changes: 4 additions & 0 deletions standardizedarchive/MNIST/mnist_trace_project/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from inference.run_inference import run

if __name__ == "__main__":
run("mnist_trace.zarr")
19 changes: 19 additions & 0 deletions standardizedarchive/MNIST/mnist_trace_project/models/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3)
self.conv2 = nn.Conv2d(16, 32, 3)
self.fc1 = nn.Linear(32 * 24 * 24, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
14 changes: 14 additions & 0 deletions standardizedarchive/MNIST/mnist_trace_project/utils/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def register_hooks(model):
activations = {}

def save_activation(name):
def hook(model, input, output):
activations[name] = output.detach().cpu().numpy()
return hook

model.conv1.register_forward_hook(save_activation("conv1"))
model.conv2.register_forward_hook(save_activation("conv2"))
model.fc1.register_forward_hook(save_activation("fc1"))
model.fc2.register_forward_hook(save_activation("fc2"))

return activations
105 changes: 105 additions & 0 deletions standardizedarchive/MNIST/tests/test1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import zarr
import matplotlib.pyplot as plt

# root = zarr.open("mnist_trace.zarr", mode="r")

# image = root["inputs/images"][0][0] # shape: (28, 28)

# plt.imshow(image, cmap="gray")
# plt.title("MNIST Image")
# plt.axis("off")
# plt.show()
# first shows the black and white MNIST image

###########################################

# activation = root["activations/conv1"][0] # shape (16, 26, 26)

# Pick one filter (e.g., filter 0)
# feature_map = activation[0]

# plt.imshow(feature_map, cmap="viridis")
# plt.title("Conv1 Filter 0 Activation")
# plt.colorbar()
# plt.axis("off")
# plt.show()
#second shows the conv1 activation heatmap from filter 0

###########################################

import numpy as np

# activation = root["activations/conv1"][0]
# num_filters = activation.shape[0]

# fig, axes = plt.subplots(4, 4, figsize=(8, 8))

# for i, ax in enumerate(axes.flat):
# ax.imshow(activation[i], cmap="viridis")
# ax.axis("off")
# ax.set_title(f"F{i}")

# plt.tight_layout()
# plt.show()
#third shows the 16 heatmaps, one for each filter

###########################################

# plt.imshow(image, cmap="gray")
# plt.imshow(feature_map, cmap="jet", alpha=0.5)
# plt.axis("off")
# plt.show()
#fourth shows the original digit and where the filter activtes on top of it

###########################################

# Normalize activation for better visualization
# feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min())

# fig, ax = plt.subplots(figsize=(5,5))

# # Show original image
# ax.imshow(image, cmap="gray")

# # Overlay activation
# overlay = ax.imshow(feature_map, cmap="jet", alpha=0.5)

# ax.set_title("Original Digit + Conv1 Filter 0 Activation")
# ax.axis("off")

# # Add colorbar as key
# cbar = plt.colorbar(overlay, ax=ax)
# cbar.set_label("Activation Intensity")

# plt.show()

###########################################

root = zarr.open("mnist_trace.zarr", mode="r")

image = root["inputs/images"][0][0]

activation = root["activations/conv1"][0]

global_min = activation.min()
global_max = activation.max()

fig, axes = plt.subplots(4, 4, figsize=(8,8))

for i, ax in enumerate(axes.flat):
fmap = activation[i]
ax.imshow(image, cmap="gray")

im = ax.imshow(
fmap,
cmap="jet",
alpha=0.5,
vmin=global_min,
vmax=global_max
)
ax.set_title(f"Filter {i}")
ax.axis("off")

fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, label="Activation Intensity")

plt.show()
16 changes: 16 additions & 0 deletions standardizedarchive/vit/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Python
__pycache__/
*.pyc

# Virtual environments
venv/
.env/

# Zarr archives
*.zarr/

# OS files
.DS_Store

# Jupyter
.ipynb_checkpoints/
58 changes: 58 additions & 0 deletions standardizedarchive/vit/archive_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import zarr
import numpy as np

def create_vit_archive(archive_path, model_config, num_layers=12, num_heads=12, hidden_dim=768, num_tokens=197):
"""
Create a standardized ViT archive with formal metadata, layer shapes, and head counts.

Parameters:
archive_path (str): Path to the Zarr archive.
model_config (transformers.PretrainedConfig): ViT model config.
num_layers (int): Number of transformer layers.
num_heads (int): Number of attention heads.
hidden_dim (int): Hidden dimension size.
num_tokens (int): Number of tokens (patches + CLS).

Returns:
zarr.hierarchy.Group: Root Zarr archive.
"""

# Root archive
archive = zarr.open(archive_path, mode="w")

# Metadata group
meta = archive.create_group("metadata")
meta.attrs["model_name"] = getattr(model_config, "model_type", "vit-base")
meta.attrs["num_layers"] = num_layers
meta.attrs["num_heads"] = num_heads
meta.attrs["hidden_dim"] = hidden_dim
meta.attrs["num_tokens"] = num_tokens
meta.attrs["input_shape"] = (3, 224, 224)

# Placeholder for layer-wise metadata
layers_meta = meta.require_group("layers")
for i in range(num_layers):
layer_group = layers_meta.require_group(f"layer_{i}")
layer_group.attrs["hidden_shape"] = (num_tokens, hidden_dim)
layer_group.attrs["num_heads"] = num_heads
layer_group.attrs["attention_shape"] = (num_heads, num_tokens, num_tokens)

# Inputs
archive.create_group("inputs")
# Processed images will be stored later as datasets
# shape example: (batch, 3, 224, 224)

# Activations
archive.create_group("activations")
# Each layer will have a "hidden_states" dataset
# shape example: (batch, num_tokens, hidden_dim)

# Attention
archive.create_group("attention")
# Each layer will have a dataset of shape (batch, num_heads, tokens, tokens)

# Outputs
archive.create_group("outputs")
# Logits and predicted class will be stored here

return archive
21 changes: 21 additions & 0 deletions standardizedarchive/vit/run_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from utils import load_image
from trace_vit import trace_vit


IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"


def main():

image = load_image(IMAGE_URL)

archive = trace_vit(
image=image,
archive_path="vit_trace.zarr"
)

print("Trace archive created at vit_trace.zarr")


if __name__ == "__main__":
main()
Loading