1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/arith/analyzer.h>
20#include <tvm/script/ir_builder/tir/ir.h>
21
22#include "./utils.h"
23
24namespace tvm {
25namespace script {
26namespace ir_builder {
27namespace tir {
28
29using tvm::tir::IterVar;
30
31Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
32 Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
33 String storage_scope, int align, int offset_factor, String buffer_type,
34 Optional<Array<IntImm>> axis_separators) {
35 Var buffer_data;
36 if (!data.defined()) {
37 DataType storage_dtype = dtype;
38 if (storage_dtype == DataType::Bool()) {
39 storage_dtype = DataType::Int(8);
40 }
41 buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope));
42 } else {
43 buffer_data = data.value();
44 }
45 if (!elem_offset.defined() && offset_factor) {
46 DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype;
47 elem_offset = tvm::tir::Var("elem_offset", shape_dtype);
48 }
49 return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
50 elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor,
51 (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault,
52 axis_separators.value_or(Array<IntImm>()));
53}
54
55PrimFuncFrame PrimFunc() {
56 ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
57 n->name = NullOpt;
58 n->args.clear();
59 n->ret_type = NullOpt;
60 n->buffer_map.clear();
61 n->attrs = NullOpt;
62 n->env_threads.clear();
63 n->root_alloc_buffers.clear();
64 return PrimFuncFrame(n);
65}
66
67Var Arg(String name, Var var) {
68 PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
69 details::Namer::Name(var, name);
70 frame->args.push_back(var);
71 return var;
72}
73
74Buffer Arg(String name, Buffer buffer) {
75 PrimFuncFrame frame = FindPrimFuncFrame("T.Arg");
76 details::Namer::Name(buffer, name);
77 Var handle(buffer->name + "_handle", DataType::Handle());
78 frame->args.push_back(handle);
79 frame->buffer_map.Set(handle, buffer);
80 return buffer;
81}
82
83void FuncName(String name) {
84 PrimFuncFrame frame = FindPrimFuncFrame("T.func_name");
85 if (frame->name.defined()) {
86 LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value();
87 }
88 frame->name = name;
89}
90
91void FuncAttrs(Map<String, ObjectRef> attrs) {
92 using namespace tvm::tir;
93 PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr");
94 if (frame->attrs.defined()) {
95 LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs;
96 }
97 frame->attrs = attrs;
98}
99
100tvm::Type FuncRet(tvm::Type ret_type) {
101 PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type");
102 if (frame->ret_type.defined()) {
103 LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is "
104 << frame->ret_type.value();
105 }
106 frame->ret_type = ret_type;
107 return ret_type;
108}
109
110Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
111 Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
112 int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
113 Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
114 offset_factor, buffer_type_str, axis_separators);
115 if (const auto* var = param.as<tvm::tir::VarNode>()) {
116 PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer");
117 Var v = GetRef<Var>(var);
118 for (auto const& arg : frame->args) {
119 if (arg.same_as(v)) {
120 frame->buffer_map.Set(v, buffer);
121 return buffer;
122 }
123 }
124 LOG(FATAL) << "ValueError: Can not bind non-input param to buffer.";
125 } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) {
126 BlockFrame frame = FindBlockFrame("T.match_buffer");
127 frame->match_buffers.push_back(tvm::tir::MatchBufferRegion(
128 buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load))));
129 } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) {
130 BlockFrame frame = FindBlockFrame("T.match_buffer");
131 frame->match_buffers.push_back(
132 tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region)));
133 } else {
134 LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer.";
135 }
136 return buffer;
137}
138
139BlockFrame Block(String name, bool no_realize) {
140 ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>();
141 n->name = name;
142 n->iter_vars.clear();
143 n->reads = NullOpt;
144 n->writes = NullOpt;
145 n->init = NullOpt;
146 n->alloc_buffers.clear();
147 n->match_buffers.clear();
148 n->annotations = NullOpt;
149 n->iter_values.clear();
150 n->predicate = NullOpt;
151 n->no_realize = no_realize;
152 return BlockFrame(n);
153}
154
155BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); }
156
157void Where(PrimExpr predicate) {
158 BlockFrame frame = FindBlockFrame("T.where");
159 if (frame->predicate.defined()) {
160 LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is "
161 << frame->predicate;
162 }
163 frame->predicate = predicate;
164}
165
166void Reads(Array<ObjectRef> buffer_slices) {
167 using namespace tvm::tir;
168 BlockFrame frame = FindBlockFrame("T.reads");
169 if (frame->reads.defined()) {
170 LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads;
171 }
172 Array<BufferRegion> reads;
173 for (const ObjectRef& obj : buffer_slices) {
174 if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
175 reads.push_back(GetRef<BufferRegion>(buffer_region));
176 } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
177 reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
178 } else {
179 LOG(FATAL) << "Invalid type for buffer reads.";
180 }
181 }
182 frame->reads = reads;
183}
184
185void Writes(Array<ObjectRef> buffer_slices) {
186 using namespace tvm::tir;
187 BlockFrame frame = FindBlockFrame("T.writes");
188 if (frame->writes.defined()) {
189 LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is "
190 << frame->writes;
191 }
192 Array<BufferRegion> writes;
193 for (const ObjectRef& obj : buffer_slices) {
194 if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
195 writes.push_back(GetRef<BufferRegion>(buffer_region));
196 } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
197 writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
198 } else {
199 LOG(FATAL) << "Invalid type for buffer writes.";
200 }
201 }
202 frame->writes = writes;
203}
204
205void BlockAttrs(Map<String, ObjectRef> attrs) {
206 BlockFrame frame = FindBlockFrame("T.block_attr");
207 if (frame->annotations.defined()) {
208 LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations;
209 }
210 frame->annotations = attrs;
211}
212
213Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
214 Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align,
215 int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) {
216 Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align,
217 offset_factor, buffer_type_str, axis_separators);
218 IRBuilder builder = IRBuilder::Current();
219 if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
220 frame.value()->alloc_buffers.push_back(buffer);
221 } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) {
222 frame.value()->root_alloc_buffers.push_back(buffer);
223 } else {
224 LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure "
225 "'T.alloc_buffer' is called under T.block() or T.prim_func()";
226 }
227 return buffer;
228}
229namespace axis {
230
231IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
232 if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
233 BlockFrame frame = opt_frame.value();
234 frame->iter_vars.push_back(iter_var);
235 frame->iter_values.push_back(binding);
236 } else {
237 LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
238 }
239 return iter_var;
240}
241
242#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \
243 Var Method(Range dom, PrimExpr binding, DataType dtype) { \
244 ICHECK(dom.defined()) << Name << " axis must have a domain"; \
245 int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
246 return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \
247 /*iter_type=*/Kind, /*thread_tag=*/""), \
248 binding) \
249 ->var; \
250 }
251TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
252TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
253TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
254TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
255#undef TVM_TIR_IR_BUILDER_AXIS
256
257Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
258 using namespace tvm::tir;
259 Array<Var> results;
260 ICHECK_EQ(kinds.size(), bindings.size());
261 int n = bindings.size();
262 results.reserve(n);
263 for (int i = 0; i < n; ++i) {
264 char c = kinds.c_str()[i];
265 PrimExpr e = bindings[i];
266 const VarNode* v = e.as<VarNode>();
267 ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
268 Range dom{nullptr};
269 for (const auto& frame : IRBuilder::Current()->frames) {
270 if (const auto* for_frame = frame.as<ForFrameNode>()) {
271 ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
272 int n = for_frame->doms.size();
273 for (int i = 0; i < n; ++i) {
274 if (for_frame->vars[i].get() == v) {
275 dom = for_frame->doms[i];
276 break;
277 }
278 }
279 if (dom.defined()) {
280 break;
281 }
282 }
283 }
284 ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
285 DataType dtype = v->dtype;
286 if (c == 'S') {
287 results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
288 /*var=*/Var("", dtype),
289 /*iter_type=*/IterVarType::kDataPar,
290 /*thread_tag=*/""),
291 e)
292 ->var);
293 } else if (c == 'R') {
294 results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
295 /*var=*/Var("", dtype),
296 /*iter_type=*/IterVarType::kCommReduce,
297 /*thread_tag=*/""),
298 e)
299 ->var);
300 } else {
301 LOG(FATAL) << "Unknown axis kind: " << c;
302 }
303 }
304 return results;
305}
306
307} // namespace axis
308
309#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
310 ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \
311 PrimExpr min = start; \
312 PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
313 ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \
314 int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
315 n->vars = {Var("v", DataType::Int(bits))}; \
316 n->doms = {Range::FromMinExtent(min, extent)}; \
317 n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \
318 ICHECK_EQ(vars.size(), 1); \
319 ICHECK_EQ(doms.size(), 1); \
320 return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \
321 annotations.value_or(Map<String, ObjectRef>())); \
322 }; \
323 return ForFrame(n); \
324 }
325
326TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial);
327TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel);
328TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized);
329TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled);
330
331#undef TVM_TIR_IR_BUILDER_FOR_FRAME
332
333ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
334 Optional<Map<String, ObjectRef>> annotations) {
335 using namespace tvm::tir;
336 PrimExpr min = start;
337 PrimExpr extent = arith::Analyzer().Simplify(stop - start);
338 ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
339 int bits = std::max(min.dtype().bits(), extent.dtype().bits());
340 n->vars = {Var("v", DataType::Int(bits))};
341 n->doms = {Range::FromMinExtent(min, extent)};
342 n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For {
343 ICHECK_EQ(vars.size(), 1);
344 ICHECK_EQ(doms.size(), 1);
345 IterVar iter_var(Range(nullptr), Var("iter", DataType::Int(32)), IterVarType::kThreadIndex,
346 thread);
347 return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
348 annotations.value_or(Map<String, ObjectRef>()));
349 };
350 return ForFrame(n);
351}
352
353ForFrame Grid(Array<PrimExpr> extents) {
354 using namespace tvm::tir;
355 ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
356 n->vars.reserve(extents.size());
357 n->doms.reserve(extents.size());
358 for (const auto& extent : extents) {
359 DataType dtype = extent.dtype();
360 n->vars.push_back(Var("v", extent.dtype()));
361 n->doms.push_back(Range(make_const(dtype, 0), extent));
362 }
363 n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt {
364 ICHECK_EQ(vars.size(), doms.size());
365 int n = vars.size();
366 for (int i = n - 1; i >= 0; --i) {
367 Range dom = doms[i];
368 Var var = vars[i];
369 body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
370 /*thread_binding=*/NullOpt, /*annotations=*/{});
371 }
372 return body;
373 };
374 return ForFrame(n);
375}
376
377AssertFrame Assert(PrimExpr condition, String message) {
378 ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>();
379 n->condition = condition;
380 n->message = tvm::tir::StringImm(message);
381 return AssertFrame(n);
382}
383
384LetFrame Let(Var var, PrimExpr value) {
385 ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
386 n->var = var;
387 n->value = value;
388 return LetFrame(n);
389}
390
391LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
392 IterVar iter_var{nullptr};
393
394 if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
395 if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) {
396 iter_var = opt_iter_var.value();
397 } else {
398 LOG(FATAL) << "ValueError: " << var->name_hint
399 << " is not an env_thread created using T.env_thread.";
400 }
401 } else {
402 LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc";
403 }
404 ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
405 if (!iter_var->dom.defined()) {
406 const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
407 } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
408 LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
409 << iter_var->dom->extent << " vs " << extent;
410 }
411 n->iter_var = iter_var;
412 n->extent = extent;
413 n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent";
414 return LaunchThreadFrame(n);
415}
416
417RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
418 PrimExpr condition) {
419 ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
420 n->buffer_slice = buffer_slice;
421 n->storage_scope = storage_scope;
422 n->condition = condition;
423 return RealizeFrame(n);
424}
425
426AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope,
427 Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) {
428 ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>();
429 n->extents = extents;
430 n->dtype = dtype;
431 n->storage_scope = storage_scope;
432 n->condition = condition.value_or(tvm::Bool(true));
433 n->annotations = annotations.value_or(Map<String, ObjectRef>());
434 n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope));
435 return AllocateFrame(n);
436}
437
438AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
439 Array<PrimExpr> extents,
440 Optional<Map<String, ObjectRef>> annotations) {
441 ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
442 n->dtype = dtype;
443 n->extents = extents;
444 n->data = data;
445 n->annotations = annotations.value_or(Map<String, ObjectRef>());
446 n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype)));
447 return AllocateConstFrame(n);
448}
449
450AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) {
451 ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>();
452 n->node = node;
453 n->attr_key = attr_key;
454 n->value = value;
455 return AttrFrame(n);
456}
457
458WhileFrame While(PrimExpr condition) {
459 ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>();
460 n->condition = condition;
461 return WhileFrame(n);
462}
463
464IfFrame If(PrimExpr condition) {
465 ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>();
466 n->condition = condition;
467 n->then_stmts = NullOpt;
468 n->else_stmts = NullOpt;
469 return IfFrame(n);
470}
471
472ThenFrame Then() {
473 ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>();
474 return ThenFrame(n);
475}
476
477ElseFrame Else() {
478 ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>();
479 return ElseFrame(n);
480}
481
482Var EnvThread(String thread_tag) {
483 IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
484 thread_tag);
485 Var var = iter_var->var;
486 if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
487 opt_frame.value()->env_threads.Set(var, iter_var);
488 } else {
489 LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
490 }
491 return var;
492}
493
494void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
495 runtime::DataType buffer_dtype = buffer->dtype;
496 int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
497 runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes);
498 runtime::DataType rhs_dtype = value->dtype;
499 if (lhs_dtype != rhs_dtype) {
500 if (lhs_dtype.lanes() != rhs_dtype.lanes()) {
501 LOG(FATAL) << "TypeError: Incompatible types in BufferStore"
502 << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype
503 << "`, indexing lanes: " << index_lanes;
504 }
505 if (lhs_dtype.code() != rhs_dtype.code()) {
506 if (
507 // Case 1. lhs is handle, and rhs needs to be casted to handle.
508 (lhs_dtype.code() == runtime::DataType::kHandle) ||
509 // Case 2. rhs is handle, and it needs to be casted to non-handle.
510 (rhs_dtype.code() == runtime::DataType::kHandle) ||
511 // Case 3. rhs is float or bfloat, and casting to non-float can lose precision.
512 ((lhs_dtype.code() == runtime::DataType::kInt ||
513 lhs_dtype.code() == runtime::DataType::kUInt) &&
514 (rhs_dtype.code() == runtime::DataType::kFloat ||
515 rhs_dtype.code() == runtime::DataType::kBFloat))) {
516 LOG(WARNING) << "Casting in BufferStore may lose precision"
517 << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype
518 << "`, indexing lanes: " << index_lanes;
519 }
520 }
521 value = tvm::cast(lhs_dtype, value);
522 }
523 AddToParent(tvm::tir::BufferStore(buffer, value, indices));
524}
525
526void Prefetch(Buffer buffer, Array<Range> bounds) {
527 AddToParent(tvm::tir::Prefetch(buffer, bounds));
528}
529
530DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name,
531 Optional<Var> data, Optional<Array<PrimExpr>> strides,
532 Optional<PrimExpr> elem_offset, String storage_scope, int align,
533 int offset_factor, String buffer_type,
534 Optional<Array<IntImm>> axis_separators) {
535 ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
536 n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
537 align, offset_factor, buffer_type, axis_separators);
538 n->allocated = data.defined();
539 return DeclBufferFrame(n);
540}
541
542void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
543
544PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
545 return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
546}
547
548using tvm::script::ir_builder::details::Namer;
549
550TVM_STATIC_IR_FUNCTOR(Namer, vtable)
551 .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void {
552 tvm::tir::BufferNode* buffer =
553 const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>());
554 buffer->name = name;
555 Namer::Name(buffer->data, name);
556 int n = buffer->strides.size();
557 for (int i = 0; i < n; ++i) {
558 PrimExpr e = buffer->strides[i];
559 if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) {
560 Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i));
561 }
562 }
563 });
564
565TVM_STATIC_IR_FUNCTOR(Namer, vtable)
566 .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void {
567 using namespace tvm::tir;
568 SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>());
569 var->name_hint = name;
570 });
571
572TVM_STATIC_IR_FUNCTOR(Namer, vtable)
573 .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void {
574 using namespace tvm::tir;
575 VarNode* var = const_cast<VarNode*>(node.as<VarNode>());
576 var->name_hint = name;
577 });
578
579TVM_STATIC_IR_FUNCTOR(Namer, vtable)
580 .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void {
581 using namespace tvm::tir;
582 IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>());
583 Namer::Name(var->var, name);
584 });
585
586TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl").set_body_typed(BufferDecl);
587
588TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc);
589TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg")
590 .set_body_typed([](String name, ObjectRef obj) -> ObjectRef {
591 using namespace tvm::tir;
592 if (const auto* var = obj.as<VarNode>()) {
593 return Arg(name, GetRef<tvm::tir::Var>(var));
594 }
595 if (const auto* buffer = obj.as<BufferNode>()) {
596 return Arg(name, GetRef<Buffer>(buffer));
597 }
598 LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey();
599 throw;
600 });
601TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName);
602TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
603TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet);
604TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
605
606TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
607TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init);
608TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where);
609TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads);
610TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes);
611TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
612TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
613
614TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
615TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
616TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
617TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
618TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
619
620TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial);
621TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel);
622TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized);
623TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll);
624TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding);
625TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);
626
627TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
628TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
629TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate);
630TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
631TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
632TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr);
633TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While);
634TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
635TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
636TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
637TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
638TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
639TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
640
641TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
642TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
643TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
644
645TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);
646
647#define TVM_TMP_STR(x) #x
648
649#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \
650 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \
651 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \
652 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \
653 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64);
654
655TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float);
656TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt);
657TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int);
658
659#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \
660 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \
661 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \
662 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \
663 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \
664 TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64);
665
666#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \
667 TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \
668 TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \
669 TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \
670 TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64);
671
672TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
673TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
674TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
675
676TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
677TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
678TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
679
680TVM_REGISTER_GLOBAL("script.ir_builder.tir.min")
681 .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); });
682TVM_REGISTER_GLOBAL("script.ir_builder.tir.max")
683 .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); });
684} // namespace tir
685} // namespace ir_builder
686} // namespace script
687} // namespace tvm
688