1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/stmt.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/op.h>
26#include <tvm/tir/op_attr_types.h>
27#include <tvm/tir/stmt.h>
28
29#include "buffer_common.h"
30
31namespace tvm {
32namespace tir {
33
34// LetStmt
35LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
36 ICHECK(value.defined());
37 ICHECK(body.defined());
38 auto vdtype = value.dtype();
39 // It is still valid to bind a pointer type
40 // var to a value that is of type handle.
41 if (var->type_annotation.as<PointerTypeNode>()) {
42 ICHECK(vdtype.is_handle());
43 } else {
44 ICHECK_EQ(value.dtype(), var.dtype());
45 }
46
47 ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
48 node->var = std::move(var);
49 node->value = std::move(value);
50 node->body = std::move(body);
51 node->span = std::move(span);
52 data_ = std::move(node);
53}
54
55TVM_REGISTER_GLOBAL("tir.LetStmt")
56 .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) {
57 return LetStmt(var, value, body, span);
58 });
59
60TVM_REGISTER_NODE_TYPE(LetStmtNode);
61
62// AttrStmt
63AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) {
64 auto n = make_object<AttrStmtNode>();
65 n->node = node;
66 n->attr_key = std::move(attr_key);
67 n->value = std::move(value);
68 n->body = std::move(body);
69 n->span = std::move(span);
70 data_ = std::move(n);
71}
72
73TVM_REGISTER_GLOBAL("tir.AttrStmt")
74 .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) {
75 return AttrStmt(node, attr_key, value, body, span);
76 });
77
78TVM_REGISTER_NODE_TYPE(AttrStmtNode);
79
80// AssertStmt
81AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) {
82 ICHECK(condition.defined());
83 ICHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>())
84 << "TypeError: AssertStmt message must be an int or string:" << message << "\n";
85
86 ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
87 node->condition = std::move(condition);
88 node->message = std::move(message);
89 node->body = std::move(body);
90 node->span = std::move(span);
91 data_ = std::move(node);
92}
93
94TVM_REGISTER_NODE_TYPE(AssertStmtNode);
95
96TVM_REGISTER_GLOBAL("tir.AssertStmt")
97 .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span span) {
98 if (const auto* str = message.as<StringObj>()) {
99 auto msg = StringImm(str->data);
100 return AssertStmt(condition, msg, body, span);
101 } else {
102 return AssertStmt(condition, Downcast<PrimExpr>(message), body, span);
103 }
104 });
105
106// For
107For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
108 Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) {
109 ICHECK(min.defined());
110 ICHECK(extent.defined());
111 ICHECK(min.dtype().is_scalar());
112 ICHECK(extent.dtype().is_scalar());
113 ICHECK(loop_var.dtype().is_scalar());
114 ICHECK(body.defined());
115
116 // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them
117 // without raising errors.
118 auto try_promote_imm_dtype = [&](const PrimExpr& e) {
119 ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
120 << " Loop variable's dtype (" << loop_var.dtype()
121 << ") is narrower than that of `min` or `extent` (" << e.dtype() << ")";
122 const IntImmNode* a = e.as<IntImmNode>();
123 if (a && e.dtype().bits() < loop_var.dtype().bits()) {
124 return make_const(loop_var.dtype(), a->value);
125 } else {
126 return e;
127 }
128 };
129
130 min = try_promote_imm_dtype(min);
131 extent = try_promote_imm_dtype(extent);
132
133 ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype();
134 ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();
135
136 ObjectPtr<ForNode> node = make_object<ForNode>();
137 node->loop_var = std::move(loop_var);
138 node->min = std::move(min);
139 node->extent = std::move(extent);
140 node->kind = kind;
141 node->body = std::move(body);
142 node->thread_binding = std::move(thread_binding);
143 node->annotations = std::move(annotations);
144 node->span = std::move(span);
145 data_ = std::move(node);
146}
147
148TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
149 [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body,
150 Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>> annotations, Span span) {
151 return For(loop_var, min, extent, static_cast<ForKind>(kind), body, thread_binding,
152 annotations.value_or(Map<String, ObjectRef>()), span);
153 });
154
155TVM_REGISTER_NODE_TYPE(ForNode);
156
157std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
158 switch (type) {
159 case ForKind::kSerial:
160 out << "for";
161 break;
162 case ForKind::kParallel:
163 out << "parallel";
164 break;
165 case ForKind::kUnrolled:
166 out << "unrolled";
167 break;
168 case ForKind::kVectorized:
169 out << "vectorized";
170 break;
171 case ForKind::kThreadBinding:
172 out << "launch_thread";
173 break;
174 }
175 return out;
176}
177
178// While
179While::While(PrimExpr condition, Stmt body, Span span) {
180 ICHECK(condition.defined());
181 ICHECK(condition.dtype().is_scalar());
182 ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial.";
183 ICHECK(body.defined());
184
185 ObjectPtr<WhileNode> node = make_object<WhileNode>();
186 node->condition = std::move(condition);
187 node->body = std::move(body);
188 node->span = std::move(span);
189 data_ = std::move(node);
190}
191
192TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) {
193 return While(condition, body, span);
194});
195
196TVM_REGISTER_NODE_TYPE(WhileNode);
197
198// Store
199Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) {
200 LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
201 << ". Use BufferStore instead.";
202 ICHECK(value.defined());
203 ICHECK(index.defined());
204 ICHECK(predicate.defined());
205
206 // Assume that the array elements have 1 lane, unless a type
207 // annotation tells us otherwise.
208 int element_lanes = 1;
209 auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
210 if (pointer_type.has_value()) {
211 // Currently cannot check element type of array, see Load::Load
212 // for details.
213
214 // TODO(Lunderberg): Uncomment this check once it can be applied.
215 // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
216 // for discussion.
217
218 // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of())
219 // << "Type mismatch, cannot store type " << value.dtype() << " into buffer "
220 // << buffer_var->name_hint << " of type " << pointer_type.value();
221 element_lanes = pointer_type->lanes();
222 }
223
224 ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
225 (value.dtype().lanes() == index.dtype().lanes()));
226 ICHECK((value.dtype().lanes() == element_lanes * predicate.dtype().lanes()) ||
227 (value.dtype().lanes() == index.dtype().lanes()));
228
229 ObjectPtr<StoreNode> node = make_object<StoreNode>();
230 node->buffer_var = std::move(buffer_var);
231 node->value = std::move(value);
232 node->index = std::move(index);
233 node->predicate = std::move(predicate);
234 node->span = std::move(span);
235 data_ = std::move(node);
236}
237
238TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
239 PrimExpr value = args[1];
240 if (args.size() == 3) {
241 *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()), Span());
242 } else if (args.size() == 4) {
243 *ret = Store(args[0], value, args[2], args[3], Span());
244 } else {
245 *ret = Store(args[0], value, args[2], args[3], args[4]);
246 }
247});
248
249TVM_REGISTER_NODE_TYPE(StoreNode);
250
251// ProducerStore
252ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
253 Span span) {
254 ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
255 node->producer = std::move(producer);
256 node->value = std::move(value);
257 node->indices = std::move(indices);
258 node->span = std::move(span);
259 data_ = std::move(node);
260}
261
262TVM_REGISTER_GLOBAL("tir.ProducerStore")
263 .set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices, Span span) {
264 return ProducerStore(producer, value, indices, span);
265 });
266
267TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
268
269// Allocate
270Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
271 Stmt body, Map<String, ObjectRef> annotations, Span span) {
272 CHECK(IsPointerType(buffer_var->type_annotation, dtype) ||
273 (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8))))
274 << "The allocated data type (" << dtype
275 << ") does not match the type annotation of the buffer " << buffer_var << " ("
276 << buffer_var->type_annotation
277 << "). The data type should be an element of the pointer type.";
278
279 for (size_t i = 0; i < extents.size(); ++i) {
280 ICHECK(extents[i].defined());
281 ICHECK(extents[i].dtype().is_scalar());
282 }
283 ICHECK(body.defined());
284 ICHECK(condition.defined());
285 ICHECK(condition.dtype().is_bool());
286
287 ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
288 node->buffer_var = std::move(buffer_var);
289 node->dtype = dtype;
290 node->extents = std::move(extents);
291 node->condition = std::move(condition);
292 node->body = std::move(body);
293 node->annotations = std::move(annotations);
294 node->span = std::move(span);
295 data_ = std::move(node);
296}
297
298int64_t AllocateNode::ConstantAllocationSize(const Array<PrimExpr>& extents) {
299 int64_t result = 1;
300 for (size_t i = 0; i < extents.size(); ++i) {
301 if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
302 result *= int_size->value;
303 if (result > std::numeric_limits<int64_t>::max()) {
304 return 0;
305 }
306 } else {
307 return 0;
308 }
309 }
310 return static_cast<int64_t>(result);
311}
312
313TVM_REGISTER_GLOBAL("tir.Allocate")
314 .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
315 Stmt body, Map<String, ObjectRef> annotations, Span span) {
316 return Allocate(buffer_var, type, extents, condition, body, annotations, span);
317 });
318
319TVM_REGISTER_NODE_TYPE(AllocateNode);
320
321// Const
322// The constructor to create a IRNode with constant data
323// depending on the type of ObjectRef, it will either
324// create AllocateConstNode with irmod_storage_idx or data
325AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
326 ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations,
327 Span span) {
328 ICHECK(IsPointerType(buffer_var->type_annotation, dtype))
329 << "The allocated data type (" << dtype
330 << ") does not match the type annotation of the buffer " << buffer_var << " ("
331 << buffer_var->type_annotation
332 << "). The data type should be an element of the pointer type.";
333
334 for (size_t i = 0; i < extents.size(); ++i) {
335 ICHECK(extents[i].defined());
336 ICHECK(extents[i].dtype().is_scalar());
337 }
338 ICHECK(body.defined());
339 ICHECK(data_or_idx.defined());
340
341 ObjectPtr<AllocateConstNode> node = make_object<AllocateConstNode>();
342 node->buffer_var = std::move(buffer_var);
343 node->dtype = dtype;
344 node->extents = std::move(extents);
345 node->body = std::move(body);
346 node->annotations = annotations;
347 node->span = std::move(span);
348 if (data_or_idx->IsInstance<runtime::NDArray::ContainerType>()) {
349 node->data = Optional<tvm::runtime::NDArray>(Downcast<runtime::NDArray>(data_or_idx));
350 node->irmod_storage_idx = Optional<Integer>();
351 } else if (data_or_idx->IsInstance<IntImmNode>()) {
352 node->data = Optional<tvm::runtime::NDArray>();
353 node->irmod_storage_idx = Optional<Integer>(Downcast<Integer>(data_or_idx));
354 } else {
355 LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey();
356 }
357 data_ = std::move(node);
358}
359
360int64_t AllocateConstNode::ConstantAllocationSize(const Array<PrimExpr>& extents) {
361 int64_t result = 1;
362 for (size_t i = 0; i < extents.size(); ++i) {
363 if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
364 result *= int_size->value;
365 if (result > std::numeric_limits<int64_t>::max()) {
366 return 0;
367 }
368 } else {
369 return 0;
370 }
371 }
372 return static_cast<int64_t>(result);
373}
374TVM_REGISTER_GLOBAL("tir.AllocateConst")
375 .set_body_typed([](Var buffer_var, DataType dtype, Array<PrimExpr> extents,
376 ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations,
377 Span span) {
378 return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations, span);
379 });
380
381TVM_REGISTER_NODE_TYPE(AllocateConstNode);
382
383// DeclBuffer
384DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
385 ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
386 node->buffer = std::move(buffer);
387 node->body = std::move(body);
388 node->span = std::move(span);
389 data_ = std::move(node);
390}
391
392TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) {
393 return DeclBuffer(buffer, body, span);
394});
395
396TVM_REGISTER_NODE_TYPE(DeclBufferNode);
397
398// ProducerRealize
399ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
400 Stmt body, String storage_scope, Span span) {
401 for (size_t i = 0; i < bounds.size(); ++i) {
402 ICHECK(bounds[i]->min.defined());
403 ICHECK(bounds[i]->extent.defined());
404 ICHECK(bounds[i]->min.dtype().is_scalar());
405 ICHECK(bounds[i]->extent.dtype().is_scalar());
406 }
407 ICHECK(body.defined());
408 ICHECK(condition.defined());
409 ICHECK(condition.dtype().is_bool());
410
411 ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
412 node->producer = std::move(producer);
413 node->bounds = std::move(bounds);
414 node->condition = std::move(condition);
415 node->body = std::move(body);
416 node->span = std::move(span);
417 node->storage_scope = std::move(storage_scope);
418 data_ = std::move(node);
419}
420
421TVM_REGISTER_GLOBAL("tir.ProducerRealize")
422 .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
423 String storage_scope, Span span) {
424 return ProducerRealize(producer, bounds, condition, body, storage_scope, span);
425 });
426
427TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
428
429// Prefetch
430Prefetch::Prefetch(Buffer buffer, Array<Range> bounds, Span span) {
431 data_ = make_object<PrefetchNode>(buffer, bounds, span);
432}
433
434TVM_REGISTER_GLOBAL("tir.Prefetch")
435 .set_body_typed([](Buffer buffer, Array<Range> bounds, Span span) {
436 return Prefetch(buffer, bounds, span);
437 });
438
439TVM_REGISTER_NODE_TYPE(PrefetchNode);
440
441// SeqStmt
442SeqStmt::SeqStmt(Array<Stmt> seq, Span span) {
443 auto node = make_object<SeqStmtNode>();
444 node->seq = std::move(seq);
445 node->span = std::move(span);
446 data_ = std::move(node);
447}
448
449TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq, Span span) {
450 return SeqStmt(std::move(seq), span);
451});
452
453TVM_REGISTER_NODE_TYPE(SeqStmtNode);
454
455// IfThenElse
456IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) {
457 ICHECK(condition.defined());
458 ICHECK(then_case.defined());
459 // else_case may be null.
460 ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
461 node->condition = std::move(condition);
462 node->then_case = std::move(then_case);
463 node->else_case = std::move(else_case);
464 node->span = std::move(span);
465 data_ = std::move(node);
466}
467
468TVM_REGISTER_NODE_TYPE(IfThenElseNode);
469
470TVM_REGISTER_GLOBAL("tir.IfThenElse")
471 .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) {
472 return IfThenElse(condition, then_case, else_case, span);
473 });
474
475// Evaluate
476Evaluate::Evaluate(PrimExpr value, Span span) {
477 ICHECK(value.defined());
478
479 ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
480 node->value = std::move(value);
481 node->span = std::move(span);
482 data_ = std::move(node);
483}
484
485TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) {
486 return Evaluate(value, span);
487});
488
489TVM_REGISTER_NODE_TYPE(EvaluateNode);
490
491// BufferStore
492BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
493 ICHECK_EQ(buffer->shape.size(), indices.size())
494 << "Buffer " << buffer->name << " is " << buffer->shape.size()
495 << "-dimensional, cannot be indexed with the " << indices.size()
496 << "-dimensional indices provided.";
497
498 for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
499 ICHECK(indices[i].dtype().is_scalar())
500 << "Only the last index of a buffer access may be a vector type.";
501 }
502
503 int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
504 int buffer_lanes = buffer->dtype.lanes();
505
506 ICHECK_EQ(index_lanes * buffer_lanes, value.dtype().lanes())
507 << "Cannot store value with " << value.dtype().lanes() << ", expected value with "
508 << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes
509 << " buffer element lanes)";
510 if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) {
511 LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " //
512 << "buffer's dtype is `" << buffer->dtype //
513 << "`, the lanes of indexing are: `" << index_lanes //
514 << "`, but RHS's dtype is `" << value.dtype() << "`";
515 }
516
517 ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
518 node->buffer = std::move(buffer);
519 node->value = std::move(value);
520 node->indices = std::move(indices);
521 node->span = std::move(span);
522 data_ = std::move(node);
523}
524
525TVM_REGISTER_GLOBAL("tir.BufferStore")
526 .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
527 return BufferStore(buffer, value, indices, span);
528 });
529
530TVM_REGISTER_NODE_TYPE(BufferStoreNode);
531
532// BufferRealize
533BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
534 Span span) {
535 data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body, span);
536}
537
538TVM_REGISTER_GLOBAL("tir.BufferRealize")
539 .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
540 Span span) { return BufferRealize(buffer, bounds, condition, body, span); });
541
542TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
543
544// BufferRegion
545BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
546 CHECK_EQ(buffer->shape.size(), region.size())
547 << "The dimension between " << buffer << " and region " << region
548 << " mismatched, the buffer is " << buffer;
549 ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
550 node->buffer = std::move(buffer);
551 node->region = std::move(region);
552 data_ = std::move(node);
553}
554
555BufferRegion BufferRegion::FullRegion(Buffer buffer) {
556 Array<Range> region;
557 for (PrimExpr extent : buffer->shape) {
558 region.push_back(Range::FromMinExtent(0, extent));
559 }
560 return BufferRegion(buffer, region);
561}
562
563BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
564 Array<Range> region;
565 for (const PrimExpr& index : indices) {
566 if (const RampNode* ramp_index = index.as<RampNode>()) {
567 region.push_back(
568 Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes));
569 } else {
570 region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
571 }
572 }
573 return BufferRegion(buffer, region);
574}
575
576TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
577 return BufferRegion(buffer, region);
578});
579
580TVM_REGISTER_NODE_TYPE(BufferRegionNode);
581
582// MatchBufferRegion
583MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
584 const Buffer& source_buffer = source->buffer;
585 arith::Analyzer analyzer;
586 // Check scope and dtype
587 CHECK_EQ(buffer.scope(), source_buffer.scope())
588 << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. "
589 << source_buffer.scope();
590 CHECK_EQ(buffer->dtype, source_buffer->dtype)
591 << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. "
592 << source_buffer->dtype;
593
594 // Check data_alignment
595 CHECK(source_buffer->data_alignment % buffer->data_alignment == 0)
596 << "Trying to match buffer to another one with lower alignment requirement "
597 << " required_alignment=" << buffer->data_alignment
598 << ", provided_alignment=" << source_buffer->data_alignment;
599
600 // Check BufferType. AutoBroadcast is not allowed for now.
601 CHECK(buffer->buffer_type == BufferType::kDefault &&
602 source_buffer->buffer_type == BufferType::kDefault)
603 << "AutoBroadcast is not allowed in MatchBuffer";
604
605 // Validate shape
606 CHECK(source->region.size() >= buffer->shape.size())
607 << "Dimension of source Region expected to be larger or equal than target buffer shape, but "
608 "got "
609 << source->region.size() << " vs. " << buffer->shape.size();
610 size_t offset = source->region.size() - buffer->shape.size();
611 for (size_t i = 0; i < offset; ++i) {
612 CHECK(analyzer.CanProve(source->region[i]->extent == 1))
613 << "The higher dimension should be 1, but got " << source->region[i]->extent << ".";
614 }
615 for (size_t i = 0; i < buffer->shape.size(); ++i) {
616 const Range& source_range = source->region[i + offset];
617 const PrimExpr& buffer_shape = buffer->shape[i];
618 if (!buffer_shape->IsInstance<VarNode>()) {
619 CHECK(analyzer.CanProve(source_range->extent == buffer_shape))
620 << "The dimension mismatched between source region and target buffer shape, got "
621 << source_range->extent << " vs. " << buffer_shape << ".";
622 }
623 }
624 // Note that we do not check elem_offset and strides in this function
625
626 // Construction
627 ObjectPtr<MatchBufferRegionNode> node = make_object<MatchBufferRegionNode>();
628 node->buffer = std::move(buffer);
629 node->source = std::move(source);
630 data_ = std::move(node);
631}
632
633TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) {
634 return MatchBufferRegion(buffer, source);
635});
636
637TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);
638
639// Block
640Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
641 String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
642 Array<MatchBufferRegion> match_buffers, Map<String, ObjectRef> annotations,
643 Span span) {
644 ObjectPtr<BlockNode> node = make_object<BlockNode>();
645 node->iter_vars = std::move(iter_vars);
646 node->reads = std::move(reads);
647 node->writes = std::move(writes);
648 node->name_hint = std::move(name_hint);
649 node->body = std::move(body);
650 node->init = std::move(init);
651 node->alloc_buffers = std::move(alloc_buffers);
652 node->match_buffers = std::move(match_buffers);
653 node->annotations = std::move(annotations);
654 node->span = std::move(span);
655 data_ = std::move(node);
656}
657
658TVM_REGISTER_GLOBAL("tir.Block")
659 .set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads,
660 Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init,
661 Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers,
662 Map<String, ObjectRef> annotations, Span span) {
663 return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers,
664 annotations, span);
665 });
666
667TVM_REGISTER_NODE_TYPE(BlockNode);
668
669// BlockRealize
670BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
671 CHECK_EQ(block->iter_vars.size(), values.size())
672 << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
673 CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
674 ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>();
675 node->iter_values = std::move(values);
676 node->predicate = std::move(predicate);
677 node->block = std::move(block);
678 node->span = std::move(span);
679 data_ = std::move(node);
680}
681
682TVM_REGISTER_GLOBAL("tir.BlockRealize")
683 .set_body_typed([](Array<PrimExpr> iter_values, PrimExpr predicate, Block block, Span span) {
684 return BlockRealize(iter_values, predicate, block, span);
685 });
686
687TVM_REGISTER_NODE_TYPE(BlockRealizeNode);
688
689PrimExpr TypeAnnotation(DataType dtype, Span span) {
690 static auto op = Op::Get("tir.type_annotation");
691 return tir::Call(dtype, op, {}, span);
692}
693
694TVM_TIR_REGISTER_OP("type_annotation")
695 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
696
697} // namespace tir
698} // namespace tvm
699