diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..057f9e6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +node_modules/ +cache/ +out/ +broadcast/ +__pycache__/ +*.pyc +__pycache__/ +*.pyc diff --git a/foundry.toml b/foundry.toml new file mode 100644 index 0000000..f88acd4 --- /dev/null +++ b/foundry.toml @@ -0,0 +1,10 @@ +[profile.default] +src = "contracts" +out = "out" +test = "test" +libs = ["lib"] +solc_version = "0.8.24" +optimizer = true +optimizer_runs = 200 + +# Install OpenZeppelin for ReentrancyGuard diff --git a/scripts/upload_and_mint.py b/scripts/upload_and_mint.py new file mode 100644 index 0000000..bb5eaff --- /dev/null +++ b/scripts/upload_and_mint.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Upload AI model weights to IPFS/Arweave and mint ERC-721 AI token in one flow. + +Usage: + # Upload to IPFS + python scripts/upload_and_mint.py --weights model.onnx --name "My Model" --storage ipfs + + # Upload to Arweave + python scripts/upload_and_mint.py --weights model.onnx --name "My Model" --storage arweave + + # Upload to both (recommended for redundancy) + python scripts/upload_and_mint.py --weights model.onnx --name "My Model" --storage both + + # Dry run (just generate metadata, no upload) + python scripts/upload_and_mint.py --weights model.onnx --name "My Model" --architecture "ResNet-50" --dataset-hash abc123 --dry-run + +Requires: + - ipfs-http-client (pip install ipfshttpclient) + - arweave-python (pip install arweave) + - web3 (pip install web3) +""" + +import argparse +import hashlib +import json +import logging +import os +import sys +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Optional + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +logger = logging.getLogger(__name__) + + +@dataclass +class ModelMetadata: + """ERC-721 AI token metadata as specified in issue #4. + + tokenURI should point to JSON with: + - model hash (SHA-256) + - storage CID (IPFS) or txn ID (Arweave) + - architecture description + - training dataset hash + """ + name: str + description: str + model_hash_sha256: str + storage_cid: str # IPFS CID or Arweave TX ID + storage_type: str # "ipfs" or "arweave" + architecture: str + training_dataset_hash: str + version: str = "1.0.0" + + def to_token_uri_json(self) -> str: + """Generate tokenURI JSON content.""" + return json.dumps(asdict(self), indent=2) + + +def compute_sha256(file_path: str) -> str: + """Compute SHA-256 hash of a file.""" + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + + +# ── IPFS Upload ──────────────────────────────────────────────────────────────── + +def upload_to_ipfs(file_path: str) -> str: + """Upload file to IPFS and return the CID. + + Requires a running IPFS node or pinning service. + """ + try: + import ipfshttpclient + except ImportError: + logger.error("ipfshttpclient not installed. Run: pip install ipfshttpclient") + sys.exit(1) + + logger.info(f"Uploading {file_path} to IPFS...") + + try: + with ipfshttpclient.connect() as client: + result = client.add(file_path) + cid = result["Hash"] + logger.info(f"IPFS upload complete. CID: {cid}") + return cid + except Exception as e: + logger.error(f"IPFS upload failed: {e}") + logger.info("Make sure IPFS daemon is running: ipfs daemon") + sys.exit(1) + + +# ── Arweave Upload ───────────────────────────────────────────────────────────── + +def upload_to_arweave(file_path: str, wallet_path: Optional[str] = None) -> str: + """Upload file to Arweave and return the transaction ID. + + Requires an Arweave wallet (keyfile JSON). + """ + try: + from arweave.arweave_lib import Wallet, Transaction + except ImportError: + logger.error("arweave not installed. Run: pip install arweave") + sys.exit(1) + + if not wallet_path: + wallet_path = os.environ.get("ARWEAVE_WALLET_PATH") + if not wallet_path: + logger.error("Arweave wallet path required. Set ARWEAVE_WALLET_PATH or pass --wallet") + sys.exit(1) + + logger.info(f"Uploading {file_path} to Arweave...") + + try: + wallet = Wallet(wallet_path) + with open(file_path, "rb") as f: + data = f.read() + + tx = Transaction(wallet, data=data) + tx.add_tag("Content-Type", "application/octet-stream") + tx.add_tag("App-Name", "ERC721-AI-Weights") + tx.sign() + tx.send() + + logger.info(f"Arweave upload complete. TX: {tx.id}") + return tx.id + except Exception as e: + logger.error(f"Arweave upload failed: {e}") + sys.exit(1) + + +# ── Metadata Upload ──────────────────────────────────────────────────────────── + +def upload_metadata_to_ipfs(metadata: ModelMetadata) -> str: + """Upload metadata JSON to IPFS and return CID for tokenURI.""" + try: + import ipfshttpclient + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(metadata.to_token_uri_json()) + meta_path = f.name + + with ipfshttpclient.connect() as client: + result = client.add(meta_path) + cid = result["Hash"] + os.unlink(meta_path) + return f"ipfs://{cid}" + except Exception as e: + logger.warning(f"Could not upload metadata to IPFS: {e}") + # Return data URI as fallback + import base64 + encoded = base64.b64encode(metadata.to_token_uri_json().encode()).decode() + return f"data:application/json;base64,{encoded}" + + +# ── Main Flow ────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="Upload AI model weights and mint ERC-721 token") + parser.add_argument("--weights", required=True, help="Path to model weights file") + parser.add_argument("--name", required=True, help="Model name") + parser.add_argument("--description", default="", help="Model description") + parser.add_argument("--architecture", required=True, help="Architecture description (e.g., 'ResNet-50, PyTorch')") + parser.add_argument("--dataset-hash", required=True, help="SHA-256 hash of training dataset") + parser.add_argument("--storage", choices=["ipfs", "arweave", "both"], default="ipfs", + help="Storage backend (default: ipfs)") + parser.add_argument("--wallet", help="Arweave wallet keyfile path") + parser.add_argument("--dry-run", action="store_true", help="Skip actual upload, just generate metadata") + + args = parser.parse_args() + + # Step 1: Compute model hash + logger.info("Computing model SHA-256 hash...") + model_hash = compute_sha256(args.weights) + logger.info(f"Model hash: {model_hash}") + + # Step 2: Upload weights + storage_cid = "" + storage_type = args.storage + + if args.dry_run: + storage_cid = "QmDRUMTQcVYUFPGn466uEtiGC8jU7bjhMiR7Y3iDSqTTNn" + logger.info(f"[DRY RUN] Would upload to {storage_type}") + elif args.storage in ("ipfs", "both"): + storage_cid = upload_to_ipfs(args.weights) + storage_type = "ipfs" + elif args.storage == "arweave": + storage_cid = upload_to_arweave(args.weights, args.wallet) + storage_type = "arweave" + + if args.storage == "both" and not args.dry_run: + # Also upload to Arweave for redundancy + ar_tx = upload_to_arweave(args.weights, args.wallet) + logger.info(f"Redundant copy on Arweave: {ar_tx}") + + # Step 3: Build metadata + metadata = ModelMetadata( + name=args.name, + description=args.description, + model_hash_sha256=model_hash, + storage_cid=storage_cid, + storage_type=storage_type, + architecture=args.architecture, + training_dataset_hash=args.dataset_hash, + ) + + logger.info("Generated metadata:") + print(metadata.to_token_uri_json()) + + # Step 4: Upload metadata to IPFS for tokenURI + if not args.dry_run: + token_uri = upload_metadata_to_ipfs(metadata) + logger.info(f"tokenURI: {token_uri}") + else: + logger.info("[DRY RUN] Would upload metadata to IPFS for tokenURI") + + +if __name__ == "__main__": + main() diff --git a/test/ERC721AIAttestationHook.t.sol b/test/ERC721AIAttestationHook.t.sol new file mode 100644 index 0000000..14a6b70 --- /dev/null +++ b/test/ERC721AIAttestationHook.t.sol @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import "forge-std/Test.sol"; +import "../contracts/ERC721AIAttestationHook.sol"; +import "../contracts/mocks/MockTrainingAttestationVerifier.sol"; + +contract ERC721AIAttestationHookTest is Test { + ERC721AIAttestationHook public hook; + MockTrainingAttestationVerifier public mockVerifier; + address public owner; + address public other; + bytes32 constant ATTESTATION_KIND = keccak256("zk-tee"); + + function setUp() public { + owner = address(this); + other = makeAddr("other"); + + mockVerifier = new MockTrainingAttestationVerifier(); + hook = new ERC721AIAttestationHook(owner); + } + + // ── Deployment ────────────────────────────────────────────────────── + + function test_SetOwnerOnDeploy() public view { + assertEq(hook.owner(), owner); + } + + // ── setAttestationVerifier ───────────────────────────────────────── + + function test_ConfiguresVerifier() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + assertEq(hook.attestationVerifiers(ATTESTATION_KIND), address(mockVerifier)); + } + + function test_EmitsVerifierConfigured() public { + vm.expectEmit(true, true, false, false); + emit ERC721AIAttestationHook.AttestationVerifierConfigured(ATTESTATION_KIND, address(mockVerifier)); + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + } + + function test_RevertWhenNonOwnerSetsVerifier() public { + vm.prank(other); + vm.expectRevert(ERC721AIAttestationHook.NotOwner.selector); + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + } + + function test_RevertWhenZeroAddressVerifier() public { + vm.expectRevert(ERC721AIAttestationHook.ZeroAddressVerifier.selector); + hook.setAttestationVerifier(ATTESTATION_KIND, address(0)); + } + + function test_CanUpdateExistingVerifier() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + hook.setAttestationVerifier(ATTESTATION_KIND, other); + assertEq(hook.attestationVerifiers(ATTESTATION_KIND), other); + } + + // ── registerAndVerifyAttestation ──────────────────────────────────── + + function test_RegistersVerifiedAttestation() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + + bytes32 modelId = keccak256("model-1"); + bytes32 artifactHash = keccak256("artifact-1"); + bytes memory attestationData = "test-attestation"; + + mockVerifier.setApproval(modelId, artifactHash, attestationData, true); + + hook.registerAndVerifyAttestation(1, modelId, artifactHash, ATTESTATION_KIND, attestationData); + + ( + bytes32 storedModelId, + bytes32 storedArtifactHash, + bytes32 storedAttestationHash, + bytes32 storedAttestationKind, + address storedVerifier, + uint64 storedVerifiedAt + ) = hook.attestationsByTokenId(1); + + assertEq(storedModelId, modelId); + assertEq(storedArtifactHash, artifactHash); + assertEq(storedAttestationKind, ATTESTATION_KIND); + assertEq(storedVerifier, address(mockVerifier)); + assertGt(storedVerifiedAt, 0); + assertEq(storedAttestationHash, keccak256(attestationData)); + } + + function test_RevertWhenVerifierNotConfigured() public { + bytes32 unknownKind = keccak256("unknown"); + vm.expectRevert(abi.encodeWithSelector(ERC721AIAttestationHook.MissingVerifier.selector, unknownKind)); + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), unknownKind, "data"); + } + + function test_RevertWhenVerificationFails() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + vm.expectRevert(ERC721AIAttestationHook.AttestationVerificationFailed.selector); + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), ATTESTATION_KIND, "bad"); + } + + function test_WorksWithAcceptAllMode() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + mockVerifier.setAcceptAll(true); + + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), ATTESTATION_KIND, "any"); + } + + function test_EmitsAttestationVerified() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + mockVerifier.setAcceptAll(true); + + bytes32 modelId = keccak256("m"); + bytes32 artifactHash = keccak256("a"); + bytes memory attestationData = "data"; + bytes32 attHash = keccak256(attestationData); + + vm.expectEmit(true, true, true, false); + emit ERC721AIAttestationHook.TrainingAttestationVerified(1, modelId, artifactHash, ATTESTATION_KIND, address(mockVerifier), attHash); + hook.registerAndVerifyAttestation(1, modelId, artifactHash, ATTESTATION_KIND, attestationData); + } + + function test_CanOverwriteAttestation() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + mockVerifier.setAcceptAll(true); + + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), ATTESTATION_KIND, "first"); + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), ATTESTATION_KIND, "second"); + + (,,, , , uint64 verifiedAt) = hook.attestationsByTokenId(1); + assertGt(verifiedAt, 0); + } + + function test_AnyoneCanRegisterIfVerifierSet() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + mockVerifier.setAcceptAll(true); + + vm.prank(other); + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), ATTESTATION_KIND, "data"); + } + + function test_MultipleTokenAttestations() public { + hook.setAttestationVerifier(ATTESTATION_KIND, address(mockVerifier)); + mockVerifier.setAcceptAll(true); + + hook.registerAndVerifyAttestation(1, keccak256("m1"), keccak256("a1"), ATTESTATION_KIND, "d1"); + hook.registerAndVerifyAttestation(2, keccak256("m2"), keccak256("a2"), ATTESTATION_KIND, "d2"); + + (bytes32 m1,,,,,) = hook.attestationsByTokenId(1); + (bytes32 m2,,,,,) = hook.attestationsByTokenId(2); + assertEq(m1, keccak256("m1")); + assertEq(m2, keccak256("m2")); + } + + function test_MultipleAttestationKinds() public { + bytes32 kind1 = keccak256("zk-tee"); + bytes32 kind2 = keccak256("sgx"); + + MockTrainingAttestationVerifier verifier2 = new MockTrainingAttestationVerifier(); + verifier2.setAcceptAll(true); + + hook.setAttestationVerifier(kind1, address(mockVerifier)); + hook.setAttestationVerifier(kind2, address(verifier2)); + + mockVerifier.setAcceptAll(true); + + hook.registerAndVerifyAttestation(1, keccak256("m"), keccak256("a"), kind1, "d1"); + hook.registerAndVerifyAttestation(2, keccak256("m"), keccak256("a"), kind2, "d2"); + } +} \ No newline at end of file diff --git a/test/ERC721AIAttestationHook.test.js b/test/ERC721AIAttestationHook.test.js new file mode 100644 index 0000000..eae9ea5 --- /dev/null +++ b/test/ERC721AIAttestationHook.test.js @@ -0,0 +1,114 @@ +import { expect } from "chai"; +import { ethers } from "hardhat"; + +describe("ERC721AIAttestationHook", function () { + let hook, mockVerifier; + let owner, other, verifierAddr; + + const ATTESTATION_KIND = ethers.encodeBytes32String("zk-tee"); + + beforeEach(async function () { + [owner, other] = await ethers.getSigners(); + + const MockVerifier = await ethers.getContractFactory("MockTrainingAttestationVerifier"); + mockVerifier = await MockVerifier.deploy(); + await mockVerifier.waitForDeployment(); + verifierAddr = await mockVerifier.getAddress(); + + const Hook = await ethers.getContractFactory("ERC721AIAttestationHook"); + hook = await Hook.deploy(owner.address); + await hook.waitForDeployment(); + }); + + describe("Deployment", function () { + it("sets owner on deploy", async function () { + expect(await hook.owner()).to.equal(owner.address); + }); + }); + + describe("setAttestationVerifier", function () { + it("configures a verifier for an attestation kind", async function () { + await expect(hook.setAttestationVerifier(ATTESTATION_KIND, verifierAddr)) + .to.emit(hook, "AttestationVerifierConfigured") + .withArgs(ATTESTATION_KIND, verifierAddr); + expect(await hook.attestationVerifiers(ATTESTATION_KIND)).to.equal(verifierAddr); + }); + + it("reverts when called by non-owner", async function () { + await expect( + hook.connect(other).setAttestationVerifier(ATTESTATION_KIND, verifierAddr) + ).to.be.revertedWithCustomError(hook, "NotOwner"); + }); + + it("reverts with zero address verifier", async function () { + await expect( + hook.setAttestationVerifier(ATTESTATION_KIND, ethers.ZeroAddress) + ).to.be.revertedWithCustomError(hook, "ZeroAddressVerifier"); + }); + }); + + describe("registerAndVerifyAttestation", function () { + const tokenId = 1; + const modelId = ethers.encodeBytes32String("model-1"); + const artifactHash = ethers.encodeBytes32String("artifact-1"); + const attestationData = ethers.toUtf8Bytes("test-attestation"); + + beforeEach(async function () { + await hook.setAttestationVerifier(ATTESTATION_KIND, verifierAddr); + }); + + it("registers a verified attestation", async function () { + await mockVerifier.setApproval(modelId, artifactHash, attestationData, true); + await expect( + hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, attestationData) + ).to.emit(hook, "TrainingAttestationVerified"); + const att = await hook.attestationsByTokenId(tokenId); + expect(att.modelId).to.equal(modelId); + expect(att.artifactHash).to.equal(artifactHash); + expect(att.verifier).to.equal(verifierAddr); + }); + + it("reverts when verifier not configured", async function () { + const unknownKind = ethers.encodeBytes32String("unknown"); + await expect( + hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, unknownKind, attestationData) + ).to.be.revertedWithCustomError(hook, "MissingVerifier"); + }); + + it("reverts when verification fails", async function () { + await expect( + hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, attestationData) + ).to.be.revertedWithCustomError(hook, "AttestationVerificationFailed"); + }); + + it("works with acceptAll mode", async function () { + await mockVerifier.setAcceptAll(true); + await expect( + hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, attestationData) + ).to.emit(hook, "TrainingAttestationVerified"); + }); + + it("stores correct attestation hash", async function () { + await mockVerifier.setAcceptAll(true); + await hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, attestationData); + const att = await hook.attestationsByTokenId(tokenId); + expect(att.attestationHash).to.equal(ethers.keccak256(attestationData)); + }); + + it("allows overwriting attestation for same token", async function () { + await mockVerifier.setAcceptAll(true); + await hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, attestationData); + const newData = ethers.toUtf8Bytes("updated"); + await hook.registerAndVerifyAttestation(tokenId, modelId, artifactHash, ATTESTATION_KIND, newData); + const att = await hook.attestationsByTokenId(tokenId); + expect(att.attestationHash).to.equal(ethers.keccak256(newData)); + }); + + it("anyone can register if verifier configured", async function () { + await mockVerifier.setAcceptAll(true); + await expect( + hook.connect(other).registerAndVerifyAttestation(1, modelId, artifactHash, ATTESTATION_KIND, attestationData) + ).to.emit(hook, "TrainingAttestationVerified"); + }); + }); +}); diff --git a/tests/test_upload_and_mint.py b/tests/test_upload_and_mint.py new file mode 100644 index 0000000..ad8a2e1 --- /dev/null +++ b/tests/test_upload_and_mint.py @@ -0,0 +1,108 @@ +"""Tests for upload_and_mint helper script.""" + +import json +import os +import tempfile +import pytest + +from scripts.upload_and_mint import ModelMetadata, compute_sha256 + + +class TestModelMetadata: + def test_to_token_uri_json_contains_required_fields(self): + meta = ModelMetadata( + name="TestModel", + description="A test model", + model_hash_sha256="abc123", + storage_cid="QmTest", + storage_type="ipfs", + architecture="ResNet-50", + training_dataset_hash="dataset123", + ) + data = json.loads(meta.to_token_uri_json()) + + assert data["name"] == "TestModel" + assert data["model_hash_sha256"] == "abc123" + assert data["storage_cid"] == "QmTest" + assert data["storage_type"] == "ipfs" + assert data["architecture"] == "ResNet-50" + assert data["training_dataset_hash"] == "dataset123" + + def test_to_token_uri_json_has_all_issue_4_fields(self): + """Issue #4 requires: model hash, storage CID, architecture, dataset hash.""" + meta = ModelMetadata( + name="M", + description="D", + model_hash_sha256="sha256hash", + storage_cid="bTx4r9...arweave", + storage_type="arweave", + architecture="LLaMA-7B, transformers", + training_dataset_hash="ds_hash_256", + ) + data = json.loads(meta.to_token_uri_json()) + + assert "model_hash_sha256" in data + assert "storage_cid" in data + assert "architecture" in data + assert "training_dataset_hash" in data + + def test_default_version(self): + meta = ModelMetadata( + name="X", description="", model_hash_sha256="", storage_cid="", + storage_type="ipfs", architecture="", training_dataset_hash="", + ) + data = json.loads(meta.to_token_uri_json()) + assert data["version"] == "1.0.0" + + def test_arweave_storage_type(self): + meta = ModelMetadata( + name="X", description="", model_hash_sha256="", storage_cid="ar_tx_id", + storage_type="arweave", architecture="", training_dataset_hash="", + ) + data = json.loads(meta.to_token_uri_json()) + assert data["storage_type"] == "arweave" + assert data["storage_cid"] == "ar_tx_id" + + +class TestComputeSHA256: + def test_correct_hash(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"hello world") + path = f.name + + result = compute_sha256(path) + # SHA-256 of "hello world" is well-known + assert result == "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9" + os.unlink(path) + + def test_empty_file_hash(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + path = f.name + + result = compute_sha256(path) + # SHA-256 of empty string + assert result == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + os.unlink(path) + + def test_large_file_hash(self): + """Test that large files are hashed correctly (chunked reading).""" + with tempfile.NamedTemporaryFile(delete=False) as f: + # Write 1MB of data + f.write(b"x" * (1024 * 1024)) + path = f.name + + result = compute_sha256(path) + assert len(result) == 64 # SHA-256 hex length + os.unlink(path) + + def test_different_files_different_hashes(self): + with tempfile.NamedTemporaryFile(delete=False) as f1: + f1.write(b"file1") + path1 = f1.name + with tempfile.NamedTemporaryFile(delete=False) as f2: + f2.write(b"file2") + path2 = f2.name + + assert compute_sha256(path1) != compute_sha256(path2) + os.unlink(path1) + os.unlink(path2)