1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/arith/analyzer.h> |
20 | #include <tvm/script/ir_builder/tir/ir.h> |
21 | |
22 | #include "./utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace script { |
26 | namespace ir_builder { |
27 | namespace tir { |
28 | |
29 | using tvm::tir::IterVar; |
30 | |
31 | Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data, |
32 | Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset, |
33 | String storage_scope, int align, int offset_factor, String buffer_type, |
34 | Optional<Array<IntImm>> axis_separators) { |
35 | Var buffer_data; |
36 | if (!data.defined()) { |
37 | DataType storage_dtype = dtype; |
38 | if (storage_dtype == DataType::Bool()) { |
39 | storage_dtype = DataType::Int(8); |
40 | } |
41 | buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); |
42 | } else { |
43 | buffer_data = data.value(); |
44 | } |
45 | if (!elem_offset.defined() && offset_factor) { |
46 | DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; |
47 | elem_offset = tvm::tir::Var("elem_offset" , shape_dtype); |
48 | } |
49 | return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()), |
50 | elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, |
51 | (buffer_type == "auto_broadcast" ) ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault, |
52 | axis_separators.value_or(Array<IntImm>())); |
53 | } |
54 | |
55 | PrimFuncFrame PrimFunc() { |
56 | ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>(); |
57 | n->name = NullOpt; |
58 | n->args.clear(); |
59 | n->ret_type = NullOpt; |
60 | n->buffer_map.clear(); |
61 | n->attrs = NullOpt; |
62 | n->env_threads.clear(); |
63 | n->root_alloc_buffers.clear(); |
64 | return PrimFuncFrame(n); |
65 | } |
66 | |
67 | Var Arg(String name, Var var) { |
68 | PrimFuncFrame frame = FindPrimFuncFrame("T.Arg" ); |
69 | details::Namer::Name(var, name); |
70 | frame->args.push_back(var); |
71 | return var; |
72 | } |
73 | |
74 | Buffer Arg(String name, Buffer buffer) { |
75 | PrimFuncFrame frame = FindPrimFuncFrame("T.Arg" ); |
76 | details::Namer::Name(buffer, name); |
77 | Var handle(buffer->name + "_handle" , DataType::Handle()); |
78 | frame->args.push_back(handle); |
79 | frame->buffer_map.Set(handle, buffer); |
80 | return buffer; |
81 | } |
82 | |
83 | void FuncName(String name) { |
84 | PrimFuncFrame frame = FindPrimFuncFrame("T.func_name" ); |
85 | if (frame->name.defined()) { |
86 | LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); |
87 | } |
88 | frame->name = name; |
89 | } |
90 | |
91 | void FuncAttrs(Map<String, ObjectRef> attrs) { |
92 | using namespace tvm::tir; |
93 | PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr" ); |
94 | if (frame->attrs.defined()) { |
95 | LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; |
96 | } |
97 | frame->attrs = attrs; |
98 | } |
99 | |
100 | tvm::Type FuncRet(tvm::Type ret_type) { |
101 | PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type" ); |
102 | if (frame->ret_type.defined()) { |
103 | LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is " |
104 | << frame->ret_type.value(); |
105 | } |
106 | frame->ret_type = ret_type; |
107 | return ret_type; |
108 | } |
109 | |
110 | Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype, Optional<Var> data, |
111 | Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align, |
112 | int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) { |
113 | Buffer buffer = BufferDecl(shape, dtype, "" , data, strides, elem_offset, storage_scope, align, |
114 | offset_factor, buffer_type_str, axis_separators); |
115 | if (const auto* var = param.as<tvm::tir::VarNode>()) { |
116 | PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer" ); |
117 | Var v = GetRef<Var>(var); |
118 | for (auto const& arg : frame->args) { |
119 | if (arg.same_as(v)) { |
120 | frame->buffer_map.Set(v, buffer); |
121 | return buffer; |
122 | } |
123 | } |
124 | LOG(FATAL) << "ValueError: Can not bind non-input param to buffer." ; |
125 | } else if (const auto* buffer_load = param.as<tvm::tir::BufferLoadNode>()) { |
126 | BlockFrame frame = FindBlockFrame("T.match_buffer" ); |
127 | frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( |
128 | buffer, BufferRegionFromLoad(GetRef<tvm::tir::BufferLoad>(buffer_load)))); |
129 | } else if (const auto* buffer_region = param.as<tvm::tir::BufferRegionNode>()) { |
130 | BlockFrame frame = FindBlockFrame("T.match_buffer" ); |
131 | frame->match_buffers.push_back( |
132 | tvm::tir::MatchBufferRegion(buffer, GetRef<tvm::tir::BufferRegion>(buffer_region))); |
133 | } else { |
134 | LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer." ; |
135 | } |
136 | return buffer; |
137 | } |
138 | |
139 | BlockFrame Block(String name, bool no_realize) { |
140 | ObjectPtr<BlockFrameNode> n = make_object<BlockFrameNode>(); |
141 | n->name = name; |
142 | n->iter_vars.clear(); |
143 | n->reads = NullOpt; |
144 | n->writes = NullOpt; |
145 | n->init = NullOpt; |
146 | n->alloc_buffers.clear(); |
147 | n->match_buffers.clear(); |
148 | n->annotations = NullOpt; |
149 | n->iter_values.clear(); |
150 | n->predicate = NullOpt; |
151 | n->no_realize = no_realize; |
152 | return BlockFrame(n); |
153 | } |
154 | |
155 | BlockInitFrame Init() { return BlockInitFrame(make_object<BlockInitFrameNode>()); } |
156 | |
157 | void Where(PrimExpr predicate) { |
158 | BlockFrame frame = FindBlockFrame("T.where" ); |
159 | if (frame->predicate.defined()) { |
160 | LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is " |
161 | << frame->predicate; |
162 | } |
163 | frame->predicate = predicate; |
164 | } |
165 | |
166 | void Reads(Array<ObjectRef> buffer_slices) { |
167 | using namespace tvm::tir; |
168 | BlockFrame frame = FindBlockFrame("T.reads" ); |
169 | if (frame->reads.defined()) { |
170 | LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; |
171 | } |
172 | Array<BufferRegion> reads; |
173 | for (const ObjectRef& obj : buffer_slices) { |
174 | if (const auto* buffer_region = obj.as<BufferRegionNode>()) { |
175 | reads.push_back(GetRef<BufferRegion>(buffer_region)); |
176 | } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) { |
177 | reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load))); |
178 | } else { |
179 | LOG(FATAL) << "Invalid type for buffer reads." ; |
180 | } |
181 | } |
182 | frame->reads = reads; |
183 | } |
184 | |
185 | void Writes(Array<ObjectRef> buffer_slices) { |
186 | using namespace tvm::tir; |
187 | BlockFrame frame = FindBlockFrame("T.writes" ); |
188 | if (frame->writes.defined()) { |
189 | LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " |
190 | << frame->writes; |
191 | } |
192 | Array<BufferRegion> writes; |
193 | for (const ObjectRef& obj : buffer_slices) { |
194 | if (const auto* buffer_region = obj.as<BufferRegionNode>()) { |
195 | writes.push_back(GetRef<BufferRegion>(buffer_region)); |
196 | } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) { |
197 | writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load))); |
198 | } else { |
199 | LOG(FATAL) << "Invalid type for buffer writes." ; |
200 | } |
201 | } |
202 | frame->writes = writes; |
203 | } |
204 | |
205 | void BlockAttrs(Map<String, ObjectRef> attrs) { |
206 | BlockFrame frame = FindBlockFrame("T.block_attr" ); |
207 | if (frame->annotations.defined()) { |
208 | LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " << frame->annotations; |
209 | } |
210 | frame->annotations = attrs; |
211 | } |
212 | |
213 | Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data, |
214 | Array<PrimExpr> strides, PrimExpr elem_offset, String storage_scope, int align, |
215 | int offset_factor, String buffer_type_str, Array<IntImm> axis_separators) { |
216 | Buffer buffer = BufferDecl(shape, dtype, "" , data, strides, elem_offset, storage_scope, align, |
217 | offset_factor, buffer_type_str, axis_separators); |
218 | IRBuilder builder = IRBuilder::Current(); |
219 | if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) { |
220 | frame.value()->alloc_buffers.push_back(buffer); |
221 | } else if (Optional<PrimFuncFrame> frame = builder->GetLastFrame<PrimFuncFrame>()) { |
222 | frame.value()->root_alloc_buffers.push_back(buffer); |
223 | } else { |
224 | LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " |
225 | "'T.alloc_buffer' is called under T.block() or T.prim_func()" ; |
226 | } |
227 | return buffer; |
228 | } |
229 | namespace axis { |
230 | |
231 | IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { |
232 | if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) { |
233 | BlockFrame frame = opt_frame.value(); |
234 | frame->iter_vars.push_back(iter_var); |
235 | frame->iter_values.push_back(binding); |
236 | } else { |
237 | LOG(FATAL) << "TypeError: The last frame is not BlockFrame" ; |
238 | } |
239 | return iter_var; |
240 | } |
241 | |
242 | #define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \ |
243 | Var Method(Range dom, PrimExpr binding, DataType dtype) { \ |
244 | ICHECK(dom.defined()) << Name << " axis must have a domain"; \ |
245 | int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \ |
246 | return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \ |
247 | /*iter_type=*/Kind, /*thread_tag=*/""), \ |
248 | binding) \ |
249 | ->var; \ |
250 | } |
251 | TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial" ); |
252 | TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction" ); |
253 | TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan" ); |
254 | TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque" ); |
255 | #undef TVM_TIR_IR_BUILDER_AXIS |
256 | |
257 | Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) { |
258 | using namespace tvm::tir; |
259 | Array<Var> results; |
260 | ICHECK_EQ(kinds.size(), bindings.size()); |
261 | int n = bindings.size(); |
262 | results.reserve(n); |
263 | for (int i = 0; i < n; ++i) { |
264 | char c = kinds.c_str()[i]; |
265 | PrimExpr e = bindings[i]; |
266 | const VarNode* v = e.as<VarNode>(); |
267 | ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap" ; |
268 | Range dom{nullptr}; |
269 | for (const auto& frame : IRBuilder::Current()->frames) { |
270 | if (const auto* for_frame = frame.as<ForFrameNode>()) { |
271 | ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); |
272 | int n = for_frame->doms.size(); |
273 | for (int i = 0; i < n; ++i) { |
274 | if (for_frame->vars[i].get() == v) { |
275 | dom = for_frame->doms[i]; |
276 | break; |
277 | } |
278 | } |
279 | if (dom.defined()) { |
280 | break; |
281 | } |
282 | } |
283 | } |
284 | ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v); |
285 | DataType dtype = v->dtype; |
286 | if (c == 'S') { |
287 | results.push_back(PushBlockVar(IterVar(/*dom=*/dom, |
288 | /*var=*/Var("" , dtype), |
289 | /*iter_type=*/IterVarType::kDataPar, |
290 | /*thread_tag=*/"" ), |
291 | e) |
292 | ->var); |
293 | } else if (c == 'R') { |
294 | results.push_back(PushBlockVar(IterVar(/*dom=*/dom, |
295 | /*var=*/Var("" , dtype), |
296 | /*iter_type=*/IterVarType::kCommReduce, |
297 | /*thread_tag=*/"" ), |
298 | e) |
299 | ->var); |
300 | } else { |
301 | LOG(FATAL) << "Unknown axis kind: " << c; |
302 | } |
303 | } |
304 | return results; |
305 | } |
306 | |
307 | } // namespace axis |
308 | |
309 | #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ |
310 | ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \ |
311 | PrimExpr min = start; \ |
312 | PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ |
313 | ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); \ |
314 | int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ |
315 | n->vars = {Var("v", DataType::Int(bits))}; \ |
316 | n->doms = {Range::FromMinExtent(min, extent)}; \ |
317 | n->f_make_for_loop = [annotations](Array<Var> vars, Array<Range> doms, tvm::tir::Stmt body) { \ |
318 | ICHECK_EQ(vars.size(), 1); \ |
319 | ICHECK_EQ(doms.size(), 1); \ |
320 | return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ |
321 | annotations.value_or(Map<String, ObjectRef>())); \ |
322 | }; \ |
323 | return ForFrame(n); \ |
324 | } |
325 | |
326 | TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); |
327 | TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel); |
328 | TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized); |
329 | TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); |
330 | |
331 | #undef TVM_TIR_IR_BUILDER_FOR_FRAME |
332 | |
333 | ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, |
334 | Optional<Map<String, ObjectRef>> annotations) { |
335 | using namespace tvm::tir; |
336 | PrimExpr min = start; |
337 | PrimExpr extent = arith::Analyzer().Simplify(stop - start); |
338 | ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); |
339 | int bits = std::max(min.dtype().bits(), extent.dtype().bits()); |
340 | n->vars = {Var("v" , DataType::Int(bits))}; |
341 | n->doms = {Range::FromMinExtent(min, extent)}; |
342 | n->f_make_for_loop = [annotations, thread](Array<Var> vars, Array<Range> doms, Stmt body) -> For { |
343 | ICHECK_EQ(vars.size(), 1); |
344 | ICHECK_EQ(doms.size(), 1); |
345 | IterVar iter_var(Range(nullptr), Var("iter" , DataType::Int(32)), IterVarType::kThreadIndex, |
346 | thread); |
347 | return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, |
348 | annotations.value_or(Map<String, ObjectRef>())); |
349 | }; |
350 | return ForFrame(n); |
351 | } |
352 | |
353 | ForFrame Grid(Array<PrimExpr> extents) { |
354 | using namespace tvm::tir; |
355 | ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); |
356 | n->vars.reserve(extents.size()); |
357 | n->doms.reserve(extents.size()); |
358 | for (const auto& extent : extents) { |
359 | DataType dtype = extent.dtype(); |
360 | n->vars.push_back(Var("v" , extent.dtype())); |
361 | n->doms.push_back(Range(make_const(dtype, 0), extent)); |
362 | } |
363 | n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, Stmt body) -> Stmt { |
364 | ICHECK_EQ(vars.size(), doms.size()); |
365 | int n = vars.size(); |
366 | for (int i = n - 1; i >= 0; --i) { |
367 | Range dom = doms[i]; |
368 | Var var = vars[i]; |
369 | body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), |
370 | /*thread_binding=*/NullOpt, /*annotations=*/{}); |
371 | } |
372 | return body; |
373 | }; |
374 | return ForFrame(n); |
375 | } |
376 | |
377 | AssertFrame Assert(PrimExpr condition, String message) { |
378 | ObjectPtr<AssertFrameNode> n = make_object<AssertFrameNode>(); |
379 | n->condition = condition; |
380 | n->message = tvm::tir::StringImm(message); |
381 | return AssertFrame(n); |
382 | } |
383 | |
384 | LetFrame Let(Var var, PrimExpr value) { |
385 | ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>(); |
386 | n->var = var; |
387 | n->value = value; |
388 | return LetFrame(n); |
389 | } |
390 | |
391 | LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { |
392 | IterVar iter_var{nullptr}; |
393 | |
394 | if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) { |
395 | if (Optional<IterVar> opt_iter_var = opt_frame.value()->env_threads.Get(var)) { |
396 | iter_var = opt_iter_var.value(); |
397 | } else { |
398 | LOG(FATAL) << "ValueError: " << var->name_hint |
399 | << " is not an env_thread created using T.env_thread." ; |
400 | } |
401 | } else { |
402 | LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc" ; |
403 | } |
404 | ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>(); |
405 | if (!iter_var->dom.defined()) { |
406 | const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent); |
407 | } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { |
408 | LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " |
409 | << iter_var->dom->extent << " vs " << extent; |
410 | } |
411 | n->iter_var = iter_var; |
412 | n->extent = extent; |
413 | n->attr_key = iter_var->thread_tag == "vthread" ? "virtual_thread" : "thread_extent" ; |
414 | return LaunchThreadFrame(n); |
415 | } |
416 | |
417 | RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, |
418 | PrimExpr condition) { |
419 | ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>(); |
420 | n->buffer_slice = buffer_slice; |
421 | n->storage_scope = storage_scope; |
422 | n->condition = condition; |
423 | return RealizeFrame(n); |
424 | } |
425 | |
426 | AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_scope, |
427 | Optional<PrimExpr> condition, Optional<Map<String, ObjectRef>> annotations) { |
428 | ObjectPtr<AllocateFrameNode> n = make_object<AllocateFrameNode>(); |
429 | n->extents = extents; |
430 | n->dtype = dtype; |
431 | n->storage_scope = storage_scope; |
432 | n->condition = condition.value_or(tvm::Bool(true)); |
433 | n->annotations = annotations.value_or(Map<String, ObjectRef>()); |
434 | n->buffer_var = Var("" , tvm::PointerType(tvm::PrimType(dtype), storage_scope)); |
435 | return AllocateFrame(n); |
436 | } |
437 | |
438 | AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, |
439 | Array<PrimExpr> extents, |
440 | Optional<Map<String, ObjectRef>> annotations) { |
441 | ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>(); |
442 | n->dtype = dtype; |
443 | n->extents = extents; |
444 | n->data = data; |
445 | n->annotations = annotations.value_or(Map<String, ObjectRef>()); |
446 | n->buffer_var = Var("" , tvm::PointerType(tvm::PrimType(dtype))); |
447 | return AllocateConstFrame(n); |
448 | } |
449 | |
450 | AttrFrame Attr(ObjectRef node, String attr_key, PrimExpr value) { |
451 | ObjectPtr<AttrFrameNode> n = make_object<AttrFrameNode>(); |
452 | n->node = node; |
453 | n->attr_key = attr_key; |
454 | n->value = value; |
455 | return AttrFrame(n); |
456 | } |
457 | |
458 | WhileFrame While(PrimExpr condition) { |
459 | ObjectPtr<WhileFrameNode> n = make_object<WhileFrameNode>(); |
460 | n->condition = condition; |
461 | return WhileFrame(n); |
462 | } |
463 | |
464 | IfFrame If(PrimExpr condition) { |
465 | ObjectPtr<IfFrameNode> n = make_object<IfFrameNode>(); |
466 | n->condition = condition; |
467 | n->then_stmts = NullOpt; |
468 | n->else_stmts = NullOpt; |
469 | return IfFrame(n); |
470 | } |
471 | |
472 | ThenFrame Then() { |
473 | ObjectPtr<ThenFrameNode> n = make_object<ThenFrameNode>(); |
474 | return ThenFrame(n); |
475 | } |
476 | |
477 | ElseFrame Else() { |
478 | ObjectPtr<ElseFrameNode> n = make_object<ElseFrameNode>(); |
479 | return ElseFrame(n); |
480 | } |
481 | |
482 | Var EnvThread(String thread_tag) { |
483 | IterVar iter_var(Range{nullptr}, Var("" , DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex, |
484 | thread_tag); |
485 | Var var = iter_var->var; |
486 | if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) { |
487 | opt_frame.value()->env_threads.Set(var, iter_var); |
488 | } else { |
489 | LOG(FATAL) << "EnvThread can only be used inside a PrimFunc" ; |
490 | } |
491 | return var; |
492 | } |
493 | |
494 | void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) { |
495 | runtime::DataType buffer_dtype = buffer->dtype; |
496 | int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; |
497 | runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes); |
498 | runtime::DataType rhs_dtype = value->dtype; |
499 | if (lhs_dtype != rhs_dtype) { |
500 | if (lhs_dtype.lanes() != rhs_dtype.lanes()) { |
501 | LOG(FATAL) << "TypeError: Incompatible types in BufferStore" |
502 | << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype |
503 | << "`, indexing lanes: " << index_lanes; |
504 | } |
505 | if (lhs_dtype.code() != rhs_dtype.code()) { |
506 | if ( |
507 | // Case 1. lhs is handle, and rhs needs to be casted to handle. |
508 | (lhs_dtype.code() == runtime::DataType::kHandle) || |
509 | // Case 2. rhs is handle, and it needs to be casted to non-handle. |
510 | (rhs_dtype.code() == runtime::DataType::kHandle) || |
511 | // Case 3. rhs is float or bfloat, and casting to non-float can lose precision. |
512 | ((lhs_dtype.code() == runtime::DataType::kInt || |
513 | lhs_dtype.code() == runtime::DataType::kUInt) && |
514 | (rhs_dtype.code() == runtime::DataType::kFloat || |
515 | rhs_dtype.code() == runtime::DataType::kBFloat))) { |
516 | LOG(WARNING) << "Casting in BufferStore may lose precision" |
517 | << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype |
518 | << "`, indexing lanes: " << index_lanes; |
519 | } |
520 | } |
521 | value = tvm::cast(lhs_dtype, value); |
522 | } |
523 | AddToParent(tvm::tir::BufferStore(buffer, value, indices)); |
524 | } |
525 | |
526 | void Prefetch(Buffer buffer, Array<Range> bounds) { |
527 | AddToParent(tvm::tir::Prefetch(buffer, bounds)); |
528 | } |
529 | |
530 | DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_name, |
531 | Optional<Var> data, Optional<Array<PrimExpr>> strides, |
532 | Optional<PrimExpr> elem_offset, String storage_scope, int align, |
533 | int offset_factor, String buffer_type, |
534 | Optional<Array<IntImm>> axis_separators) { |
535 | ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>(); |
536 | n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, |
537 | align, offset_factor, buffer_type, axis_separators); |
538 | n->allocated = data.defined(); |
539 | return DeclBufferFrame(n); |
540 | } |
541 | |
542 | void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } |
543 | |
544 | PrimExpr Ptr(runtime::DataType dtype, String storage_scope) { |
545 | return tvm::tir::Var("" , tvm::PointerType(PrimType(dtype), storage_scope)); |
546 | } |
547 | |
548 | using tvm::script::ir_builder::details::Namer; |
549 | |
550 | TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
551 | .set_dispatch<tvm::tir::BufferNode>([](const ObjectRef& node, String name) -> void { |
552 | tvm::tir::BufferNode* buffer = |
553 | const_cast<tvm::tir::BufferNode*>(node.as<tvm::tir::BufferNode>()); |
554 | buffer->name = name; |
555 | Namer::Name(buffer->data, name); |
556 | int n = buffer->strides.size(); |
557 | for (int i = 0; i < n; ++i) { |
558 | PrimExpr e = buffer->strides[i]; |
559 | if (const tvm::tir::VarNode* v = e.as<tvm::tir::VarNode>()) { |
560 | Namer::Name(GetRef<tvm::tir::Var>(v), name + "_s" + std::to_string(i)); |
561 | } |
562 | } |
563 | }); |
564 | |
565 | TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
566 | .set_dispatch<tvm::tir::SizeVarNode>([](const ObjectRef& node, String name) -> void { |
567 | using namespace tvm::tir; |
568 | SizeVarNode* var = const_cast<SizeVarNode*>(node.as<SizeVarNode>()); |
569 | var->name_hint = name; |
570 | }); |
571 | |
572 | TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
573 | .set_dispatch<tvm::tir::VarNode>([](const ObjectRef& node, String name) -> void { |
574 | using namespace tvm::tir; |
575 | VarNode* var = const_cast<VarNode*>(node.as<VarNode>()); |
576 | var->name_hint = name; |
577 | }); |
578 | |
579 | TVM_STATIC_IR_FUNCTOR(Namer, vtable) |
580 | .set_dispatch<tvm::tir::IterVarNode>([](const ObjectRef& node, String name) -> void { |
581 | using namespace tvm::tir; |
582 | IterVarNode* var = const_cast<IterVarNode*>(node.as<IterVarNode>()); |
583 | Namer::Name(var->var, name); |
584 | }); |
585 | |
586 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferDecl" ).set_body_typed(BufferDecl); |
587 | |
588 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc" ).set_body_typed(PrimFunc); |
589 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg" ) |
590 | .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { |
591 | using namespace tvm::tir; |
592 | if (const auto* var = obj.as<VarNode>()) { |
593 | return Arg(name, GetRef<tvm::tir::Var>(var)); |
594 | } |
595 | if (const auto* buffer = obj.as<BufferNode>()) { |
596 | return Arg(name, GetRef<Buffer>(buffer)); |
597 | } |
598 | LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); |
599 | throw; |
600 | }); |
601 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName" ).set_body_typed(FuncName); |
602 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs" ).set_body_typed(FuncAttrs); |
603 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet" ).set_body_typed(FuncRet); |
604 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer" ).set_body_typed(MatchBuffer); |
605 | |
606 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block" ).set_body_typed(Block); |
607 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init" ).set_body_typed(Init); |
608 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where" ).set_body_typed(Where); |
609 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads" ).set_body_typed(Reads); |
610 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes" ).set_body_typed(Writes); |
611 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs" ).set_body_typed(BlockAttrs); |
612 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer" ).set_body_typed(AllocBuffer); |
613 | |
614 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial" ).set_body_typed(axis::Spatial); |
615 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce" ).set_body_typed(axis::Reduce); |
616 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan" ).set_body_typed(axis::Scan); |
617 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque" ).set_body_typed(axis::Opaque); |
618 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap" ).set_body_typed(axis::Remap); |
619 | |
620 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial" ).set_body_typed(Serial); |
621 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel" ).set_body_typed(Parallel); |
622 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized" ).set_body_typed(Vectorized); |
623 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll" ).set_body_typed(Unroll); |
624 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding" ).set_body_typed(ThreadBinding); |
625 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid" ).set_body_typed(Grid); |
626 | |
627 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert" ).set_body_typed(Assert); |
628 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let" ).set_body_typed(Let); |
629 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate" ).set_body_typed(Allocate); |
630 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst" ).set_body_typed(AllocateConst); |
631 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize" ).set_body_typed(Realize); |
632 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr" ).set_body_typed(Attr); |
633 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.While" ).set_body_typed(While); |
634 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.If" ).set_body_typed(If); |
635 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then" ).set_body_typed(Then); |
636 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else" ).set_body_typed(Else); |
637 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer" ).set_body_typed(DeclBuffer); |
638 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread" ).set_body_typed(LaunchThread); |
639 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread" ).set_body_typed(EnvThread); |
640 | |
641 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore" ).set_body_typed(BufferStore); |
642 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch" ).set_body_typed(Prefetch); |
643 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate" ).set_body_typed(Evaluate); |
644 | |
645 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr" ).set_body_typed(Ptr); |
646 | |
647 | #define TVM_TMP_STR(x) #x |
648 | |
649 | #define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \ |
650 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ |
651 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ |
652 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ |
653 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); |
654 | |
655 | TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float" , Float); |
656 | TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt" , UInt); |
657 | TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int" , Int); |
658 | |
659 | #define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \ |
660 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ |
661 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ |
662 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ |
663 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ |
664 | TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); |
665 | |
666 | #define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ |
667 | TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ |
668 | TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ |
669 | TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ |
670 | TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); |
671 | |
672 | TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float" , Float); |
673 | TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt" , UInt); |
674 | TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int" , Int); |
675 | |
676 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean" ).set_body_typed(Boolean); |
677 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle" ).set_body_typed(Handle); |
678 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void" ).set_body_typed(Void); |
679 | |
680 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.min" ) |
681 | .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); |
682 | TVM_REGISTER_GLOBAL("script.ir_builder.tir.max" ) |
683 | .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); |
684 | } // namespace tir |
685 | } // namespace ir_builder |
686 | } // namespace script |
687 | } // namespace tvm |
688 | |