Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
/coti_sdk.egg-info/

/build/
.venv/
__pycache__/
*.pyc
Comparisson.md
170 changes: 168 additions & 2 deletions coti/crypto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from eth_keys import keys
from .types import CtString, CtUint, ItString, ItUint
from .types import CtUint256, ItUint256, CtString, CtUint, ItString, ItUint

block_size = AES.block_size
address_size = 20
Expand Down Expand Up @@ -94,6 +94,46 @@ def sign_input_text(sender_address: str, contract_address: str, function_selecto
return sign(message, key)


def sign_input_text_256(sender_address: str, contract_address: str, function_selector: str, ct, key):
"""
Sign input text for 256-bit encrypted values.

Similar to sign_input_text but accepts 64-byte ciphertext (CtUint256 blob).

Args:
sender_address: Address of the sender (20 bytes, without 0x prefix)
contract_address: Address of the contract (20 bytes, without 0x prefix)
function_selector: Function selector (hex string with 0x prefix, e.g., '0x12345678')
ct: Ciphertext bytes (must be 64 bytes for uint256)
key: Signing key (32 bytes)

Returns:
bytes: The signature

Raises:
ValueError: If any input has invalid length
"""
function_selector_bytes = bytes.fromhex(function_selector[2:])

if len(sender_address) != address_size:
raise ValueError(f"Invalid sender address length: {len(sender_address)} bytes, must be {address_size} bytes")
if len(contract_address) != address_size:
raise ValueError(f"Invalid contract address length: {len(contract_address)} bytes, must be {address_size} bytes")
if len(function_selector_bytes) != function_selector_size:
raise ValueError(f"Invalid signature size: {len(function_selector_bytes)} bytes, must be {function_selector_size} bytes")

# 256-bit IT has 64 bytes CT
if len(ct) != 64:
raise ValueError(f"Invalid ct length: {len(ct)} bytes, must be 64 bytes for uint256")

if len(key) != key_size:
raise ValueError(f"Invalid key length: {len(key)} bytes, must be {key_size} bytes")

message = sender_address + contract_address + function_selector_bytes + ct

return sign(message, key)


def sign(message, key):
# Sign the message
pk = keys.PrivateKey(key)
Expand Down Expand Up @@ -122,7 +162,7 @@ def build_input_text(plaintext: int, user_aes_key: str, sender_address: str, con
}


def build_string_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItString:
def build_string_input_text(plaintext: str, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItString:
input_text = {
'ciphertext': {
'value': []
Expand Down Expand Up @@ -169,6 +209,130 @@ def decrypt_uint(ciphertext: CtUint, user_aes_key: str) -> int:
return decrypted_uint


def create_ciphertext_256(plaintext_int: int, user_aes_key_bytes: bytes) -> bytes:
"""
Create a 256-bit ciphertext by encrypting high and low 128-bit parts separately.

Args:
plaintext_int: Integer value to encrypt (must fit in 256 bits)
user_aes_key_bytes: AES encryption key (16 bytes)

Returns:
bytes: 64-byte ciphertext blob formatted as:
[high_ciphertext(16) | high_r(16) | low_ciphertext(16) | low_r(16)]

Raises:
ValueError: If plaintext exceeds 256 bits
"""
# Convert 256-bit int to 32 bytes (Big Endian)
try:
plaintext_bytes = plaintext_int.to_bytes(32, 'big')
except OverflowError:
raise ValueError("Plaintext size must be 256 bits or smaller.")

# Split into High and Low 128-bit parts
high_bytes = plaintext_bytes[:16]
low_bytes = plaintext_bytes[16:]

# Encrypt High
high_ct, high_r = encrypt(user_aes_key_bytes, high_bytes)

# Encrypt Low
low_ct, low_r = encrypt(user_aes_key_bytes, low_bytes)

# Construct format: high.ciphertext + high.r + low.ciphertext + low.r
return high_ct + high_r + low_ct + low_r


def prepare_it_256(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItUint256:
"""
Prepare Input Text for a 256-bit encrypted integer.

Encrypts a 256-bit integer and creates a signed Input Text structure
suitable for smart contract interaction.

Args:
plaintext: Integer value to encrypt (must fit in 256 bits)
user_aes_key: AES encryption key (hex string without 0x prefix, 32 hex chars)
sender_address: Address of the sender (hex string with 0x prefix, 40 hex chars)
contract_address: Address of the contract (hex string with 0x prefix, 40 hex chars)
function_selector: Function selector (hex string with 0x prefix, e.g., '0x12345678')
signing_key: Private key for signing (32 bytes)

Returns:
ItUint256: Dictionary containing ciphertext (ciphertextHigh, ciphertextLow) and signature

Raises:
ValueError: If plaintext exceeds 256 bits
"""
if plaintext.bit_length() > 256:
raise ValueError("Plaintext size must be 256 bits or smaller.")

user_aes_key_bytes = bytes.fromhex(user_aes_key)

ct_blob = create_ciphertext_256(plaintext, user_aes_key_bytes)

# Split for types
# ct_blob is 64 bytes: [high_ct(16) | high_r(16) | low_ct(16) | low_r(16)]
high_blob = ct_blob[:32]
low_blob = ct_blob[32:]

# Sign the full 64-byte blob
signature = sign_input_text_256(
bytes.fromhex(sender_address[2:]),
bytes.fromhex(contract_address[2:]),
function_selector,
ct_blob,
signing_key
)

return {
'ciphertext': {
'ciphertextHigh': int.from_bytes(high_blob, 'big'),
'ciphertextLow': int.from_bytes(low_blob, 'big')
},
'signature': signature
}


def decrypt_uint256(ciphertext: CtUint256, user_aes_key: str) -> int:
"""
Decrypt a 256-bit encrypted integer.

Decrypts both high and low 128-bit parts and combines them back into
a single 256-bit integer.

Args:
ciphertext: CtUint256 dictionary containing ciphertextHigh and ciphertextLow
user_aes_key: AES decryption key (hex string without 0x prefix, 32 hex chars)

Returns:
int: The decrypted 256-bit integer value
"""
user_aes_key_bytes = bytes.fromhex(user_aes_key)

# Process High
ct_high_int = ciphertext['ciphertextHigh']
ct_high_bytes = ct_high_int.to_bytes(32, 'big')
cipher_high = ct_high_bytes[:block_size]
r_high = ct_high_bytes[block_size:]

plaintext_high = decrypt(user_aes_key_bytes, r_high, cipher_high)

# Process Low
ct_low_int = ciphertext['ciphertextLow']
ct_low_bytes = ct_low_int.to_bytes(32, 'big')
cipher_low = ct_low_bytes[:block_size]
r_low = ct_low_bytes[block_size:]

plaintext_low = decrypt(user_aes_key_bytes, r_low, cipher_low)

# Combine back to 256-bit int
# High part is MSB
full_bytes = plaintext_high + plaintext_low
return int.from_bytes(full_bytes, 'big')


def decrypt_string(ciphertext: CtString, user_aes_key: str) -> str:
if 'value' in ciphertext or hasattr(ciphertext, 'value'): # format when reading ciphertext from an event
__ciphertext = ciphertext['value']
Expand Down Expand Up @@ -231,6 +395,7 @@ def decrypt_rsa(private_key_bytes: bytes, ciphertext: bytes):
)
return plaintext


#This function recovers a user's key by decrypting two encrypted key shares with the given private key,
#and then XORing the two key shares together.
def recover_user_key(private_key_bytes: bytes, encrypted_key_share0: bytes, encrypted_key_share1: bytes):
Expand All @@ -239,3 +404,4 @@ def recover_user_key(private_key_bytes: bytes, encrypted_key_share0: bytes, encr

# XOR both key shares to get the user key
return bytes([a ^ b for a, b in zip(key_share0, key_share1)])

17 changes: 16 additions & 1 deletion coti/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,19 @@ class ItStringCiphertext(TypedDict):

class ItString(TypedDict):
ciphertext: ItStringCiphertext
signature: List[bytes]
signature: List[bytes]


class CtUint256(TypedDict):
ciphertextHigh: int
ciphertextLow: int


class ItUint256Ciphertext(TypedDict):
ciphertextHigh: int
ciphertextLow: int


class ItUint256(TypedDict):
ciphertext: ItUint256Ciphertext
signature: bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import os
from eth_keys import keys
from coti.crypto_utils import generate_aes_key

@pytest.fixture
def user_key():
# Return hex string of the key
return os.environ.get("TEST_USER_KEY") or generate_aes_key().hex()

@pytest.fixture
def private_key_bytes():
pk_hex = os.environ.get("TEST_PRIVATE_KEY")
if pk_hex:
# Handle 0x prefix if present
clean_hex = pk_hex[2:] if pk_hex.startswith("0x") else pk_hex
return bytes.fromhex(clean_hex)
return os.urandom(32)

@pytest.fixture
def sender_address(private_key_bytes):
# Derive address from private key to ensure consistency in signing checks if needed
pk = keys.PrivateKey(private_key_bytes)
return pk.public_key.to_checksum_address()

@pytest.fixture
def contract_address():
# Dummy contract address
return "0x1000000000000000000000000000000000000001"

@pytest.fixture
def function_selector():
# Dummy function selector
return "0x11223344"
Loading