1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file buffer.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/runtime/device_api.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/buffer.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/op.h>
31
32#include <iterator>
33#include <stack>
34
35#include "../../arith/pattern_match.h"
36
37namespace tvm {
38namespace tir {
39
40using IndexMod = tir::FloorModNode;
41using IndexDiv = tir::FloorDivNode;
42
43Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
44 for (size_t i = 0; i < array.size(); ++i) {
45 array.Set(i, ana->Simplify(array[i]));
46 }
47 return array;
48}
49
50Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String storage_scope,
51 Array<IntImm> axis_separators, Span span) {
52 DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
53 return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
54 Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, axis_separators, span);
55}
56
57// Split the given expression w.r.t the add operator
58inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
59 using namespace tir;
60 std::vector<const PrimExpr*> ret;
61 std::stack<const PrimExpr*> split_buffer;
62 split_buffer.push(&expr);
63 while (!split_buffer.empty()) {
64 const PrimExpr* top_ele = split_buffer.top();
65 split_buffer.pop();
66 auto expr_add_match = top_ele->as<AddNode>();
67 if (expr_add_match) {
68 split_buffer.push(&expr_add_match->b);
69 split_buffer.push(&expr_add_match->a);
70 } else {
71 ret.emplace_back(top_ele);
72 }
73 }
74 return ret;
75}
76
77// Searches for the following types of expr:
78// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
79// mod_l_expr = c2
80// mod_r_expr = k1 * k2 * ... * ki
81// where c1 ~= c2 mod k1 * k2 * ... * ki
82// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1)
83// Currently the we will not search the add/mult combinations exhaustively
84// as it will take too much computation.
85inline std::pair<bool, PrimExpr> MergeMulModInner(arith::Analyzer* analyzer,
86 const PrimExpr& mult_expr,
87 const PrimExpr& mod_l_expr,
88 const PrimExpr& mod_r_expr) {
89 using namespace tir;
90 const MulNode* mult_ptr = mult_expr.as<MulNode>();
91 if (!mult_ptr) return std::make_pair(false, PrimExpr());
92 PrimExpr mult_outer = mult_ptr->b;
93 const PrimExpr* inner = &(mult_ptr->a);
94 // 1. Calculate the outer multiplier
95 while (true) {
96 mult_ptr = inner->as<MulNode>();
97 if (mult_ptr) {
98 inner = &(mult_ptr->a);
99 mult_outer = mult_ptr->b * mult_outer;
100 } else {
101 break;
102 }
103 }
104 // 2. Search for the pattern c / (...) * (...) + c % (...)
105 // We match the search element with Add, Mul and Div.
106 // If Add is found, we need to continue our search for the rhs
107 // If Mult is found, we will expand the inner multiplication factor
108 // If Div is found, we will go on testing whether lhs matches the lhs of mod expr
109 // and returns the optimization result.
110 const PrimExpr* search_ptr = inner;
111 PrimExpr mult_inner; // The inner multiplication factor
112 PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized
113 tir::ExprDeepEqual expr_equal;
114
115 while (true) {
116 auto inner_div_ptr = search_ptr->as<IndexDiv>();
117 auto inner_mult_ptr = search_ptr->as<MulNode>();
118 auto inner_add_ptr = search_ptr->as<AddNode>();
119 if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
120 return std::make_pair(false, PrimExpr());
121 } else if (inner_div_ptr) {
122 PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
123 if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) &&
124 analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) {
125 // Found!
126 PrimExpr ret =
127 no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a;
128 return std::make_pair(true, ret);
129 } else {
130 return std::make_pair(false, PrimExpr());
131 }
132 } else if (inner_mult_ptr) {
133 mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
134 search_ptr = &(inner_mult_ptr->a);
135 } else if (inner_add_ptr) {
136 if (mult_inner.get()) {
137 return std::make_pair(false, PrimExpr());
138 }
139 no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
140 search_ptr = &(inner_add_ptr->b);
141 } else {
142 LOG(FATAL) << "Unexpected search result!";
143 break;
144 }
145 }
146 return std::make_pair(false, PrimExpr());
147}
148
149// Insert the elements into the corresponding mult_exprs and mod_exprs.
150// If the element is found to match Mul, it will be pushed to the mult_exprs.
151// If the element it found to match Mod, it will be pused to the mod_exprs.
152// Otherwise, the elements will be added to the no_opt_sum variable
153inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
154 std::list<PrimExpr>* mult_exprs,
155 std::list<std::pair<PrimExpr, PrimExpr>>* mod_exprs,
156 PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) {
157 using namespace tir;
158 *has_mult = false;
159 *has_mod = false;
160 for (const PrimExpr* ele : eles) {
161 auto mod_ptr = ele->as<IndexMod>();
162 auto mult_ptr = ele->as<MulNode>();
163 if (mod_ptr) {
164 *has_mod = true;
165 mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
166 } else if (mult_ptr) {
167 *has_mult = true;
168 mult_exprs->emplace_back(*ele);
169 } else {
170 *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
171 }
172 }
173}
174
175// Searches for this types of expr:
176// (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
177// + c % (k1 * k2 * ... * ki)
178// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
179// The search will be performed repeatively until no pattern is found.
180// Return: a pair with (false, Expr()) if cannot be optimized.
181// a pair with (true, optimized_expr) if can be optimized
182inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
183 using namespace tir;
184 // 1. Prepare the lists.
185 // We store two lists, a list that contain all the elements that match Mul and
186 // a list that contain all the elements that match Mod.
187 // The elements in the Mod will be used to match against the elements in Mul.
188 // The result will then be split and pushed back to these two lists.
189 PrimExpr simplified_base = base;
190 arith::PVar<PrimExpr> x, y;
191 if ((floordiv(x, y) * y + floormod(x, y)).Match(simplified_base)) {
192 simplified_base = x.Eval();
193 }
194 simplified_base = analyzer->Simplify(simplified_base);
195 std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base);
196 std::list<PrimExpr> mult_exprs;
197 std::list<std::pair<PrimExpr, PrimExpr>> mod_exprs;
198 PrimExpr no_opt_sum;
199 bool has_mult;
200 bool has_mod;
201 MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod);
202 bool find_opt = false;
203 std::list<std::pair<PrimExpr, PrimExpr>>::iterator search_mod_it = mod_exprs.begin();
204 // 2. Exhaustive Search
205 while (search_mod_it != mod_exprs.end()) {
206 std::list<PrimExpr>::iterator mult_it = mult_exprs.begin();
207 bool inner_find_opt = false;
208 while (mult_it != mult_exprs.end()) {
209 std::pair<bool, PrimExpr> ret =
210 MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second);
211 if (ret.first) {
212 inner_find_opt = true;
213 auto temp_mod_it = search_mod_it;
214 ++search_mod_it;
215 mod_exprs.erase(temp_mod_it);
216 mult_exprs.erase(mult_it);
217 std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second);
218 MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult,
219 &has_mod);
220 if (has_mult) {
221 search_mod_it = mod_exprs.begin();
222 } else if (has_mod && search_mod_it == mod_exprs.end()) {
223 search_mod_it--;
224 }
225 break;
226 } else {
227 ++mult_it;
228 }
229 }
230 find_opt = find_opt || inner_find_opt;
231 if (!inner_find_opt) {
232 ++search_mod_it;
233 }
234 }
235 if (!find_opt) {
236 return simplified_base;
237 }
238 for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
239 no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
240 }
241 for (std::list<std::pair<PrimExpr, PrimExpr>>::iterator it = mod_exprs.begin();
242 it != mod_exprs.end(); ++it) {
243 no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second)
244 : indexmod(it->first, it->second);
245 }
246 return no_opt_sum;
247}
248
249Array<PrimExpr> Buffer::OffsetOf(Array<PrimExpr> input_indices) const {
250 return (*this)->ElemOffset(std::move(input_indices));
251}
252
253// The buffer offset in convention of number of elements of
254// original data ignoring number of lanes.
255// We also perform optimization to simplify the indexing expression.
256Array<PrimExpr> BufferNode::ElemOffset(Array<PrimExpr> input_indices) const {
257 ICHECK_EQ(shape.size(), input_indices.size())
258 << "Buffer " << this->name << " is " << shape.size()
259 << "-dimensional, cannot be indexed with the " << input_indices.size()
260 << "-dimensional indices provided.";
261
262 if (strides.size()) {
263 ICHECK_EQ(this->strides.size(), input_indices.size())
264 << "If strides are defined, "
265 << "the index's dimensionality must match the dimensionality of the index given.";
266 }
267
268 // TODO(Lunderberg): Better handling for cases where there is more
269 // than one output index. Currently, this only allows elem_offset
270 // to be non-zero for flat memory allocations.
271 Array<PrimExpr> elem_offsets = {};
272 if (elem_offset.defined() && !is_zero(elem_offset)) {
273 elem_offsets = {elem_offset};
274 }
275
276 if (elem_offsets.size()) {
277 ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1)
278 << "If element offsets are defined, "
279 << "there must be one element offset for each output index.";
280 }
281
282 Array<PrimExpr> output_indices(axis_separators.size() + 1, 0);
283
284 size_t current_output_axis = 0;
285
286 arith::Analyzer ana;
287
288 for (size_t i = 0; i < input_indices.size(); i++) {
289 if ((current_output_axis < axis_separators.size()) &&
290 (i == size_t(axis_separators[current_output_axis]->value))) {
291 current_output_axis++;
292 }
293
294 PrimExpr output_index = output_indices[current_output_axis];
295 if (strides.size()) {
296 output_index = output_index + input_indices[i] * strides[i];
297 } else {
298 output_index = output_index * this->shape[i] + input_indices[i];
299 }
300
301 if (i > 0) {
302 output_index = MergeMulMod(&ana, output_index);
303 }
304
305 output_indices.Set(current_output_axis, output_index);
306 }
307
308 if (elem_offsets.size()) {
309 for (size_t i = 0; i < output_indices.size(); i++) {
310 output_indices.Set(i, output_indices[i] + elem_offsets[i]);
311 }
312 }
313
314 return SimplifyArray(&ana, output_indices);
315}
316
317inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
318 Array<PrimExpr> offsets = n->ElemOffset(index);
319 // If the Buffer has element type with more than one lane, scale to
320 // get the offset in number of scalars.
321 if (n->dtype.lanes() != 1) {
322 PrimExpr last_offset = offsets[offsets.size() - 1];
323 offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes()));
324 }
325
326 // If the requested type has more than one lane, make a RampNode at
327 // that offset.
328 if (dtype.lanes() != 1) {
329 PrimExpr last_offset = offsets[offsets.size() - 1];
330 PrimExpr stride = make_const(last_offset.dtype(), 1);
331 offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes()));
332 }
333
334 return offsets;
335}
336
337Buffer Buffer::GetFlattenedBuffer() const {
338 auto self = operator->();
339
340 // These checks ensure that all output axes contain at least one
341 // input axis.
342 for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
343 auto sep = self->axis_separators[i]->value;
344 auto next_sep = self->axis_separators[i + 1]->value;
345 ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order.";
346 }
347 if (self->axis_separators.size()) {
348 auto first_sep = self->axis_separators[0]->value;
349 ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, "
350 << "so that first output axis contains at least one input axis";
351 auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value;
352 ICHECK_LT(last_sep, self->shape.size())
353 << "Last output axis must contain at least one input axis.";
354 }
355
356 Array<PrimExpr> output_shape;
357 if (self->strides.size()) {
358 // If strides are defined, then the extent of each flattened
359 // buffer is the stride*size for the first input axis used for
360 // each output axis.
361 ICHECK_EQ(self->shape.size(), self->strides.size());
362 output_shape.push_back(self->strides[0] * self->shape[0]);
363 for (const auto& sep : self->axis_separators) {
364 output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]);
365 }
366
367 } else {
368 // Otherwise, the extent of each flattened buffer is the product
369 // of the extents of each input axis used to generate that output
370 // axis. This also "flattens" rank-0 tensors to a rank-1 buffer
371 // of shape [1].
372 output_shape = Array<PrimExpr>(self->axis_separators.size() + 1, 1);
373 size_t current_output_index = 0;
374 for (size_t i = 0; i < self->shape.size(); i++) {
375 if ((current_output_index < self->axis_separators.size()) &&
376 (i == size_t(self->axis_separators[current_output_index]->value))) {
377 current_output_index += 1;
378 }
379 output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]);
380 }
381 }
382
383 // The axis_separators for the output buffer.
384 Array<IntImm> output_axis_separators;
385 for (size_t i = 0; i < self->axis_separators.size(); i++) {
386 auto dtype = self->axis_separators[i]->dtype;
387 output_axis_separators.push_back(IntImm(dtype, i + 1));
388 }
389
390 Buffer output = *this;
391 auto writer = output.CopyOnWrite();
392 writer->shape = output_shape;
393 writer->axis_separators = output_axis_separators;
394 writer->strides = {};
395
396 return output;
397}
398
399PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType value_dtype) const {
400 // specially handle bool, stored as DataType::Int(8)
401 const BufferNode* n = operator->();
402 ICHECK(n != nullptr);
403 ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
404 value_dtype.lanes() % n->dtype.lanes() == 0)
405 << "Cannot load " << value_dtype << " from buffer of " << n->dtype;
406
407 Array<PrimExpr> indices = begin;
408 int factor = value_dtype.lanes() / n->dtype.lanes();
409 if (factor > 1) {
410 indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor));
411 }
412 return BufferLoad(*this, indices);
413}
414
415Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
416 // specially handle bool, stored as DataType::Int(8)
417 const BufferNode* n = operator->();
418 ICHECK(n != nullptr);
419 DataType value_dtype = value.dtype();
420 ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
421 value_dtype.lanes() % n->dtype.lanes() == 0)
422 << "Cannot store " << value_dtype << " to buffer of " << n->dtype;
423
424 Array<PrimExpr> indices = begin;
425 int factor = value_dtype.lanes() / n->dtype.lanes();
426 if (factor > 1) {
427 indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor));
428 }
429 return BufferStore(*this, value, indices);
430}
431
432String Buffer::scope() const {
433 const auto* ptr_type = (*this)->data->type_annotation.as<PointerTypeNode>();
434 ICHECK(ptr_type) << "Buffer variable is not of pointer type";
435 if (ptr_type->storage_scope.empty()) {
436 return "global";
437 }
438 return ptr_type->storage_scope;
439}
440
441Buffer Buffer::MakeStrideView() const {
442 if ((*this)->strides.size() != 0) return *this;
443 if ((*this)->shape.size() == 0) return *this;
444 std::vector<PrimExpr> temp;
445 const BufferNode* self = operator->();
446 ICHECK(self != nullptr);
447 auto n = make_object<BufferNode>(*self);
448 PrimExpr acc = make_const(n->DefaultIndexType(), 1);
449 for (size_t i = n->shape.size(); i != 0; --i) {
450 temp.push_back(acc);
451 acc = acc * n->shape[i - 1];
452 }
453 for (size_t i = temp.size(); i != 0; --i) {
454 n->strides.push_back(temp[i - 1]);
455 }
456 return Buffer(n);
457}
458
459Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const {
460 const BufferNode* n = operator->();
461 ICHECK(n != nullptr);
462 arith::Analyzer ana;
463 begins = SimplifyArray(&ana, begins);
464 Array<PrimExpr> elem_offset =
465 n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); });
466
467 Array<PrimExpr> strides = n->strides;
468 if (strides.size() == 0) {
469 bool can_relax = true;
470 bool need_stride = false;
471 // check if stride is needed.
472 for (size_t i = 0; i < extents.size(); ++i) {
473 if (!can_relax) {
474 if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) {
475 need_stride = true;
476 }
477 }
478 if (!is_one(extents[i])) can_relax = false;
479 }
480 // make stride.
481 if (need_stride) {
482 return MakeStrideView().MakeSlice(begins, extents);
483 }
484 }
485 Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
486 n->data_alignment, 0, n->buffer_type);
487
488 // Buffer must be constructed with a singular element offset which means there is no
489 // support for n-dimensional buffers where n > 1. Insert sentinel value for
490 // ArgBinder::BindBuffer to state that any usage of element offset is invalid
491 // in this case. This allows for construction of a Buffer with multiple element offsets
492 // but disallows any usage of those element offsets. See PR #10816 for discussion on
493 // supporting multiple element offsets in TIR Buffer.
494 // TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer
495 if (elem_offset.size() != 1) {
496 slice.CopyOnWrite()->elem_offset = PrimExpr();
497 }
498 return slice;
499}
500
501PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
502 Optional<PrimExpr> input_extent) const {
503 const BufferNode* self = operator->();
504 ICHECK(self != nullptr);
505 PrimExpr e_dtype;
506 PrimExpr extent;
507 if (self->shape.size() == 0) {
508 extent = make_const(self->DefaultIndexType(), 1);
509 } else if (self->strides.size() == self->shape.size()) {
510 int highest_dim = 0;
511 extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
512 } else {
513 extent = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
514 make_const(DataType::Int(32), 1), self->shape) -
515 offset;
516 }
517 PrimExpr elem_offset = self->elem_offset + offset;
518 if (content_lanes > 1) {
519 e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
520 extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
521 elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes);
522 } else {
523 e_dtype = tir::TypeAnnotation(self->dtype);
524 }
525
526 if (input_extent.defined()) {
527 extent = input_extent.value();
528 }
529 Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
530 make_const(DataType::Int(32), access_mask)};
531 return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
532}
533
534Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
535 PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
536 BufferType buffer_type, Array<IntImm> axis_separators, Span span) {
537 DataType storage_dtype = dtype;
538 // specially handle bool
539 if (storage_dtype == DataType::Bool()) {
540 storage_dtype = DataType::Int(8);
541 }
542 // The buffer dtype may differ from the dtype of the underlying
543 // allocation, such as a single allocation that backs multiple
544 // tensors without a common datatype. Therefore, we check that the
545 // data pointer is a pointer, but not the exact type of the
546 // pointed-to values.
547
548 // TODO(Lunderberg): Use an explicit pointer cast for the data
549 // pointer. Should be done alongside extensions to StmtExprMutator
550 // to more easily handle buffer/buffer_var updates.
551 ICHECK(data->type_annotation.defined())
552 << "Variable " << data->name_hint << " is missing a type annotation.";
553 ICHECK(data->type_annotation.as<PointerTypeNode>())
554 << "Variable " << data->name_hint << " is not a pointer.";
555 ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
556 << "Variable " << data->name_hint << " does not point to a primitive.";
557
558 auto n = make_object<BufferNode>();
559 n->data = std::move(data);
560 n->dtype = dtype;
561
562 n->shape = std::move(shape);
563 n->strides = std::move(strides);
564 n->axis_separators = std::move(axis_separators);
565 n->name = std::move(name);
566 if (!elem_offset.defined()) {
567 elem_offset = make_const(n->DefaultIndexType(), 0);
568 }
569 if (data_alignment <= 0) {
570 data_alignment = runtime::kAllocAlignment;
571 }
572 if (offset_factor == 0) {
573 offset_factor = 1;
574 }
575 n->elem_offset = std::move(elem_offset);
576 n->data_alignment = data_alignment;
577 n->offset_factor = offset_factor;
578 n->buffer_type = buffer_type;
579 if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
580 for (size_t i = 0; i < n->shape.size(); ++i) {
581 n->strides.push_back(Var("stride", n->shape[i].dtype()));
582 }
583 }
584 n->span = std::move(span);
585 data_ = std::move(n);
586}
587
588tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
589 int data_alignment, int offset_factor, bool compact,
590 std::string memory_scope) {
591 DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
592 auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope));
593 bool has_any = false;
594 if (!compact) {
595 for (const auto& it : shape) {
596 if (it.as<tir::VarNode>()) {
597 has_any = true;
598 break;
599 }
600 }
601 }
602 tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault;
603
604 PrimExpr elem_offset;
605 if (offset_factor != 0) {
606 elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype());
607 } else {
608 elem_offset = PrimExpr();
609 }
610
611 return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment,
612 offset_factor, buffer_type);
613}
614
615TVM_REGISTER_NODE_TYPE(BufferNode);
616
617TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
618 ICHECK_EQ(args.size(), 11);
619 auto buffer_type = args[8].operator String();
620 BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
621 *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type,
622 args[9], args[10]);
623});
624
625TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
626
627TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer);
628
629TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf);
630
631TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload);
632
633TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore);
634
635TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope);
636
637} // namespace tir
638} // namespace tvm
639