Skip to content

Commit d699397

Browse files
committed
[LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to metaschedule
1 parent 585d6d2 commit d699397

File tree

8 files changed

+335
-1
lines changed

8 files changed

+335
-1
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef {
166166
TVM_DLL static Array<Postproc, void> DefaultLLVM();
167167
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
168168
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
169+
/*! \brief Create default postprocessors for RISCV */
170+
TVM_DLL static Array<Postproc, void> DefaultRISCV();
169171
/*! \brief Create default postprocessors for CUDA */
170172
TVM_DLL static Array<Postproc, void> DefaultCUDA();
171173
/*! \brief Create default postprocessors for CUDA with TensorCore */

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef {
301301
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
302302
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
303303
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
304+
/*! \brief Create default schedule rules for RISCV CPU (RVV) */
305+
TVM_DLL static Array<ScheduleRule, void> DefaultRISCV(int vlen);
304306

305307
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
306308
};

python/tvm/target/target.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None):
637637
"-mabi=lp64d",
638638
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74
639639
],
640+
"licheepi3a": [
641+
"-num-cores=8",
642+
"-mtriple=riscv64-unknown-linux-gnu",
643+
"-mcpu=spacemit-x60",
644+
"-mfloat-abi=hard",
645+
"-mabi=lp64d",
646+
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d -mcpu=spacemit-x60
647+
],
640648
}
641649
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
642650

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from . import cuda
2121

2222
if enabled("llvm"):
23-
from . import arm_cpu, x86, rocm, hexagon
23+
from . import arm_cpu, x86, rocm, hexagon, riscv_cpu
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,line-too-long
18+
"""Intrinsics for RISCV tensorization"""
19+
20+
import logging
21+
from tvm.ffi import register_func
22+
from tvm.runtime import DataType
23+
from tvm.script import tir as T
24+
from tvm.target.codegen import Target
25+
from tvm.target.codegen import llvm_get_vector_width, target_has_features
26+
from tvm.target.datatype import get_type_name
27+
from .. import TensorIntrin
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def get_max_elems(vlen: int, lmul: int, sew: int) -> int:
33+
"""Returns number of elements of a given data type (SEW)
34+
that fits multiple (LMUL) of the vector registers (VLEN).
35+
36+
Args:
37+
vlen (int): VLEN vector length in bits
38+
lmul (int): LMUL vector lenght multiplier
39+
sew (int): SEW standard (single) element width
40+
41+
Returns:
42+
int: Number of elements
43+
"""
44+
return (vlen // sew) * lmul
45+
46+
47+
def rvv_vec_dot_product_kernels(
48+
n_elems: int,
49+
n_lanes: int,
50+
data_dtype: str,
51+
weight_dtype: str,
52+
out_dtype: str,
53+
lmul: int,
54+
):
55+
"""
56+
Dot product of vector and matrix rows using RISC-V vector instructions.
57+
58+
These kernel takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes
59+
dot product of A[ELEMS] with each row of B[LANES], accumulating results
60+
in C[LANES].
61+
62+
The pseudo code is as follows:
63+
.. code-block:: c
64+
void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){
65+
for (j = 0; j < LANES; j++) {
66+
for (k = 0; k < ELEMS; k++) {
67+
C[j] += A[k] * B[j][k]
68+
}
69+
}
70+
}
71+
"""
72+
73+
@T.prim_func
74+
def rvv_vec_dot_prod_desc(
75+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
76+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
77+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
78+
) -> None:
79+
with T.block("root"):
80+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
81+
T.writes(C[0:n_lanes])
82+
for j in T.serial(0, n_lanes):
83+
for k in T.serial(0, n_elems):
84+
with T.block("update"):
85+
vj, vk = T.axis.remap("SR", [j, k])
86+
C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype)
87+
88+
# LLVM only supports ELEN=32 or ELEN=64
89+
# https://llvm.org/docs//RISCV/RISCVVectorExtension.html
90+
d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul
91+
w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul
92+
# reduction lanes narrow
93+
o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes
94+
# data type widening case
95+
o_dtype_lanes = max(o_dtype_lanes, 2)
96+
97+
mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),)
98+
99+
wide_dtype = out_dtype
100+
if DataType(out_dtype).bits > DataType(data_dtype).bits:
101+
wide_dtype = "".join(c for c in data_dtype if not c.isdigit())
102+
wide_dtype += str(DataType(data_dtype).bits * 2)
103+
104+
# fmt: off
105+
@T.prim_func
106+
def rvv_vec_dot_prod_impl(
107+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
108+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
109+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
110+
) -> None:
111+
with T.block("root"):
112+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
113+
T.writes(C[0:n_lanes])
114+
115+
vec_A = T.call_llvm_intrin(
116+
f"{data_dtype}xvscalex{d_dtype_lanes}",
117+
"llvm.riscv.vle",
118+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes),
119+
T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, n_elems, 1),
120+
T.int64(n_elems))
121+
122+
for i in range(n_lanes):
123+
with T.block("update"):
124+
T.reads(B[i, 0:n_elems])
125+
T.writes(C[i])
126+
127+
vec_B_row = T.call_llvm_intrin(
128+
f"{weight_dtype}xvscalex{w_dtype_lanes}",
129+
"llvm.riscv.vle",
130+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * w_dtype_lanes),
131+
T.tvm_access_ptr(T.type_annotation(weight_dtype), B.data, i * n_elems, n_elems, 1),
132+
T.int64(n_elems))
133+
134+
product = T.call_llvm_intrin(
135+
f"{wide_dtype}xvscalex{w_dtype_lanes}",
136+
"llvm.riscv.vfmul" if out_dtype[0] == "f" else \
137+
"llvm.riscv.vwmulsu" if (data_dtype[0] != weight_dtype[0]) else \
138+
"llvm.riscv.vwmul",
139+
T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * w_dtype_lanes),
140+
vec_B_row,
141+
vec_A,
142+
*mask_args,
143+
T.uint64(n_elems))
144+
145+
ini_acc = T.call_llvm_intrin(
146+
f"{out_dtype}xvscalex{o_dtype_lanes}",
147+
"llvm.riscv.vle",
148+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
149+
T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, i, 1, 1),
150+
T.int64(1))
151+
152+
red_sum = T.call_llvm_intrin(
153+
f"{out_dtype}xvscalex{o_dtype_lanes}",
154+
"llvm.riscv.vfredusum" if out_dtype[0] == "f" else \
155+
"llvm.riscv.vwredsum",
156+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
157+
product,
158+
ini_acc,
159+
*mask_args,
160+
T.uint64(n_elems))
161+
162+
C[i] = T.call_llvm_intrin(
163+
out_dtype,
164+
"llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \
165+
"llvm.riscv.vmv.x.s",
166+
red_sum)
167+
# fmt: on
168+
169+
return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
170+
171+
172+
@register_func("tir.tensor_intrin.register_rvv_isa_intrisics")
173+
def register_rvv_isa_intrisics(target: Target, only_inventory=False) -> dict():
174+
"""Register RISCV V (vector) intrinsics
175+
[x] Implementation follows version 1.0 vector specifications:
176+
https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0
177+
178+
Args:
179+
target (Target): TVM target
180+
only_inventory (bool): No registration inventory only
181+
182+
Returns:
183+
dict(): A catalog with registered kernel names and properties
184+
"""
185+
if not target_has_features("v", target):
186+
raise RuntimeError("Current target does not support `v` extension.")
187+
188+
vlen = llvm_get_vector_width(target)
189+
# get maximum reduction lanes (without grouping)
190+
n_lanes = get_max_elems(vlen, lmul=1, sew=32)
191+
192+
data_dtype = ["uint8", "int8", "float16", "float32"]
193+
weight_dtype = ["int8", "int8", "float16", "float32"]
194+
output_dtype = ["int32", "int32", "float16", "float32"]
195+
196+
kernel_inventory = {}
197+
198+
for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, output_dtype):
199+
# max elements to grouped registers
200+
max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits)
201+
# data widening halves available vector registers
202+
if DataType(o_dtype).bits > DataType(d_dtype).bits:
203+
max_elems //= 2
204+
# compute optimal LMUL for full load
205+
lmul = max_elems // (vlen // DataType(d_dtype).bits)
206+
207+
n_elems = max_elems
208+
while n_elems >= 4:
209+
210+
dt = DataType(d_dtype)
211+
wt = DataType(w_dtype)
212+
ot = DataType(o_dtype)
213+
214+
kernel_name = f"rvv_dot"
215+
kernel_name += f"_{n_elems}{dt[0]}{dt.bits}"
216+
kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}"
217+
kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}"
218+
kernel_inventory[kernel_name] = n_elems
219+
220+
if not only_inventory:
221+
logger.debug("Registering kernel %s" % kernel_name)
222+
desc, impl = rvv_vec_dot_product_kernels(
223+
n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul
224+
)
225+
TensorIntrin.register(kernel_name, desc, impl, override=True)
226+
227+
n_elems //= 2
228+
229+
return kernel_inventory
230+
231+
232+
def register_riscv_intrinsics(target: Target):
233+
"""Register RISCV intrinsics
234+
235+
Args:
236+
target (Target): TVM target
237+
"""
238+
239+
# RISC-V `v` extension ISA
240+
_ = register_rvv_isa_intrisics(target)
241+
logger.debug("Finished registering riscv intrinsics.")

src/meta_schedule/postproc/postproc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ Array<Postproc> Postproc::DefaultCPUTensorization() {
6969
};
7070
}
7171

72+
Array<Postproc> Postproc::DefaultRISCV() {
73+
return Array<Postproc>{
74+
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
75+
Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
76+
Postproc::RewriteLayout(),
77+
};
78+
}
79+
7280
Array<Postproc> Postproc::DefaultCUDA() {
7381
return Array<Postproc>{
7482
Postproc::DisallowDynamicLoop(),

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919
#include <tvm/ffi/reflection/registry.h>
20+
#include <tvm/runtime/data_type.h>
2021

2122
#include "../utils.h"
2223

@@ -304,6 +305,62 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
304305
};
305306
}
306307

308+
Array<ScheduleRule> ScheduleRule::DefaultRISCV(const int vlen) {
309+
Array<ScheduleRule> rules;
310+
rules.push_back(ScheduleRule::ApplyCustomRule());
311+
rules.push_back(ScheduleRule::InlineConstantScalars());
312+
rules.push_back(ScheduleRule::AutoInline(
313+
/*into_producer=*/false,
314+
/*into_consumer=*/true,
315+
/*inline_const_tensor=*/true,
316+
/*disallow_if_then_else=*/true,
317+
/*require_injective=*/true,
318+
/*require_ordered=*/true,
319+
/*disallow_op=*/Array<String>{"tir.exp"}));
320+
rules.push_back(ScheduleRule::AddRFactor(
321+
/*max_jobs_per_core=*/16,
322+
/*max_innermost_factor=*/Integer(64)));
323+
auto current_target = tvm::Target::Current();
324+
const auto reg_rvv_intrinsics =
325+
tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrisics");
326+
const auto rvv_kernels_inventory =
327+
reg_rvv_intrinsics(current_target, /* only_inventory */ true).cast<Map<String, int>>();
328+
for (const auto& intrin : rvv_kernels_inventory) {
329+
if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) {
330+
// on demand intrinsic register
331+
reg_rvv_intrinsics(current_target, /* only_inventory */ false);
332+
}
333+
rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
334+
/*intrin_name=*/intrin.first,
335+
/*structure=*/"SSRSRS",
336+
/*tile_binds=*/std::nullopt,
337+
/*max_innermost_factor=*/Integer(intrin.second),
338+
/*vector_load_lens=*/std::nullopt,
339+
/*reuse_read=*/std::nullopt,
340+
/*reuse_write=*/
341+
Map<String, ffi::Any>{{"req", String("may")},
342+
{"levels", Array<Integer>{1, 2}},
343+
{"scope", String("global")}}));
344+
}
345+
rules.push_back(ScheduleRule::MultiLevelTiling(
346+
/*structure=*/"SSRSRS",
347+
/*tile_binds=*/std::nullopt,
348+
/*max_innermost_factor=*/Integer(64),
349+
/*vector_load_lens=*/std::nullopt,
350+
/*reuse_read=*/std::nullopt,
351+
/*reuse_write=*/
352+
Map<String, ffi::Any>{
353+
{"req", String("may")}, {"levels", Array<Integer>{1, 2}}, {"scope", String("global")}}));
354+
rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll(
355+
/*max_jobs_per_core=*/16,
356+
/*max_vectorize_extent=*/64,
357+
/*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
358+
/*unroll_explicit=*/true));
359+
rules.push_back(ScheduleRule::RandomComputeLocation());
360+
361+
return rules;
362+
}
363+
307364
Array<ScheduleRule> GetARMNeonSpecificRules() {
308365
return {
309366
ScheduleRule::MultiLevelTilingWithIntrin(

0 commit comments

Comments
 (0)