1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file buffer.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/runtime/device_api.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/buffer.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/op.h> |
31 | |
32 | #include <iterator> |
33 | #include <stack> |
34 | |
35 | #include "../../arith/pattern_match.h" |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | using IndexMod = tir::FloorModNode; |
41 | using IndexDiv = tir::FloorDivNode; |
42 | |
43 | Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) { |
44 | for (size_t i = 0; i < array.size(); ++i) { |
45 | array.Set(i, ana->Simplify(array[i])); |
46 | } |
47 | return array; |
48 | } |
49 | |
50 | Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String storage_scope, |
51 | Array<IntImm> axis_separators, Span span) { |
52 | DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); |
53 | return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, |
54 | Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault, axis_separators, span); |
55 | } |
56 | |
57 | // Split the given expression w.r.t the add operator |
58 | inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) { |
59 | using namespace tir; |
60 | std::vector<const PrimExpr*> ret; |
61 | std::stack<const PrimExpr*> split_buffer; |
62 | split_buffer.push(&expr); |
63 | while (!split_buffer.empty()) { |
64 | const PrimExpr* top_ele = split_buffer.top(); |
65 | split_buffer.pop(); |
66 | auto expr_add_match = top_ele->as<AddNode>(); |
67 | if (expr_add_match) { |
68 | split_buffer.push(&expr_add_match->b); |
69 | split_buffer.push(&expr_add_match->a); |
70 | } else { |
71 | ret.emplace_back(top_ele); |
72 | } |
73 | } |
74 | return ret; |
75 | } |
76 | |
77 | // Searches for the following types of expr: |
78 | // mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki |
79 | // mod_l_expr = c2 |
80 | // mod_r_expr = k1 * k2 * ... * ki |
81 | // where c1 ~= c2 mod k1 * k2 * ... * ki |
82 | // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1) |
83 | // Currently the we will not search the add/mult combinations exhaustively |
84 | // as it will take too much computation. |
85 | inline std::pair<bool, PrimExpr> MergeMulModInner(arith::Analyzer* analyzer, |
86 | const PrimExpr& mult_expr, |
87 | const PrimExpr& mod_l_expr, |
88 | const PrimExpr& mod_r_expr) { |
89 | using namespace tir; |
90 | const MulNode* mult_ptr = mult_expr.as<MulNode>(); |
91 | if (!mult_ptr) return std::make_pair(false, PrimExpr()); |
92 | PrimExpr mult_outer = mult_ptr->b; |
93 | const PrimExpr* inner = &(mult_ptr->a); |
94 | // 1. Calculate the outer multiplier |
95 | while (true) { |
96 | mult_ptr = inner->as<MulNode>(); |
97 | if (mult_ptr) { |
98 | inner = &(mult_ptr->a); |
99 | mult_outer = mult_ptr->b * mult_outer; |
100 | } else { |
101 | break; |
102 | } |
103 | } |
104 | // 2. Search for the pattern c / (...) * (...) + c % (...) |
105 | // We match the search element with Add, Mul and Div. |
106 | // If Add is found, we need to continue our search for the rhs |
107 | // If Mult is found, we will expand the inner multiplication factor |
108 | // If Div is found, we will go on testing whether lhs matches the lhs of mod expr |
109 | // and returns the optimization result. |
110 | const PrimExpr* search_ptr = inner; |
111 | PrimExpr mult_inner; // The inner multiplication factor |
112 | PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized |
113 | tir::ExprDeepEqual expr_equal; |
114 | |
115 | while (true) { |
116 | auto inner_div_ptr = search_ptr->as<IndexDiv>(); |
117 | auto inner_mult_ptr = search_ptr->as<MulNode>(); |
118 | auto inner_add_ptr = search_ptr->as<AddNode>(); |
119 | if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { |
120 | return std::make_pair(false, PrimExpr()); |
121 | } else if (inner_div_ptr) { |
122 | PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; |
123 | if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && |
124 | analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) { |
125 | // Found! |
126 | PrimExpr ret = |
127 | no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a; |
128 | return std::make_pair(true, ret); |
129 | } else { |
130 | return std::make_pair(false, PrimExpr()); |
131 | } |
132 | } else if (inner_mult_ptr) { |
133 | mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b; |
134 | search_ptr = &(inner_mult_ptr->a); |
135 | } else if (inner_add_ptr) { |
136 | if (mult_inner.get()) { |
137 | return std::make_pair(false, PrimExpr()); |
138 | } |
139 | no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a; |
140 | search_ptr = &(inner_add_ptr->b); |
141 | } else { |
142 | LOG(FATAL) << "Unexpected search result!" ; |
143 | break; |
144 | } |
145 | } |
146 | return std::make_pair(false, PrimExpr()); |
147 | } |
148 | |
149 | // Insert the elements into the corresponding mult_exprs and mod_exprs. |
150 | // If the element is found to match Mul, it will be pushed to the mult_exprs. |
151 | // If the element it found to match Mod, it will be pused to the mod_exprs. |
152 | // Otherwise, the elements will be added to the no_opt_sum variable |
153 | inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles, |
154 | std::list<PrimExpr>* mult_exprs, |
155 | std::list<std::pair<PrimExpr, PrimExpr>>* mod_exprs, |
156 | PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { |
157 | using namespace tir; |
158 | *has_mult = false; |
159 | *has_mod = false; |
160 | for (const PrimExpr* ele : eles) { |
161 | auto mod_ptr = ele->as<IndexMod>(); |
162 | auto mult_ptr = ele->as<MulNode>(); |
163 | if (mod_ptr) { |
164 | *has_mod = true; |
165 | mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); |
166 | } else if (mult_ptr) { |
167 | *has_mult = true; |
168 | mult_exprs->emplace_back(*ele); |
169 | } else { |
170 | *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele; |
171 | } |
172 | } |
173 | } |
174 | |
175 | // Searches for this types of expr: |
176 | // (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki |
177 | // + c % (k1 * k2 * ... * ki) |
178 | // and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c |
179 | // The search will be performed repeatively until no pattern is found. |
180 | // Return: a pair with (false, Expr()) if cannot be optimized. |
181 | // a pair with (true, optimized_expr) if can be optimized |
182 | inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { |
183 | using namespace tir; |
184 | // 1. Prepare the lists. |
185 | // We store two lists, a list that contain all the elements that match Mul and |
186 | // a list that contain all the elements that match Mod. |
187 | // The elements in the Mod will be used to match against the elements in Mul. |
188 | // The result will then be split and pushed back to these two lists. |
189 | PrimExpr simplified_base = base; |
190 | arith::PVar<PrimExpr> x, y; |
191 | if ((floordiv(x, y) * y + floormod(x, y)).Match(simplified_base)) { |
192 | simplified_base = x.Eval(); |
193 | } |
194 | simplified_base = analyzer->Simplify(simplified_base); |
195 | std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base); |
196 | std::list<PrimExpr> mult_exprs; |
197 | std::list<std::pair<PrimExpr, PrimExpr>> mod_exprs; |
198 | PrimExpr no_opt_sum; |
199 | bool has_mult; |
200 | bool has_mod; |
201 | MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); |
202 | bool find_opt = false; |
203 | std::list<std::pair<PrimExpr, PrimExpr>>::iterator search_mod_it = mod_exprs.begin(); |
204 | // 2. Exhaustive Search |
205 | while (search_mod_it != mod_exprs.end()) { |
206 | std::list<PrimExpr>::iterator mult_it = mult_exprs.begin(); |
207 | bool inner_find_opt = false; |
208 | while (mult_it != mult_exprs.end()) { |
209 | std::pair<bool, PrimExpr> ret = |
210 | MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second); |
211 | if (ret.first) { |
212 | inner_find_opt = true; |
213 | auto temp_mod_it = search_mod_it; |
214 | ++search_mod_it; |
215 | mod_exprs.erase(temp_mod_it); |
216 | mult_exprs.erase(mult_it); |
217 | std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second); |
218 | MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, |
219 | &has_mod); |
220 | if (has_mult) { |
221 | search_mod_it = mod_exprs.begin(); |
222 | } else if (has_mod && search_mod_it == mod_exprs.end()) { |
223 | search_mod_it--; |
224 | } |
225 | break; |
226 | } else { |
227 | ++mult_it; |
228 | } |
229 | } |
230 | find_opt = find_opt || inner_find_opt; |
231 | if (!inner_find_opt) { |
232 | ++search_mod_it; |
233 | } |
234 | } |
235 | if (!find_opt) { |
236 | return simplified_base; |
237 | } |
238 | for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) { |
239 | no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; |
240 | } |
241 | for (std::list<std::pair<PrimExpr, PrimExpr>>::iterator it = mod_exprs.begin(); |
242 | it != mod_exprs.end(); ++it) { |
243 | no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) |
244 | : indexmod(it->first, it->second); |
245 | } |
246 | return no_opt_sum; |
247 | } |
248 | |
249 | Array<PrimExpr> Buffer::OffsetOf(Array<PrimExpr> input_indices) const { |
250 | return (*this)->ElemOffset(std::move(input_indices)); |
251 | } |
252 | |
253 | // The buffer offset in convention of number of elements of |
254 | // original data ignoring number of lanes. |
255 | // We also perform optimization to simplify the indexing expression. |
256 | Array<PrimExpr> BufferNode::ElemOffset(Array<PrimExpr> input_indices) const { |
257 | ICHECK_EQ(shape.size(), input_indices.size()) |
258 | << "Buffer " << this->name << " is " << shape.size() |
259 | << "-dimensional, cannot be indexed with the " << input_indices.size() |
260 | << "-dimensional indices provided." ; |
261 | |
262 | if (strides.size()) { |
263 | ICHECK_EQ(this->strides.size(), input_indices.size()) |
264 | << "If strides are defined, " |
265 | << "the index's dimensionality must match the dimensionality of the index given." ; |
266 | } |
267 | |
268 | // TODO(Lunderberg): Better handling for cases where there is more |
269 | // than one output index. Currently, this only allows elem_offset |
270 | // to be non-zero for flat memory allocations. |
271 | Array<PrimExpr> elem_offsets = {}; |
272 | if (elem_offset.defined() && !is_zero(elem_offset)) { |
273 | elem_offsets = {elem_offset}; |
274 | } |
275 | |
276 | if (elem_offsets.size()) { |
277 | ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1) |
278 | << "If element offsets are defined, " |
279 | << "there must be one element offset for each output index." ; |
280 | } |
281 | |
282 | Array<PrimExpr> output_indices(axis_separators.size() + 1, 0); |
283 | |
284 | size_t current_output_axis = 0; |
285 | |
286 | arith::Analyzer ana; |
287 | |
288 | for (size_t i = 0; i < input_indices.size(); i++) { |
289 | if ((current_output_axis < axis_separators.size()) && |
290 | (i == size_t(axis_separators[current_output_axis]->value))) { |
291 | current_output_axis++; |
292 | } |
293 | |
294 | PrimExpr output_index = output_indices[current_output_axis]; |
295 | if (strides.size()) { |
296 | output_index = output_index + input_indices[i] * strides[i]; |
297 | } else { |
298 | output_index = output_index * this->shape[i] + input_indices[i]; |
299 | } |
300 | |
301 | if (i > 0) { |
302 | output_index = MergeMulMod(&ana, output_index); |
303 | } |
304 | |
305 | output_indices.Set(current_output_axis, output_index); |
306 | } |
307 | |
308 | if (elem_offsets.size()) { |
309 | for (size_t i = 0; i < output_indices.size(); i++) { |
310 | output_indices.Set(i, output_indices[i] + elem_offsets[i]); |
311 | } |
312 | } |
313 | |
314 | return SimplifyArray(&ana, output_indices); |
315 | } |
316 | |
317 | inline Array<PrimExpr> BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) { |
318 | Array<PrimExpr> offsets = n->ElemOffset(index); |
319 | // If the Buffer has element type with more than one lane, scale to |
320 | // get the offset in number of scalars. |
321 | if (n->dtype.lanes() != 1) { |
322 | PrimExpr last_offset = offsets[offsets.size() - 1]; |
323 | offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes())); |
324 | } |
325 | |
326 | // If the requested type has more than one lane, make a RampNode at |
327 | // that offset. |
328 | if (dtype.lanes() != 1) { |
329 | PrimExpr last_offset = offsets[offsets.size() - 1]; |
330 | PrimExpr stride = make_const(last_offset.dtype(), 1); |
331 | offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes())); |
332 | } |
333 | |
334 | return offsets; |
335 | } |
336 | |
337 | Buffer Buffer::GetFlattenedBuffer() const { |
338 | auto self = operator->(); |
339 | |
340 | // These checks ensure that all output axes contain at least one |
341 | // input axis. |
342 | for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) { |
343 | auto sep = self->axis_separators[i]->value; |
344 | auto next_sep = self->axis_separators[i + 1]->value; |
345 | ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly increasing order." ; |
346 | } |
347 | if (self->axis_separators.size()) { |
348 | auto first_sep = self->axis_separators[0]->value; |
349 | ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater than 0, " |
350 | << "so that first output axis contains at least one input axis" ; |
351 | auto last_sep = self->axis_separators[self->axis_separators.size() - 1]->value; |
352 | ICHECK_LT(last_sep, self->shape.size()) |
353 | << "Last output axis must contain at least one input axis." ; |
354 | } |
355 | |
356 | Array<PrimExpr> output_shape; |
357 | if (self->strides.size()) { |
358 | // If strides are defined, then the extent of each flattened |
359 | // buffer is the stride*size for the first input axis used for |
360 | // each output axis. |
361 | ICHECK_EQ(self->shape.size(), self->strides.size()); |
362 | output_shape.push_back(self->strides[0] * self->shape[0]); |
363 | for (const auto& sep : self->axis_separators) { |
364 | output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]); |
365 | } |
366 | |
367 | } else { |
368 | // Otherwise, the extent of each flattened buffer is the product |
369 | // of the extents of each input axis used to generate that output |
370 | // axis. This also "flattens" rank-0 tensors to a rank-1 buffer |
371 | // of shape [1]. |
372 | output_shape = Array<PrimExpr>(self->axis_separators.size() + 1, 1); |
373 | size_t current_output_index = 0; |
374 | for (size_t i = 0; i < self->shape.size(); i++) { |
375 | if ((current_output_index < self->axis_separators.size()) && |
376 | (i == size_t(self->axis_separators[current_output_index]->value))) { |
377 | current_output_index += 1; |
378 | } |
379 | output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]); |
380 | } |
381 | } |
382 | |
383 | // The axis_separators for the output buffer. |
384 | Array<IntImm> output_axis_separators; |
385 | for (size_t i = 0; i < self->axis_separators.size(); i++) { |
386 | auto dtype = self->axis_separators[i]->dtype; |
387 | output_axis_separators.push_back(IntImm(dtype, i + 1)); |
388 | } |
389 | |
390 | Buffer output = *this; |
391 | auto writer = output.CopyOnWrite(); |
392 | writer->shape = output_shape; |
393 | writer->axis_separators = output_axis_separators; |
394 | writer->strides = {}; |
395 | |
396 | return output; |
397 | } |
398 | |
399 | PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType value_dtype) const { |
400 | // specially handle bool, stored as DataType::Int(8) |
401 | const BufferNode* n = operator->(); |
402 | ICHECK(n != nullptr); |
403 | ICHECK(value_dtype.element_of() == n->dtype.element_of() && |
404 | value_dtype.lanes() % n->dtype.lanes() == 0) |
405 | << "Cannot load " << value_dtype << " from buffer of " << n->dtype; |
406 | |
407 | Array<PrimExpr> indices = begin; |
408 | int factor = value_dtype.lanes() / n->dtype.lanes(); |
409 | if (factor > 1) { |
410 | indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); |
411 | } |
412 | return BufferLoad(*this, indices); |
413 | } |
414 | |
415 | Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const { |
416 | // specially handle bool, stored as DataType::Int(8) |
417 | const BufferNode* n = operator->(); |
418 | ICHECK(n != nullptr); |
419 | DataType value_dtype = value.dtype(); |
420 | ICHECK(value_dtype.element_of() == n->dtype.element_of() && |
421 | value_dtype.lanes() % n->dtype.lanes() == 0) |
422 | << "Cannot store " << value_dtype << " to buffer of " << n->dtype; |
423 | |
424 | Array<PrimExpr> indices = begin; |
425 | int factor = value_dtype.lanes() / n->dtype.lanes(); |
426 | if (factor > 1) { |
427 | indices.Set(indices.size() - 1, Ramp(indices[indices.size() - 1], 1, factor)); |
428 | } |
429 | return BufferStore(*this, value, indices); |
430 | } |
431 | |
432 | String Buffer::scope() const { |
433 | const auto* ptr_type = (*this)->data->type_annotation.as<PointerTypeNode>(); |
434 | ICHECK(ptr_type) << "Buffer variable is not of pointer type" ; |
435 | if (ptr_type->storage_scope.empty()) { |
436 | return "global" ; |
437 | } |
438 | return ptr_type->storage_scope; |
439 | } |
440 | |
441 | Buffer Buffer::MakeStrideView() const { |
442 | if ((*this)->strides.size() != 0) return *this; |
443 | if ((*this)->shape.size() == 0) return *this; |
444 | std::vector<PrimExpr> temp; |
445 | const BufferNode* self = operator->(); |
446 | ICHECK(self != nullptr); |
447 | auto n = make_object<BufferNode>(*self); |
448 | PrimExpr acc = make_const(n->DefaultIndexType(), 1); |
449 | for (size_t i = n->shape.size(); i != 0; --i) { |
450 | temp.push_back(acc); |
451 | acc = acc * n->shape[i - 1]; |
452 | } |
453 | for (size_t i = temp.size(); i != 0; --i) { |
454 | n->strides.push_back(temp[i - 1]); |
455 | } |
456 | return Buffer(n); |
457 | } |
458 | |
459 | Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const { |
460 | const BufferNode* n = operator->(); |
461 | ICHECK(n != nullptr); |
462 | arith::Analyzer ana; |
463 | begins = SimplifyArray(&ana, begins); |
464 | Array<PrimExpr> elem_offset = |
465 | n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); }); |
466 | |
467 | Array<PrimExpr> strides = n->strides; |
468 | if (strides.size() == 0) { |
469 | bool can_relax = true; |
470 | bool need_stride = false; |
471 | // check if stride is needed. |
472 | for (size_t i = 0; i < extents.size(); ++i) { |
473 | if (!can_relax) { |
474 | if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { |
475 | need_stride = true; |
476 | } |
477 | } |
478 | if (!is_one(extents[i])) can_relax = false; |
479 | } |
480 | // make stride. |
481 | if (need_stride) { |
482 | return MakeStrideView().MakeSlice(begins, extents); |
483 | } |
484 | } |
485 | Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice" , |
486 | n->data_alignment, 0, n->buffer_type); |
487 | |
488 | // Buffer must be constructed with a singular element offset which means there is no |
489 | // support for n-dimensional buffers where n > 1. Insert sentinel value for |
490 | // ArgBinder::BindBuffer to state that any usage of element offset is invalid |
491 | // in this case. This allows for construction of a Buffer with multiple element offsets |
492 | // but disallows any usage of those element offsets. See PR #10816 for discussion on |
493 | // supporting multiple element offsets in TIR Buffer. |
494 | // TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer |
495 | if (elem_offset.size() != 1) { |
496 | slice.CopyOnWrite()->elem_offset = PrimExpr(); |
497 | } |
498 | return slice; |
499 | } |
500 | |
501 | PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset, |
502 | Optional<PrimExpr> input_extent) const { |
503 | const BufferNode* self = operator->(); |
504 | ICHECK(self != nullptr); |
505 | PrimExpr e_dtype; |
506 | PrimExpr extent; |
507 | if (self->shape.size() == 0) { |
508 | extent = make_const(self->DefaultIndexType(), 1); |
509 | } else if (self->strides.size() == self->shape.size()) { |
510 | int highest_dim = 0; |
511 | extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; |
512 | } else { |
513 | extent = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, |
514 | make_const(DataType::Int(32), 1), self->shape) - |
515 | offset; |
516 | } |
517 | PrimExpr elem_offset = self->elem_offset + offset; |
518 | if (content_lanes > 1) { |
519 | e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); |
520 | extent = extent / make_const(self->elem_offset.dtype(), content_lanes); |
521 | elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); |
522 | } else { |
523 | e_dtype = tir::TypeAnnotation(self->dtype); |
524 | } |
525 | |
526 | if (input_extent.defined()) { |
527 | extent = input_extent.value(); |
528 | } |
529 | Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent, |
530 | make_const(DataType::Int(32), access_mask)}; |
531 | return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); |
532 | } |
533 | |
534 | Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides, |
535 | PrimExpr elem_offset, String name, int data_alignment, int offset_factor, |
536 | BufferType buffer_type, Array<IntImm> axis_separators, Span span) { |
537 | DataType storage_dtype = dtype; |
538 | // specially handle bool |
539 | if (storage_dtype == DataType::Bool()) { |
540 | storage_dtype = DataType::Int(8); |
541 | } |
542 | // The buffer dtype may differ from the dtype of the underlying |
543 | // allocation, such as a single allocation that backs multiple |
544 | // tensors without a common datatype. Therefore, we check that the |
545 | // data pointer is a pointer, but not the exact type of the |
546 | // pointed-to values. |
547 | |
548 | // TODO(Lunderberg): Use an explicit pointer cast for the data |
549 | // pointer. Should be done alongside extensions to StmtExprMutator |
550 | // to more easily handle buffer/buffer_var updates. |
551 | ICHECK(data->type_annotation.defined()) |
552 | << "Variable " << data->name_hint << " is missing a type annotation." ; |
553 | ICHECK(data->type_annotation.as<PointerTypeNode>()) |
554 | << "Variable " << data->name_hint << " is not a pointer." ; |
555 | ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()) |
556 | << "Variable " << data->name_hint << " does not point to a primitive." ; |
557 | |
558 | auto n = make_object<BufferNode>(); |
559 | n->data = std::move(data); |
560 | n->dtype = dtype; |
561 | |
562 | n->shape = std::move(shape); |
563 | n->strides = std::move(strides); |
564 | n->axis_separators = std::move(axis_separators); |
565 | n->name = std::move(name); |
566 | if (!elem_offset.defined()) { |
567 | elem_offset = make_const(n->DefaultIndexType(), 0); |
568 | } |
569 | if (data_alignment <= 0) { |
570 | data_alignment = runtime::kAllocAlignment; |
571 | } |
572 | if (offset_factor == 0) { |
573 | offset_factor = 1; |
574 | } |
575 | n->elem_offset = std::move(elem_offset); |
576 | n->data_alignment = data_alignment; |
577 | n->offset_factor = offset_factor; |
578 | n->buffer_type = buffer_type; |
579 | if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { |
580 | for (size_t i = 0; i < n->shape.size(); ++i) { |
581 | n->strides.push_back(Var("stride" , n->shape[i].dtype())); |
582 | } |
583 | } |
584 | n->span = std::move(span); |
585 | data_ = std::move(n); |
586 | } |
587 | |
588 | tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name, |
589 | int data_alignment, int offset_factor, bool compact, |
590 | std::string memory_scope) { |
591 | DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); |
592 | auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); |
593 | bool has_any = false; |
594 | if (!compact) { |
595 | for (const auto& it : shape) { |
596 | if (it.as<tir::VarNode>()) { |
597 | has_any = true; |
598 | break; |
599 | } |
600 | } |
601 | } |
602 | tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; |
603 | |
604 | PrimExpr elem_offset; |
605 | if (offset_factor != 0) { |
606 | elem_offset = tir::Var(name + "_elem_offset" , shape[0].dtype()); |
607 | } else { |
608 | elem_offset = PrimExpr(); |
609 | } |
610 | |
611 | return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, data_alignment, |
612 | offset_factor, buffer_type); |
613 | } |
614 | |
615 | TVM_REGISTER_NODE_TYPE(BufferNode); |
616 | |
617 | TVM_REGISTER_GLOBAL("tir.Buffer" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
618 | ICHECK_EQ(args.size(), 11); |
619 | auto buffer_type = args[8].operator String(); |
620 | BufferType type = (buffer_type == "auto_broadcast" ) ? kAutoBroadcast : kDefault; |
621 | *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], type, |
622 | args[9], args[10]); |
623 | }); |
624 | |
625 | TVM_REGISTER_GLOBAL("tir.BufferAccessPtr" ).set_body_method(&Buffer::access_ptr); |
626 | |
627 | TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer" ).set_body_method(&Buffer::GetFlattenedBuffer); |
628 | |
629 | TVM_REGISTER_GLOBAL("tir.BufferOffsetOf" ).set_body_method(&Buffer::OffsetOf); |
630 | |
631 | TVM_REGISTER_GLOBAL("tir.BufferVLoad" ).set_body_method(&Buffer::vload); |
632 | |
633 | TVM_REGISTER_GLOBAL("tir.BufferVStore" ).set_body_method(&Buffer::vstore); |
634 | |
635 | TVM_REGISTER_GLOBAL("tir.BufferStorageScope" ).set_body_method(&Buffer::scope); |
636 | |
637 | } // namespace tir |
638 | } // namespace tvm |
639 | |