1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/stmt.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/op.h> |
26 | #include <tvm/tir/op_attr_types.h> |
27 | #include <tvm/tir/stmt.h> |
28 | |
29 | #include "buffer_common.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | // LetStmt |
35 | LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { |
36 | ICHECK(value.defined()); |
37 | ICHECK(body.defined()); |
38 | auto vdtype = value.dtype(); |
39 | // It is still valid to bind a pointer type |
40 | // var to a value that is of type handle. |
41 | if (var->type_annotation.as<PointerTypeNode>()) { |
42 | ICHECK(vdtype.is_handle()); |
43 | } else { |
44 | ICHECK_EQ(value.dtype(), var.dtype()); |
45 | } |
46 | |
47 | ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>(); |
48 | node->var = std::move(var); |
49 | node->value = std::move(value); |
50 | node->body = std::move(body); |
51 | node->span = std::move(span); |
52 | data_ = std::move(node); |
53 | } |
54 | |
55 | TVM_REGISTER_GLOBAL("tir.LetStmt" ) |
56 | .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { |
57 | return LetStmt(var, value, body, span); |
58 | }); |
59 | |
60 | TVM_REGISTER_NODE_TYPE(LetStmtNode); |
61 | |
62 | // AttrStmt |
63 | AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { |
64 | auto n = make_object<AttrStmtNode>(); |
65 | n->node = node; |
66 | n->attr_key = std::move(attr_key); |
67 | n->value = std::move(value); |
68 | n->body = std::move(body); |
69 | n->span = std::move(span); |
70 | data_ = std::move(n); |
71 | } |
72 | |
73 | TVM_REGISTER_GLOBAL("tir.AttrStmt" ) |
74 | .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { |
75 | return AttrStmt(node, attr_key, value, body, span); |
76 | }); |
77 | |
78 | TVM_REGISTER_NODE_TYPE(AttrStmtNode); |
79 | |
80 | // AssertStmt |
81 | AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { |
82 | ICHECK(condition.defined()); |
83 | ICHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>()) |
84 | << "TypeError: AssertStmt message must be an int or string:" << message << "\n" ; |
85 | |
86 | ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>(); |
87 | node->condition = std::move(condition); |
88 | node->message = std::move(message); |
89 | node->body = std::move(body); |
90 | node->span = std::move(span); |
91 | data_ = std::move(node); |
92 | } |
93 | |
94 | TVM_REGISTER_NODE_TYPE(AssertStmtNode); |
95 | |
96 | TVM_REGISTER_GLOBAL("tir.AssertStmt" ) |
97 | .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span span) { |
98 | if (const auto* str = message.as<StringObj>()) { |
99 | auto msg = StringImm(str->data); |
100 | return AssertStmt(condition, msg, body, span); |
101 | } else { |
102 | return AssertStmt(condition, Downcast<PrimExpr>(message), body, span); |
103 | } |
104 | }); |
105 | |
106 | // For |
107 | For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, |
108 | Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) { |
109 | ICHECK(min.defined()); |
110 | ICHECK(extent.defined()); |
111 | ICHECK(min.dtype().is_scalar()); |
112 | ICHECK(extent.dtype().is_scalar()); |
113 | ICHECK(loop_var.dtype().is_scalar()); |
114 | ICHECK(body.defined()); |
115 | |
116 | // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them |
117 | // without raising errors. |
118 | auto try_promote_imm_dtype = [&](const PrimExpr& e) { |
119 | ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) |
120 | << " Loop variable's dtype (" << loop_var.dtype() |
121 | << ") is narrower than that of `min` or `extent` (" << e.dtype() << ")" ; |
122 | const IntImmNode* a = e.as<IntImmNode>(); |
123 | if (a && e.dtype().bits() < loop_var.dtype().bits()) { |
124 | return make_const(loop_var.dtype(), a->value); |
125 | } else { |
126 | return e; |
127 | } |
128 | }; |
129 | |
130 | min = try_promote_imm_dtype(min); |
131 | extent = try_promote_imm_dtype(extent); |
132 | |
133 | ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); |
134 | ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); |
135 | |
136 | ObjectPtr<ForNode> node = make_object<ForNode>(); |
137 | node->loop_var = std::move(loop_var); |
138 | node->min = std::move(min); |
139 | node->extent = std::move(extent); |
140 | node->kind = kind; |
141 | node->body = std::move(body); |
142 | node->thread_binding = std::move(thread_binding); |
143 | node->annotations = std::move(annotations); |
144 | node->span = std::move(span); |
145 | data_ = std::move(node); |
146 | } |
147 | |
148 | TVM_REGISTER_GLOBAL("tir.For" ).set_body_typed( |
149 | [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, |
150 | Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>> annotations, Span span) { |
151 | return For(loop_var, min, extent, static_cast<ForKind>(kind), body, thread_binding, |
152 | annotations.value_or(Map<String, ObjectRef>()), span); |
153 | }); |
154 | |
155 | TVM_REGISTER_NODE_TYPE(ForNode); |
156 | |
157 | std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) |
158 | switch (type) { |
159 | case ForKind::kSerial: |
160 | out << "for" ; |
161 | break; |
162 | case ForKind::kParallel: |
163 | out << "parallel" ; |
164 | break; |
165 | case ForKind::kUnrolled: |
166 | out << "unrolled" ; |
167 | break; |
168 | case ForKind::kVectorized: |
169 | out << "vectorized" ; |
170 | break; |
171 | case ForKind::kThreadBinding: |
172 | out << "launch_thread" ; |
173 | break; |
174 | } |
175 | return out; |
176 | } |
177 | |
178 | // While |
179 | While::While(PrimExpr condition, Stmt body, Span span) { |
180 | ICHECK(condition.defined()); |
181 | ICHECK(condition.dtype().is_scalar()); |
182 | ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial." ; |
183 | ICHECK(body.defined()); |
184 | |
185 | ObjectPtr<WhileNode> node = make_object<WhileNode>(); |
186 | node->condition = std::move(condition); |
187 | node->body = std::move(body); |
188 | node->span = std::move(span); |
189 | data_ = std::move(node); |
190 | } |
191 | |
192 | TVM_REGISTER_GLOBAL("tir.While" ).set_body_typed([](PrimExpr condition, Stmt body, Span span) { |
193 | return While(condition, body, span); |
194 | }); |
195 | |
196 | TVM_REGISTER_NODE_TYPE(WhileNode); |
197 | |
198 | // Store |
199 | Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { |
200 | LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint |
201 | << ". Use BufferStore instead." ; |
202 | ICHECK(value.defined()); |
203 | ICHECK(index.defined()); |
204 | ICHECK(predicate.defined()); |
205 | |
206 | // Assume that the array elements have 1 lane, unless a type |
207 | // annotation tells us otherwise. |
208 | int element_lanes = 1; |
209 | auto pointer_type = tir::GetPointerType(buffer_var->type_annotation); |
210 | if (pointer_type.has_value()) { |
211 | // Currently cannot check element type of array, see Load::Load |
212 | // for details. |
213 | |
214 | // TODO(Lunderberg): Uncomment this check once it can be applied. |
215 | // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 |
216 | // for discussion. |
217 | |
218 | // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of()) |
219 | // << "Type mismatch, cannot store type " << value.dtype() << " into buffer " |
220 | // << buffer_var->name_hint << " of type " << pointer_type.value(); |
221 | element_lanes = pointer_type->lanes(); |
222 | } |
223 | |
224 | ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) || |
225 | (value.dtype().lanes() == index.dtype().lanes())); |
226 | ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) || |
227 | (value.dtype().lanes() == index.dtype().lanes())); |
228 | |
229 | ObjectPtr<StoreNode> node = make_object<StoreNode>(); |
230 | node->buffer_var = std::move(buffer_var); |
231 | node->value = std::move(value); |
232 | node->index = std::move(index); |
233 | node->predicate = std::move(predicate); |
234 | node->span = std::move(span); |
235 | data_ = std::move(node); |
236 | } |
237 | |
238 | TVM_REGISTER_GLOBAL("tir.Store" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
239 | PrimExpr value = args[1]; |
240 | if (args.size() == 3) { |
241 | *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()), Span()); |
242 | } else if (args.size() == 4) { |
243 | *ret = Store(args[0], value, args[2], args[3], Span()); |
244 | } else { |
245 | *ret = Store(args[0], value, args[2], args[3], args[4]); |
246 | } |
247 | }); |
248 | |
249 | TVM_REGISTER_NODE_TYPE(StoreNode); |
250 | |
251 | // ProducerStore |
252 | ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices, |
253 | Span span) { |
254 | ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>(); |
255 | node->producer = std::move(producer); |
256 | node->value = std::move(value); |
257 | node->indices = std::move(indices); |
258 | node->span = std::move(span); |
259 | data_ = std::move(node); |
260 | } |
261 | |
262 | TVM_REGISTER_GLOBAL("tir.ProducerStore" ) |
263 | .set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices, Span span) { |
264 | return ProducerStore(producer, value, indices, span); |
265 | }); |
266 | |
267 | TVM_REGISTER_NODE_TYPE(ProducerStoreNode); |
268 | |
269 | // Allocate |
270 | Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition, |
271 | Stmt body, Map<String, ObjectRef> annotations, Span span) { |
272 | CHECK(IsPointerType(buffer_var->type_annotation, dtype) || |
273 | (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) |
274 | << "The allocated data type (" << dtype |
275 | << ") does not match the type annotation of the buffer " << buffer_var << " (" |
276 | << buffer_var->type_annotation |
277 | << "). The data type should be an element of the pointer type." ; |
278 | |
279 | for (size_t i = 0; i < extents.size(); ++i) { |
280 | ICHECK(extents[i].defined()); |
281 | ICHECK(extents[i].dtype().is_scalar()); |
282 | } |
283 | ICHECK(body.defined()); |
284 | ICHECK(condition.defined()); |
285 | ICHECK(condition.dtype().is_bool()); |
286 | |
287 | ObjectPtr<AllocateNode> node = make_object<AllocateNode>(); |
288 | node->buffer_var = std::move(buffer_var); |
289 | node->dtype = dtype; |
290 | node->extents = std::move(extents); |
291 | node->condition = std::move(condition); |
292 | node->body = std::move(body); |
293 | node->annotations = std::move(annotations); |
294 | node->span = std::move(span); |
295 | data_ = std::move(node); |
296 | } |
297 | |
298 | int64_t AllocateNode::ConstantAllocationSize(const Array<PrimExpr>& extents) { |
299 | int64_t result = 1; |
300 | for (size_t i = 0; i < extents.size(); ++i) { |
301 | if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) { |
302 | result *= int_size->value; |
303 | if (result > std::numeric_limits<int64_t>::max()) { |
304 | return 0; |
305 | } |
306 | } else { |
307 | return 0; |
308 | } |
309 | } |
310 | return static_cast<int64_t>(result); |
311 | } |
312 | |
313 | TVM_REGISTER_GLOBAL("tir.Allocate" ) |
314 | .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, |
315 | Stmt body, Map<String, ObjectRef> annotations, Span span) { |
316 | return Allocate(buffer_var, type, extents, condition, body, annotations, span); |
317 | }); |
318 | |
319 | TVM_REGISTER_NODE_TYPE(AllocateNode); |
320 | |
321 | // Const |
322 | // The constructor to create a IRNode with constant data |
323 | // depending on the type of ObjectRef, it will either |
324 | // create AllocateConstNode with irmod_storage_idx or data |
325 | AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents, |
326 | ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations, |
327 | Span span) { |
328 | ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) |
329 | << "The allocated data type (" << dtype |
330 | << ") does not match the type annotation of the buffer " << buffer_var << " (" |
331 | << buffer_var->type_annotation |
332 | << "). The data type should be an element of the pointer type." ; |
333 | |
334 | for (size_t i = 0; i < extents.size(); ++i) { |
335 | ICHECK(extents[i].defined()); |
336 | ICHECK(extents[i].dtype().is_scalar()); |
337 | } |
338 | ICHECK(body.defined()); |
339 | ICHECK(data_or_idx.defined()); |
340 | |
341 | ObjectPtr<AllocateConstNode> node = make_object<AllocateConstNode>(); |
342 | node->buffer_var = std::move(buffer_var); |
343 | node->dtype = dtype; |
344 | node->extents = std::move(extents); |
345 | node->body = std::move(body); |
346 | node->annotations = annotations; |
347 | node->span = std::move(span); |
348 | if (data_or_idx->IsInstance<runtime::NDArray::ContainerType>()) { |
349 | node->data = Optional<tvm::runtime::NDArray>(Downcast<runtime::NDArray>(data_or_idx)); |
350 | node->irmod_storage_idx = Optional<Integer>(); |
351 | } else if (data_or_idx->IsInstance<IntImmNode>()) { |
352 | node->data = Optional<tvm::runtime::NDArray>(); |
353 | node->irmod_storage_idx = Optional<Integer>(Downcast<Integer>(data_or_idx)); |
354 | } else { |
355 | LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey(); |
356 | } |
357 | data_ = std::move(node); |
358 | } |
359 | |
360 | int64_t AllocateConstNode::ConstantAllocationSize(const Array<PrimExpr>& extents) { |
361 | int64_t result = 1; |
362 | for (size_t i = 0; i < extents.size(); ++i) { |
363 | if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) { |
364 | result *= int_size->value; |
365 | if (result > std::numeric_limits<int64_t>::max()) { |
366 | return 0; |
367 | } |
368 | } else { |
369 | return 0; |
370 | } |
371 | } |
372 | return static_cast<int64_t>(result); |
373 | } |
374 | TVM_REGISTER_GLOBAL("tir.AllocateConst" ) |
375 | .set_body_typed([](Var buffer_var, DataType dtype, Array<PrimExpr> extents, |
376 | ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations, |
377 | Span span) { |
378 | return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations, span); |
379 | }); |
380 | |
381 | TVM_REGISTER_NODE_TYPE(AllocateConstNode); |
382 | |
383 | // DeclBuffer |
384 | DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { |
385 | ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>(); |
386 | node->buffer = std::move(buffer); |
387 | node->body = std::move(body); |
388 | node->span = std::move(span); |
389 | data_ = std::move(node); |
390 | } |
391 | |
392 | TVM_REGISTER_GLOBAL("tir.DeclBuffer" ).set_body_typed([](Buffer buffer, Stmt body, Span span) { |
393 | return DeclBuffer(buffer, body, span); |
394 | }); |
395 | |
396 | TVM_REGISTER_NODE_TYPE(DeclBufferNode); |
397 | |
398 | // ProducerRealize |
399 | ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, |
400 | Stmt body, String storage_scope, Span span) { |
401 | for (size_t i = 0; i < bounds.size(); ++i) { |
402 | ICHECK(bounds[i]->min.defined()); |
403 | ICHECK(bounds[i]->extent.defined()); |
404 | ICHECK(bounds[i]->min.dtype().is_scalar()); |
405 | ICHECK(bounds[i]->extent.dtype().is_scalar()); |
406 | } |
407 | ICHECK(body.defined()); |
408 | ICHECK(condition.defined()); |
409 | ICHECK(condition.dtype().is_bool()); |
410 | |
411 | ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>(); |
412 | node->producer = std::move(producer); |
413 | node->bounds = std::move(bounds); |
414 | node->condition = std::move(condition); |
415 | node->body = std::move(body); |
416 | node->span = std::move(span); |
417 | node->storage_scope = std::move(storage_scope); |
418 | data_ = std::move(node); |
419 | } |
420 | |
421 | TVM_REGISTER_GLOBAL("tir.ProducerRealize" ) |
422 | .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, |
423 | String storage_scope, Span span) { |
424 | return ProducerRealize(producer, bounds, condition, body, storage_scope, span); |
425 | }); |
426 | |
427 | TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); |
428 | |
429 | // Prefetch |
430 | Prefetch::Prefetch(Buffer buffer, Array<Range> bounds, Span span) { |
431 | data_ = make_object<PrefetchNode>(buffer, bounds, span); |
432 | } |
433 | |
434 | TVM_REGISTER_GLOBAL("tir.Prefetch" ) |
435 | .set_body_typed([](Buffer buffer, Array<Range> bounds, Span span) { |
436 | return Prefetch(buffer, bounds, span); |
437 | }); |
438 | |
439 | TVM_REGISTER_NODE_TYPE(PrefetchNode); |
440 | |
441 | // SeqStmt |
442 | SeqStmt::SeqStmt(Array<Stmt> seq, Span span) { |
443 | auto node = make_object<SeqStmtNode>(); |
444 | node->seq = std::move(seq); |
445 | node->span = std::move(span); |
446 | data_ = std::move(node); |
447 | } |
448 | |
449 | TVM_REGISTER_GLOBAL("tir.SeqStmt" ).set_body_typed([](Array<Stmt> seq, Span span) { |
450 | return SeqStmt(std::move(seq), span); |
451 | }); |
452 | |
453 | TVM_REGISTER_NODE_TYPE(SeqStmtNode); |
454 | |
455 | // IfThenElse |
456 | IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) { |
457 | ICHECK(condition.defined()); |
458 | ICHECK(then_case.defined()); |
459 | // else_case may be null. |
460 | ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>(); |
461 | node->condition = std::move(condition); |
462 | node->then_case = std::move(then_case); |
463 | node->else_case = std::move(else_case); |
464 | node->span = std::move(span); |
465 | data_ = std::move(node); |
466 | } |
467 | |
468 | TVM_REGISTER_NODE_TYPE(IfThenElseNode); |
469 | |
470 | TVM_REGISTER_GLOBAL("tir.IfThenElse" ) |
471 | .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { |
472 | return IfThenElse(condition, then_case, else_case, span); |
473 | }); |
474 | |
475 | // Evaluate |
476 | Evaluate::Evaluate(PrimExpr value, Span span) { |
477 | ICHECK(value.defined()); |
478 | |
479 | ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>(); |
480 | node->value = std::move(value); |
481 | node->span = std::move(span); |
482 | data_ = std::move(node); |
483 | } |
484 | |
485 | TVM_REGISTER_GLOBAL("tir.Evaluate" ).set_body_typed([](PrimExpr value, Span span) { |
486 | return Evaluate(value, span); |
487 | }); |
488 | |
489 | TVM_REGISTER_NODE_TYPE(EvaluateNode); |
490 | |
491 | // BufferStore |
492 | BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) { |
493 | ICHECK_EQ(buffer->shape.size(), indices.size()) |
494 | << "Buffer " << buffer->name << " is " << buffer->shape.size() |
495 | << "-dimensional, cannot be indexed with the " << indices.size() |
496 | << "-dimensional indices provided." ; |
497 | |
498 | for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) { |
499 | ICHECK(indices[i].dtype().is_scalar()) |
500 | << "Only the last index of a buffer access may be a vector type." ; |
501 | } |
502 | |
503 | int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; |
504 | int buffer_lanes = buffer->dtype.lanes(); |
505 | |
506 | ICHECK_EQ(index_lanes * buffer_lanes, value.dtype().lanes()) |
507 | << "Cannot store value with " << value.dtype().lanes() << ", expected value with " |
508 | << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes |
509 | << " buffer element lanes)" ; |
510 | if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) { |
511 | LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // |
512 | << "buffer's dtype is `" << buffer->dtype // |
513 | << "`, the lanes of indexing are: `" << index_lanes // |
514 | << "`, but RHS's dtype is `" << value.dtype() << "`" ; |
515 | } |
516 | |
517 | ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>(); |
518 | node->buffer = std::move(buffer); |
519 | node->value = std::move(value); |
520 | node->indices = std::move(indices); |
521 | node->span = std::move(span); |
522 | data_ = std::move(node); |
523 | } |
524 | |
525 | TVM_REGISTER_GLOBAL("tir.BufferStore" ) |
526 | .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) { |
527 | return BufferStore(buffer, value, indices, span); |
528 | }); |
529 | |
530 | TVM_REGISTER_NODE_TYPE(BufferStoreNode); |
531 | |
532 | // BufferRealize |
533 | BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body, |
534 | Span span) { |
535 | data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body, span); |
536 | } |
537 | |
538 | TVM_REGISTER_GLOBAL("tir.BufferRealize" ) |
539 | .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body, |
540 | Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); |
541 | |
542 | TVM_REGISTER_NODE_TYPE(BufferRealizeNode); |
543 | |
544 | // BufferRegion |
545 | BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) { |
546 | CHECK_EQ(buffer->shape.size(), region.size()) |
547 | << "The dimension between " << buffer << " and region " << region |
548 | << " mismatched, the buffer is " << buffer; |
549 | ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>(); |
550 | node->buffer = std::move(buffer); |
551 | node->region = std::move(region); |
552 | data_ = std::move(node); |
553 | } |
554 | |
555 | BufferRegion BufferRegion::FullRegion(Buffer buffer) { |
556 | Array<Range> region; |
557 | for (PrimExpr extent : buffer->shape) { |
558 | region.push_back(Range::FromMinExtent(0, extent)); |
559 | } |
560 | return BufferRegion(buffer, region); |
561 | } |
562 | |
563 | BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) { |
564 | Array<Range> region; |
565 | for (const PrimExpr& index : indices) { |
566 | if (const RampNode* ramp_index = index.as<RampNode>()) { |
567 | region.push_back( |
568 | Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes)); |
569 | } else { |
570 | region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1))); |
571 | } |
572 | } |
573 | return BufferRegion(buffer, region); |
574 | } |
575 | |
576 | TVM_REGISTER_GLOBAL("tir.BufferRegion" ).set_body_typed([](Buffer buffer, Array<Range> region) { |
577 | return BufferRegion(buffer, region); |
578 | }); |
579 | |
580 | TVM_REGISTER_NODE_TYPE(BufferRegionNode); |
581 | |
582 | // MatchBufferRegion |
583 | MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { |
584 | const Buffer& source_buffer = source->buffer; |
585 | arith::Analyzer analyzer; |
586 | // Check scope and dtype |
587 | CHECK_EQ(buffer.scope(), source_buffer.scope()) |
588 | << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " |
589 | << source_buffer.scope(); |
590 | CHECK_EQ(buffer->dtype, source_buffer->dtype) |
591 | << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " |
592 | << source_buffer->dtype; |
593 | |
594 | // Check data_alignment |
595 | CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) |
596 | << "Trying to match buffer to another one with lower alignment requirement " |
597 | << " required_alignment=" << buffer->data_alignment |
598 | << ", provided_alignment=" << source_buffer->data_alignment; |
599 | |
600 | // Check BufferType. AutoBroadcast is not allowed for now. |
601 | CHECK(buffer->buffer_type == BufferType::kDefault && |
602 | source_buffer->buffer_type == BufferType::kDefault) |
603 | << "AutoBroadcast is not allowed in MatchBuffer" ; |
604 | |
605 | // Validate shape |
606 | CHECK(source->region.size() >= buffer->shape.size()) |
607 | << "Dimension of source Region expected to be larger or equal than target buffer shape, but " |
608 | "got " |
609 | << source->region.size() << " vs. " << buffer->shape.size(); |
610 | size_t offset = source->region.size() - buffer->shape.size(); |
611 | for (size_t i = 0; i < offset; ++i) { |
612 | CHECK(analyzer.CanProve(source->region[i]->extent == 1)) |
613 | << "The higher dimension should be 1, but got " << source->region[i]->extent << "." ; |
614 | } |
615 | for (size_t i = 0; i < buffer->shape.size(); ++i) { |
616 | const Range& source_range = source->region[i + offset]; |
617 | const PrimExpr& buffer_shape = buffer->shape[i]; |
618 | if (!buffer_shape->IsInstance<VarNode>()) { |
619 | CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) |
620 | << "The dimension mismatched between source region and target buffer shape, got " |
621 | << source_range->extent << " vs. " << buffer_shape << "." ; |
622 | } |
623 | } |
624 | // Note that we do not check elem_offset and strides in this function |
625 | |
626 | // Construction |
627 | ObjectPtr<MatchBufferRegionNode> node = make_object<MatchBufferRegionNode>(); |
628 | node->buffer = std::move(buffer); |
629 | node->source = std::move(source); |
630 | data_ = std::move(node); |
631 | } |
632 | |
633 | TVM_REGISTER_GLOBAL("tir.MatchBufferRegion" ).set_body_typed([](Buffer buffer, BufferRegion source) { |
634 | return MatchBufferRegion(buffer, source); |
635 | }); |
636 | |
637 | TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); |
638 | |
639 | // Block |
640 | Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes, |
641 | String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers, |
642 | Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations, |
643 | Span span) { |
644 | ObjectPtr<BlockNode> node = make_object<BlockNode>(); |
645 | node->iter_vars = std::move(iter_vars); |
646 | node->reads = std::move(reads); |
647 | node->writes = std::move(writes); |
648 | node->name_hint = std::move(name_hint); |
649 | node->body = std::move(body); |
650 | node->init = std::move(init); |
651 | node->alloc_buffers = std::move(alloc_buffers); |
652 | node->match_buffers = std::move(match_buffers); |
653 | node->annotations = std::move(annotations); |
654 | node->span = std::move(span); |
655 | data_ = std::move(node); |
656 | } |
657 | |
658 | TVM_REGISTER_GLOBAL("tir.Block" ) |
659 | .set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads, |
660 | Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init, |
661 | Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers, |
662 | Map<String, ObjectRef> annotations, Span span) { |
663 | return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, |
664 | annotations, span); |
665 | }); |
666 | |
667 | TVM_REGISTER_NODE_TYPE(BlockNode); |
668 | |
669 | // BlockRealize |
670 | BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) { |
671 | CHECK_EQ(block->iter_vars.size(), values.size()) |
672 | << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values" ; |
673 | CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression" ; |
674 | ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>(); |
675 | node->iter_values = std::move(values); |
676 | node->predicate = std::move(predicate); |
677 | node->block = std::move(block); |
678 | node->span = std::move(span); |
679 | data_ = std::move(node); |
680 | } |
681 | |
682 | TVM_REGISTER_GLOBAL("tir.BlockRealize" ) |
683 | .set_body_typed([](Array<PrimExpr> iter_values, PrimExpr predicate, Block block, Span span) { |
684 | return BlockRealize(iter_values, predicate, block, span); |
685 | }); |
686 | |
687 | TVM_REGISTER_NODE_TYPE(BlockRealizeNode); |
688 | |
689 | PrimExpr TypeAnnotation(DataType dtype, Span span) { |
690 | static auto op = Op::Get("tir.type_annotation" ); |
691 | return tir::Call(dtype, op, {}, span); |
692 | } |
693 | |
694 | TVM_TIR_REGISTER_OP("type_annotation" ) |
695 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kPure)); |
696 | |
697 | } // namespace tir |
698 | } // namespace tvm |
699 | |