Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 17, 2025

📄 2,845% (28.45x) speedup for AES.decrypt in skyvern/forge/sdk/encrypt/aes.py

⏱️ Runtime : 9.81 seconds 333 milliseconds (best of 5 runs)

📝 Explanation and details

The optimization implements key derivation caching to eliminate redundant expensive cryptographic operations. The key change is adding a _cached_key attribute that stores the derived key after the first computation.

What was optimized:

  • Added self._cached_key: bytes | None = None in __init__
  • Modified _derive_key() to return the cached key if available, otherwise compute, cache, and return it

Why this provides a massive speedup:
The PBKDF2HMAC key derivation with 100,000 iterations is computationally expensive (~10ms per call based on profiler data). In the original code, this expensive operation ran on every decrypt() call. The optimization reduces this to a one-time cost per AES instance.

Performance impact:

  • Runtime improvement: 2844% speedup (9.81s → 333ms)
  • Line profiler shows _derive_key() time dropped from 9.97s to 0.415s
  • The expensive kdf.derive() operation now only runs 35 times instead of 843 times in the test workload

When this optimization shines:
The test results show this optimization is particularly effective for:

  • Concurrent decryption scenarios (test_decrypt_concurrent_same_key, test_decrypt_many_concurrent) where the same AES instance decrypts multiple messages
  • Throughput-heavy workloads (test_AES_decrypt_throughput_* tests) with batched operations
  • Reused AES instances where the same key/salt/IV combination processes multiple ciphertexts

Thread safety note: The caching is safe since AES instances are typically not shared across threads, and the key derivation parameters (secret_key, salt, IV) are immutable after initialization.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 878 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import asyncio  # used to run async functions
import base64
import hashlib

import pytest  # used for our unit tests
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.aes import AES

# function to test
# --- Begin: skyvern/forge/sdk/encrypt/aes.py ---
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
# --- End: skyvern/forge/sdk/encrypt/aes.py ---

# --- Helper functions for tests ---

def pad(data: bytes, block_size: int = 16) -> bytes:
    """Pad data to a multiple of block_size using PKCS7 padding."""
    padding_len = block_size - (len(data) % block_size)
    return data + bytes([padding_len] * padding_len)

def aes_encrypt(plaintext: str, secret_key: str, salt: str | None = None, iv: str | None = None) -> str:
    """Helper to produce ciphertext compatible with AES.decrypt."""
    # Derive key and IV as in AES.__init__ and _derive_key
    secret_key_digest = hashlib.md5(secret_key.encode("utf-8")).digest()
    salt_digest = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
    iv_digest = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv

    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt_digest,
        iterations=100000,
    )
    key = kdf.derive(secret_key_digest)
    cipher = Cipher(algorithms.AES(key), modes.CBC(iv_digest))
    encryptor = cipher.encryptor()
    padded = pad(plaintext.encode("utf-8"))
    encrypted = encryptor.update(padded) + encryptor.finalize()
    return base64.b64encode(encrypted).decode("utf-8")

# --- Unit tests ---

# 1. Basic Test Cases

@pytest.mark.asyncio
async def test_decrypt_basic_hello_world():
    """Test decryption of a simple string."""
    secret = "testkey"
    plaintext = "hello world"
    ciphertext = aes_encrypt(plaintext, secret)
    aes = AES(secret_key=secret)
    result = await aes.decrypt(ciphertext)

@pytest.mark.asyncio
async def test_decrypt_basic_with_salt_and_iv():
    """Test decryption with custom salt and IV."""
    secret = "anotherkey"
    salt = "mysalt"
    iv = "myiv"
    plaintext = "some secret data"
    ciphertext = aes_encrypt(plaintext, secret, salt, iv)
    aes = AES(secret_key=secret, salt=salt, iv=iv)
    result = await aes.decrypt(ciphertext)

@pytest.mark.asyncio
async def test_decrypt_basic_empty_string():
    """Test decryption of an empty string."""
    secret = "emptypass"
    plaintext = ""
    ciphertext = aes_encrypt(plaintext, secret)
    aes = AES(secret_key=secret)
    result = await aes.decrypt(ciphertext)

@pytest.mark.asyncio
async def test_decrypt_basic_unicode():
    """Test decryption of a Unicode string."""
    secret = "unicodekey"
    plaintext = "你好,世界! 🌍"
    ciphertext = aes_encrypt(plaintext, secret)
    aes = AES(secret_key=secret)
    result = await aes.decrypt(ciphertext)

# 2. Edge Test Cases

@pytest.mark.asyncio
async def test_decrypt_invalid_base64():
    """Test handling of invalid base64 ciphertext."""
    secret = "key"
    aes = AES(secret_key=secret)
    with pytest.raises(Exception) as excinfo:
        await aes.decrypt("!!!notbase64!!!")

@pytest.mark.asyncio

async def test_decrypt_invalid_padding():
    """Test handling of invalid padding (tampered ciphertext)."""
    secret = "key"
    plaintext = "padme"
    ciphertext = aes_encrypt(plaintext, secret)
    # Tamper with last byte to break padding
    tampered = base64.b64decode(ciphertext.encode("utf-8"))
    tampered = tampered[:-1] + bytes([0])
    tampered_b64 = base64.b64encode(tampered).decode("utf-8")
    aes = AES(secret_key=secret)
    with pytest.raises(Exception) as excinfo:
        await aes.decrypt(tampered_b64)

@pytest.mark.asyncio
async def test_decrypt_concurrent_different_keys():
    """Test concurrent decryption with different keys."""
    secret1 = "key1"
    secret2 = "key2"
    text1 = "message one"
    text2 = "message two"
    c1 = aes_encrypt(text1, secret1)
    c2 = aes_encrypt(text2, secret2)
    aes1 = AES(secret_key=secret1)
    aes2 = AES(secret_key=secret2)
    results = await asyncio.gather(
        aes1.decrypt(c1),
        aes2.decrypt(c2),
    )

@pytest.mark.asyncio
async def test_decrypt_concurrent_same_key():
    """Test concurrent decryption with the same key and different ciphertexts."""
    secret = "sharedkey"
    texts = ["alpha", "beta", "gamma", "delta"]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))

# 3. Large Scale Test Cases

@pytest.mark.asyncio
async def test_decrypt_many_concurrent():
    """Test decryption of many ciphertexts concurrently (moderate scale)."""
    secret = "bulkkey"
    texts = [f"msg_{i}" for i in range(50)]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))

@pytest.mark.asyncio
async def test_decrypt_long_message():
    """Test decryption of a long plaintext message."""
    secret = "longkey"
    plaintext = "A" * 1000  # 1000 'A's
    ciphertext = aes_encrypt(plaintext, secret)
    aes = AES(secret_key=secret)
    result = await aes.decrypt(ciphertext)

@pytest.mark.asyncio
async def test_decrypt_max_block_boundary():
    """Test decryption of a message exactly on AES block boundary."""
    secret = "blockkey"
    block_size = 16
    plaintext = "B" * (block_size * 4)  # 64 bytes
    ciphertext = aes_encrypt(plaintext, secret)
    aes = AES(secret_key=secret)
    result = await aes.decrypt(ciphertext)

# 4. Throughput Test Cases

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_small_load():
    """Throughput test: small batch of decryptions."""
    secret = "throughputkey"
    texts = [f"small_{i}" for i in range(10)]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_medium_load():
    """Throughput test: medium batch of decryptions."""
    secret = "throughputkey2"
    texts = [f"medium_{i}" for i in range(100)]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_varied_lengths():
    """Throughput test: decrypt messages of varied lengths."""
    secret = "variedkey"
    texts = [
        "short",
        "medium" * 10,
        "long" * 100,
        "unicode 🌟" * 20,
        "",
        "edge" * 16,
    ]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_high_volume():
    """Throughput test: high volume concurrent decryptions (stress, but bounded)."""
    secret = "highvolumekey"
    texts = [f"hv_{i}" for i in range(200)]
    ciphertexts = [aes_encrypt(t, secret) for t in texts]
    aes = AES(secret_key=secret)
    results = await asyncio.gather(*(aes.decrypt(c) for c in ciphertexts))
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import asyncio  # used to run async functions
# function to test
import base64
import hashlib

import pytest
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.aes import AES

default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()

# --- Helper functions for test setup ---

def pad_pkcs7(data: bytes, block_size: int = 16) -> bytes:
    """Pads data according to PKCS7."""
    pad_len = block_size - (len(data) % block_size)
    return data + bytes([pad_len] * pad_len)

def aes_encrypt(plaintext: str, secret_key: str, salt: str | None = None, iv: str | None = None) -> str:
    """Encrypt plaintext using the same logic as AES.decrypt expects."""
    # Derive key
    key_material = hashlib.md5(secret_key.encode("utf-8")).digest()
    salt_bytes = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
    iv_bytes = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt_bytes,
        iterations=100000,
    )
    key = kdf.derive(key_material)
    cipher = Cipher(algorithms.AES(key), modes.CBC(iv_bytes))
    encryptor = cipher.encryptor()
    padded = pad_pkcs7(plaintext.encode("utf-8"))
    ciphertext = encryptor.update(padded) + encryptor.finalize()
    return base64.b64encode(ciphertext).decode("utf-8")

# --- Basic Test Cases ---

@pytest.mark.asyncio
async def test_decrypt_basic_string():
    """Test decrypting a simple string."""
    secret = "supersecret"
    plaintext = "hello world"
    aes = AES(secret_key=secret)
    encrypted = aes_encrypt(plaintext, secret)
    # Await the async decrypt function
    result = await aes.decrypt(encrypted)

@pytest.mark.asyncio
async def test_decrypt_with_custom_salt_and_iv():
    """Test decrypting with custom salt and IV."""
    secret = "anothersecret"
    salt = "mysalt"
    iv = "myiv"
    plaintext = "custom salt and iv test"
    aes = AES(secret_key=secret, salt=salt, iv=iv)
    encrypted = aes_encrypt(plaintext, secret, salt, iv)
    result = await aes.decrypt(encrypted)

@pytest.mark.asyncio
async def test_decrypt_empty_string():
    """Test decrypting an empty string."""
    secret = "emptysecret"
    plaintext = ""
    aes = AES(secret_key=secret)
    encrypted = aes_encrypt(plaintext, secret)
    result = await aes.decrypt(encrypted)

@pytest.mark.asyncio
async def test_decrypt_unicode_characters():
    """Test decrypting a string with unicode characters."""
    secret = "unicodesecret"
    plaintext = "こんにちは世界🌏"
    aes = AES(secret_key=secret)
    encrypted = aes_encrypt(plaintext, secret)
    result = await aes.decrypt(encrypted)

@pytest.mark.asyncio
async def test_decrypt_basic_async_behavior():
    """Test that decrypt returns a coroutine and can be awaited."""
    secret = "asyncsecret"
    plaintext = "async test"
    aes = AES(secret_key=secret)
    encrypted = aes_encrypt(plaintext, secret)
    # Ensure that decrypt returns a coroutine
    codeflash_output = aes.decrypt(encrypted); coro = codeflash_output
    result = await coro

# --- Edge Test Cases ---

@pytest.mark.asyncio
async def test_decrypt_invalid_base64():
    """Test decrypting invalid base64 raises an exception."""
    secret = "invalidbase64"
    aes = AES(secret_key=secret)
    with pytest.raises(Exception) as excinfo:
        await aes.decrypt("not-base64!")

@pytest.mark.asyncio
async def test_decrypt_garbage_ciphertext():
    """Test decrypting a valid base64 but random bytes raises an exception."""
    secret = "garbage"
    aes = AES(secret_key=secret)
    # base64-encode random bytes, but not a valid AES ciphertext
    garbage = base64.b64encode(b"garbagegarbagegarbagegarb").decode("utf-8")
    with pytest.raises(Exception) as excinfo:
        await aes.decrypt(garbage)

@pytest.mark.asyncio

async def test_decrypt_concurrent_execution():
    """Test concurrent decryption of multiple ciphertexts."""
    secret = "concurrent"
    aes = AES(secret_key=secret)
    plaintexts = [f"message {i}" for i in range(10)]
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]

    # Await all decryptions concurrently
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))

@pytest.mark.asyncio

async def test_decrypt_large_plaintext():
    """Test decrypting a large plaintext (~10KB)."""
    secret = "largeplaintext"
    plaintext = "A" * 10_000
    aes = AES(secret_key=secret)
    encrypted = aes_encrypt(plaintext, secret)
    result = await aes.decrypt(encrypted)

@pytest.mark.asyncio
async def test_decrypt_many_concurrent():
    """Test decrypting many ciphertexts concurrently (100)."""
    secret = "manyconcurrent"
    aes = AES(secret_key=secret)
    plaintexts = [f"data {i}" for i in range(100)]
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))

# --- Throughput Test Cases ---

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_small_load():
    """Throughput test: decrypt 10 small messages concurrently."""
    secret = "throughputsmall"
    aes = AES(secret_key=secret)
    plaintexts = [f"msg{i}" for i in range(10)]
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_medium_load():
    """Throughput test: decrypt 100 medium messages concurrently."""
    secret = "throughputmedium"
    aes = AES(secret_key=secret)
    plaintexts = [f"medium message {i}" * 5 for i in range(100)]
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_large_load():
    """Throughput test: decrypt 200 large messages concurrently."""
    secret = "throughputlarge"
    aes = AES(secret_key=secret)
    plaintexts = [("L" * 1000) + str(i) for i in range(200)]
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))

@pytest.mark.asyncio
async def test_AES_decrypt_throughput_mixed_sizes():
    """Throughput test: decrypt a mix of small, medium, and large messages concurrently."""
    secret = "throughputmixed"
    aes = AES(secret_key=secret)
    plaintexts = (
        [f"short{i}" for i in range(10)] +
        [f"medium message {i}" * 10 for i in range(10)] +
        [("L" * 1000) + str(i) for i in range(10)]
    )
    ciphertexts = [aes_encrypt(pt, secret) for pt in plaintexts]
    results = await asyncio.gather(*(aes.decrypt(ct) for ct in ciphertexts))
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-AES.decrypt-mjaaxyed and push.

Codeflash Static Badge

The optimization implements **key derivation caching** to eliminate redundant expensive cryptographic operations. The key change is adding a `_cached_key` attribute that stores the derived key after the first computation.

**What was optimized:**
- Added `self._cached_key: bytes | None = None` in `__init__`
- Modified `_derive_key()` to return the cached key if available, otherwise compute, cache, and return it

**Why this provides a massive speedup:**
The PBKDF2HMAC key derivation with 100,000 iterations is computationally expensive (~10ms per call based on profiler data). In the original code, this expensive operation ran on every `decrypt()` call. The optimization reduces this to a one-time cost per AES instance.

**Performance impact:**
- **Runtime improvement**: 2844% speedup (9.81s → 333ms)
- Line profiler shows `_derive_key()` time dropped from 9.97s to 0.415s
- The expensive `kdf.derive()` operation now only runs 35 times instead of 843 times in the test workload

**When this optimization shines:**
The test results show this optimization is particularly effective for:
- **Concurrent decryption scenarios** (`test_decrypt_concurrent_same_key`, `test_decrypt_many_concurrent`) where the same AES instance decrypts multiple messages
- **Throughput-heavy workloads** (`test_AES_decrypt_throughput_*` tests) with batched operations
- **Reused AES instances** where the same key/salt/IV combination processes multiple ciphertexts

**Thread safety note:** The caching is safe since AES instances are typically not shared across threads, and the key derivation parameters (secret_key, salt, IV) are immutable after initialization.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 17, 2025 17:45
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant