Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions QKV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""QKV Accelerator ISA Definition"""

from taidl import Accelerator

qkv = Accelerator("QKV")


# Define Data Models
# d1: 128 rows x 64 columns of bf16
qkv.add_data_model("d1", [128], [64], "bf16")

# d3: 128 rows x 64 columns of bf16
qkv.add_data_model("d3", [128], [64], "bf16")

# d2: 64 rows x 64 columns of bf16
qkv.add_data_model("d2", [64], [64], "bf16")


# Define Instruction semantics

# (1) load_rm: Loads data from HBM (d0) in row-major format to d1
instr = qkv.add_instruction("load_rm_01", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]]) # u8[@c.n * 128]
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.add_semantics("""
ENTRY load_rm_01 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")


# (2) load_rm: Loads data from HBM (d0) in row-major format to d3
instr = qkv.add_instruction("load_rm_03", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]]) # u8[@c.n * 128]
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.add_semantics("""
ENTRY load_rm_03 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")


# (3) store_rm: Stores data from d1 to HBM (d0) in row-major format
instr = qkv.add_instruction("store_rm_10", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_in"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d0", ["@a.addr_out"], ["@c.n * 128"]]]) # u8[@c.n * 128]
instr.add_semantics("""
ENTRY store_rm_10 {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")

# (4) store_rm: Stores data from d3 to HBM (d0) in row-major format
instr = qkv.add_instruction("store_rm_30", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d3", ["@a.addr_in"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d0", ["@a.addr_out"], ["@c.n * 128"]]]) # u8[@c.n * 128]
instr.add_semantics("""
ENTRY store_rm_30 {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")

# (5) store_cm: Moves data from d1 to d3 in row-major format (with transpose)
instr = qkv.add_instruction("transpose_13",[], ["addr_in", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_in"], ["64"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d3", ["@a.addr_out"], ["64"]]]) # bf16[64, @c.n]
instr.add_semantics("""
ENTRY transpose_13 {
%In1 = bf16[64,64] parameter(0);
%a = bf16[64,64] transpose(%In1), dimensions={1,0};
ROOT %Out0 = bf16[64, 64] copy(%a);
}
""")


# (6) mov: Copies data from d2 to d1
instr = qkv.add_instruction("mov_21", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.add_semantics("""
ENTRY mov_21 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

# (7) mov: Copies data from d2 to d3
# Couldn't figure out how to move 64 columns into a 128 column buffer.
instr = qkv.add_instruction("mov_23", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64] (ignore last 64 columns)
instr.add_semantics("""
ENTRY mov_23 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

# (8) gemm: Matrix multiplication between two d3 tensors, output to d2
instr = qkv.add_instruction("gemm_3", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]]) # bf16[64, 64]
instr.add_semantics("""
ENTRY gemm_3 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")

# (9) gemm: Matrix multiplication between one d1 tensor and one d3 tensor, output to d2
instr = qkv.add_instruction("gemm_13", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]]) # bf16[64, 64]
instr.add_semantics("""
ENTRY gemm_13 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")

# (10) softmax: Applies softmax along dimension 1 (rows) on d2
instr = qkv.add_instruction("softmax", ["n"], ["addr"])
instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]]) # bf16[@c.n, 64]
instr.add_semantics("""
ENTRY softmax {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = bf16[`@c.n`,64] exponential(%In1);
%reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1};
%b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0};
ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b);
}
""")




#Unused:
# # (2) load_cm: Loads data from HBM (d0) in column-major format to d3 (with transpose)
# instr = qkv.add_instruction("load_cm", ["n"], ["addr_in", "addr_out"])
# instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]]) # u8[@c.n * 128]
# instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64]
# instr.add_semantics("""
# ENTRY load_cm {
# %In1 = u8[`@c.n * 128`] parameter(0);
# %a = u8[`@c.n`,64,2] reshape(%In1);
# %b = bf16[`@c.n`,64] bitcast_convert(%a);
# ROOT %Out0 = bf16[64,`@c.n`] transpose(%b), dimensions={1,0};
# }
# """)

# # (2) load_cm: Loads data from HBM (d0) in column-major format to d1 (with transpose)
# instr = qkv.add_instruction("load_cm", ["n"], ["addr_in", "addr_out"])
# instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]]) # u8[@c.n * 128]
# instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]]) # bf16[@c.n, 64]
# instr.add_semantics("""
# ENTRY load_cm {
# %In1 = u8[`@c.n * 128`] parameter(0);
# %a = u8[`@c.n`,64,2] reshape(%In1);
# %b = bf16[`@c.n`,64] bitcast_convert(%a);
# ROOT %Out0 = bf16[64,`@c.n`] transpose(%b), dimensions={1,0};
# }
# """)



# Generate programming APIs and test oracle (functional simulator)
qkv.generate_oracle()

# Generate compiler backend
# qkv.generate_backend()
94 changes: 94 additions & 0 deletions asm/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import jax.numpy as jnp


def qkv(kernel, api):
@kernel(
hbm=32768, # 32 KB: enough for 3 inputs + 1 output
input=[
{'addr': 0, 'shape': (64,64), 'dtype':jnp.bfloat16},
{'addr': 8192, 'shape':(64,64), 'dtype': jnp.bfloat16},
{'addr':16384, 'shape':(64,64), 'dtype': jnp.bfloat16}
], # allocate the input addresses here
constant=[], # if needed, we can add constants here
output=[
{'addr': 24576, 'shape':(64,64), 'dtype': jnp.bfloat16}
] # allocate the output address here
)
def qkv_():
# # Kernel implementation goes here
# #Kernel 1: matmul 13, 13
# # ===== STAGE 1: Compute Attention Scores (Q × K^T) =====
# api.load_rm_01(n=64, addr_in=8192, addr_out=0) #K
# api.transpose_13(addr_in=0,addr_out=0) #K^T
# api.load_rm_01(n=64, addr_in=0, addr_out=0)

# api.gemm_13(addr_1=0,addr_2=0,addr_out=0)
# # ===== STAGE 2: Normalize Scores (softmax) =====
# api.softmax(n=64,addr=0) #In place softmax
# # ===== STAGE 3: Prepare for Second MatMul =====
# api.mov_21(n=64,addr_in=0,addr_out=0) #P is now in d1 addr=0
# api.load_rm_03(n=64,addr_in=16384, addr_out=0) #V is now in d3 addr=0
# # ===== STAGE 4: Compute Final Output (P × V) =====
# api.gemm_13(addr_1=0,addr_2=0,addr_out=0)
# api.mov_21(n=64,addr_in=0,addr_out=0)
# # ===== STAGE 5: Store Result =====
# api.store_rm_10(n=64, addr_in=0, addr_out = 24576)


# # Kernel 2: matmul 13, 3
# # ===== STAGE 1: Compute Attention Scores (Q × K^T) =====
# api.load_rm_01(n=64, addr_in=8192, addr_out=0) #K
# api.transpose_13(addr_in=0,addr_out=0) #K^T
# api.load_rm_01(n=64, addr_in=0, addr_out=0)

# api.gemm_13(addr_1=0,addr_2=0,addr_out=0)
# # ===== STAGE 2: Normalize Scores (softmax) =====
# api.softmax(n=64,addr=0) #In place softmax
# # ===== STAGE 3: Prepare for Second MatMul =====
# api.mov_23(n=64,addr_in=0,addr_out=0) #P is now in d3 addr=0
# api.load_rm_03(n=64,addr_in=16384, addr_out=64) #V is now in d3 addr=64
# # ===== STAGE 4: Compute Final Output (P × V) =====
# api.gemm_3(addr_1=0,addr_2=64,addr_out=0)
# api.mov_21(n=64,addr_in=0,addr_out=0)
# # ===== STAGE 5: Store Result =====
# api.store_rm_10(n=64, addr_in=0, addr_out = 24576)


# # Kernel 3: matmul 3, 13
# # ===== STAGE 1: Compute Attention Scores (Q × K^T) =====
# api.load_rm_01(n=64, addr_in=8192, addr_out=0) #K
# api.transpose_13(addr_in=0,addr_out=64) #K^T
# api.load_rm_03(n=64, addr_in=0, addr_out=0)

# api.gemm_3(addr_1=0,addr_2=64,addr_out=0)
# # ===== STAGE 2: Normalize Scores (softmax) =====
# api.softmax(n=64,addr=0) #In place softmax
# # ===== STAGE 3: Prepare for Second MatMul =====
# api.mov_21(n=64,addr_in=0,addr_out=0) #P is now in d1 addr=0
# api.load_rm_03(n=64,addr_in=16384, addr_out=0) #V is now in d3 addr=0
# # ===== STAGE 4: Compute Final Output (P × V) =====
# api.gemm_13(addr_1=0,addr_2=0,addr_out=0)
# api.mov_21(n=64,addr_in=0,addr_out=0)
# # ===== STAGE 5: Store Result =====
# api.store_rm_10(n=64, addr_in=0, addr_out = 24576)


# Kernel 4: matmul 3, 3
# ===== STAGE 1: Compute Attention Scores (Q × K^T) =====
api.load_rm_01(n=64, addr_in=8192, addr_out=0) #K
api.transpose_13(addr_in=0,addr_out=64) #K^T
api.load_rm_03(n=64, addr_in=0, addr_out=0)

api.gemm_3(addr_1=0,addr_2=64,addr_out=0)
# ===== STAGE 2: Normalize Scores (softmax) =====
api.softmax(n=64,addr=0) #In place softmax
# ===== STAGE 3: Prepare for Second MatMul =====
api.mov_23(n=64,addr_in=0,addr_out=0) #P is now in d3 addr=0
api.load_rm_03(n=64,addr_in=16384, addr_out=64) #V is now in d3 addr=64
# ===== STAGE 4: Compute Final Output (P × V) =====
api.gemm_3(addr_1=0,addr_2=64,addr_out=0)
api.mov_21(n=64,addr_in=0,addr_out=0)
# ===== STAGE 5: Store Result =====
api.store_rm_10(n=64, addr_in=0, addr_out = 24576)

return qkv_
22 changes: 22 additions & 0 deletions asm/identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax.numpy as jnp


def identity(kernel, api):
@kernel(
hbm=16384, # 16 KB: 8 KB input + 8 KB output
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16},
],
constant=[],
output=[
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def identity_():
# Load 64 rows from HBM address 0 to scratchpad d1 address 0
api.load_rm(n=64, addr_in=0, addr_out=0)

# Store 64 rows from scratchpad d1 address 0 to HBM address 8192
api.store_rm(n=64, addr_in=0, addr_out=8192)

return identity_
32 changes: 32 additions & 0 deletions asm/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import jax.numpy as jnp


def matmul(kernel, api):
@kernel(
hbm=24576, # 24 KB: 3 matrices × 8 KB
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Matrix A
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Matrix B
],
constant=[],
output=[
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Matrix C = A × B
]
)
def matmul_():
# Load matrix A into d1[0:63]
api.load_rm(n=64, addr_in=0, addr_out=0)

# Load matrix B into d1[64:127]
api.load_rm(n=64, addr_in=8192, addr_out=64)

# Compute C = A × B, result goes to d2[0:63]
api.gemm(addr_1=0, addr_2=64, addr_out=0)

# Move result from d2[0:63] to d1[0:63]
api.mov(n=64, addr_in=0, addr_out=0)

# Store result back to HBM
api.store_rm(n=64, addr_in=0, addr_out=16384)

return matmul_
38 changes: 38 additions & 0 deletions asm/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import jax.numpy as jnp


def softmax(kernel, api):
@kernel(
hbm=24576,
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16},
],
constant=[
# Identity matrix I
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16,
'value': jnp.eye(64, dtype=jnp.bfloat16)},
],
output=[
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def softmax_():
# Load input matrix A
api.load_rm(n=64, addr_in=0, addr_out=0)

# Load identity matrix I (constant)
api.load_cm(n=64, addr_in=8192, addr_out=64)

# Compute A × I = A
api.gemm(addr_1=0, addr_2=64, addr_out=0)

# Apply softmax to A (in-place in d2)
api.softmax(n=64, addr=0)

# Move softmax(A) from d2 to d1
api.mov(n=64, addr_in=0, addr_out=0)

# Store softmax(A) from d1 to HBM
api.store_rm(n=64, addr_in=0, addr_out=16384)

return softmax_
Loading