Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions QKV_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""QKV Accelerator ISA Definition for Tutorial 2"""

from taidl import Accelerator

qkv_2 = Accelerator("QKV_2")

# Define Data Models
qkv_2.add_data_model("d1", [128], [64], "bf16")
qkv_2.add_data_model("d2", [64], [64], "bf16")
qkv_2.add_data_model("d3", [128], [64], "bf16")

# Load instructions
instr = qkv_2.add_instruction("load_rm_d1", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d1 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")

instr = qkv_2.add_instruction("load_rm_d2", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d2 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")

instr = qkv_2.add_instruction("load_rm_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d0", ["@a.addr_in"], ["@c.n * 128"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY load_rm_d3 {
%In1 = u8[`@c.n * 128`] parameter(0);
%a = u8[`@c.n`,64,2] reshape(%In1);
ROOT %Out0 = bf16[`@c.n`,64] bitcast_convert(%a);
}
""")


instr = qkv_2.add_instruction("store_rm_d2", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d0", ["@a.addr_out"], ["@c.n * 128"]]])
instr.add_semantics("""
ENTRY store_rm_d2 {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = u8[`@c.n`,64,2] bitcast_convert(%In1);
ROOT %Out0 = u8[`@c.n*128`] reshape(%a);
}
""")


# Compute instructions
instr = qkv_2.add_instruction("gemm_d1_d3", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY gemm_d1_d3 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")

instr = qkv_2.add_instruction("gemm_d3_d3", [], ["addr_1", "addr_2", "addr_out"])
instr.set_inputs([["d3", ["@a.addr_1"], ["64"]], ["d3", ["@a.addr_2"], ["64"]]])
instr.set_outputs([["d2", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY gemm_d3_d3 {
%In1 = bf16[64,64] parameter(0);
%In2 = bf16[64,64] parameter(1);
ROOT %Out0 = bf16[64,64] dot(%In1, %In2), lhs_contracting_dims={1}, rhs_contracting_dims={0};
}
""")


instr = qkv_2.add_instruction("softmax", ["n"], ["addr"])
instr.set_inputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.set_outputs([["d2", ["@a.addr"], ["@c.n"]]])
instr.add_semantics("""
ENTRY softmax {
%In1 = bf16[`@c.n`,64] parameter(0);
%a = bf16[`@c.n`,64] exponential(%In1);
%reduced = bf16[`@c.n`] reduce_add(%a), dimensions={1};
%b = bf16[`@c.n`,64] broadcast(%reduced), dimensions={0};
ROOT %Out0 = bf16[`@c.n`,64] divide(%a, %b);
}
""")

instr = qkv_2.add_instruction("copy_d2_d1", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d1", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY copy_d2_d1 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

instr = qkv_2.add_instruction("copy_d2_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d2", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["@c.n"]]])
instr.add_semantics("""
ENTRY copy_d2_d3 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[`@c.n`,64] copy(%In1);
}
""")

instr = qkv_2.add_instruction("transpose_d1_d3", ["n"], ["addr_in", "addr_out"])
instr.set_inputs([["d1", ["@a.addr_in"], ["@c.n"]]])
instr.set_outputs([["d3", ["@a.addr_out"], ["64"]]])
instr.add_semantics("""
ENTRY transpose_d1_d3 {
%In1 = bf16[`@c.n`,64] parameter(0);
ROOT %Out0 = bf16[64,`@c.n`] transpose(%In1), dimensions={1,0};
}
""")

# Generate programming APIs and test oracle (functional simulator)
qkv_2.generate_oracle()

# Generate compiler backend
qkv_2.generate_backend()
47 changes: 47 additions & 0 deletions asm/attention_qkv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768, # 32 KB: enough for 3 inputs + 1 output
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Q
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # K
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # V
],
constant=[], # No constants needed
output=[
{'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def qkv_():
#kernel implementation goes here

#load k from hbm to d1
api.load_rm_d1(n=64, addr_in=8192, addr_out=0)

#tranpose k to k^t
api.transpose_d1_d3(n=64, addr_in=0, addr_out=0)

#load q from hbm to d1
api.load_rm_d1(n=64, addr_in=0, addr_out=0)

# Compute S = Q × K^T
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

# Apply softmax to S, converting it to P (in-place in d2)
api.softmax(n=64, addr=0)

# move p to d1
api.copy_d2_d1(n=64, addr_in=0, addr_out=0)

# load v from hbm to d3
api.load_rm_d3(n=64, addr_in=16384, addr_out=0)

# compute O from P x V
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

#store O back to hbm
api.store_rm_d2(n=64, addr_in=0, addr_out=24576)


return qkv_
46 changes: 46 additions & 0 deletions asm/attention_qkv2_variant_B.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768, # 32 KB: enough for 3 inputs + 1 output
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Q
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # K
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # V
],
constant=[], # No constants needed
output=[
{'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def qkv_():
#kernel implementation goes here

#load k from hbm to d1
api.load_rm_d1(n=64, addr_in=8192, addr_out=0)

#tranpose k to k^t
api.transpose_d1_d3(n=64, addr_in=0, addr_out=0)

#load q from hbm to d1
api.load_rm_d1(n=64, addr_in=0, addr_out=0)

# Compute S = Q × K^T
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

# Apply softmax to S, converting it to P (in-place in d2)
api.softmax(n=64, addr=0)

# move p to d3
api.copy_d2_d3(n=64, addr_in=0, addr_out=0)

# load v from hbm to d3
api.load_rm_d3(n=64, addr_in=16384, addr_out=64)

# compute O from P x V
api.gemm_d3_d3(addr_1=0, addr_2=64, addr_out=0)

#store O back to hbm
api.store_rm_d2(n=64, addr_in=0, addr_out=24576)

return qkv_
51 changes: 51 additions & 0 deletions asm/attention_qkv2_variant_C.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768, # 32 KB: enough for 3 inputs + 1 output
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Q
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # K
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # V
],
constant=[], # No constants needed
output=[
{'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def qkv_():
#kernel implementation goes here

# load v from hbm to d3
api.load_rm_d3(n=64, addr_in=16384, addr_out=0)

#load k from hbm to d1
api.load_rm_d1(n=64, addr_in=8192, addr_out=0)

#tranpose k to k^t
api.transpose_d1_d3(n=64, addr_in=0, addr_out=0)

#load q from hbm to d1
api.load_rm_d1(n=64, addr_in=0, addr_out=0)

# Compute S = Q × K^T
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

# Apply softmax to S, converting it to P (in-place in d2)
api.softmax(n=64, addr=0)

# move p to d1
api.copy_d2_d1(n=64, addr_in=0, addr_out=0)

# load v from hbm to d3
api.load_rm_d3(n=64, addr_in=16384, addr_out=0)

# compute O from P x V
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

#store O back to hbm
api.store_rm_d2(n=64, addr_in=0, addr_out=24576)



return qkv_
47 changes: 47 additions & 0 deletions asm/attention_qkv2_variant_D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax.numpy as jnp

def qkv(kernel, api):
@kernel(
hbm=32768, # 32 KB: enough for 3 inputs + 1 output
input=[
{'addr': 0, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # Q
{'addr': 8192, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # K
{'addr': 16384, 'shape': (64, 64), 'dtype': jnp.bfloat16}, # V
],
constant=[], # No constants needed
output=[
{'addr': 24576, 'shape': (64, 64), 'dtype': jnp.bfloat16},
]
)
def qkv_():
#kernel implementation goes here

#load k from hbm to d1
api.load_rm_d1(n=64, addr_in=8192, addr_out=0)

#tranpose k to k^t
api.transpose_d1_d3(n=64, addr_in=0, addr_out=64)

#load q from hbm to d3
api.load_rm_d3(n=64, addr_in=0, addr_out=0)

# Compute S = Q × K^T
api.gemm_d3_d3(addr_1=0, addr_2=64, addr_out=0)

# Apply softmax to S, converting it to P (in-place in d2)
api.softmax(n=64, addr=0)

# move p to d1
api.copy_d2_d1(n=64, addr_in=0, addr_out=0)

# load v from hbm to d3
api.load_rm_d3(n=64, addr_in=16384, addr_out=0)

# compute O from P x V
api.gemm_d1_d3(addr_1=0, addr_2=0, addr_out=0)

#store O back to hbm
api.store_rm_d2(n=64, addr_in=0, addr_out=24576)


return qkv_
Loading