1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Logics related to tensorize, used by ComputeOpNode.
22 * \file tensorize.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include "../schedule/message_passing.h"
30#include "compute_op.h"
31#include "op_utils.h"
32
33namespace tvm {
34namespace te {
35
36using namespace tir;
37
38// Detect the region of input and output to be tensorized.
39// out_dom: the domain of root iter vars in output op
40// in_region: region of each input tensor.
41// return The location of the tensorized scope start.
42size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage,
43 const std::unordered_map<IterVar, Range>& dom_map,
44 std::unordered_map<IterVar, Range>* out_dom,
45 std::unordered_map<Tensor, Array<Range>>* in_region) {
46 // Get the bound of the tensorized scope.
47 bool found_point = false;
48 size_t loc_scope = 0;
49 std::unordered_map<IterVar, IntSet> up_state;
50 // Loop over the leafs
51 for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) {
52 IterVar iv = stage->leaf_iter_vars[i - 1];
53 ICHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce);
54 auto vit = dom_map.find(iv);
55 ICHECK(vit != dom_map.end());
56 const Range& vrange = vit->second;
57 if (is_one(vrange->extent)) {
58 up_state[iv] = IntSet::SinglePoint(vrange->min);
59 } else if (found_point) {
60 ICHECK(is_zero(vrange->min));
61 up_state[iv] = IntSet::SinglePoint(iv->var);
62 } else {
63 up_state[iv] = IntSet::FromRange(vrange);
64 }
65 auto iit = stage->iter_var_attrs.find(iv);
66 if (iit != stage->iter_var_attrs.end()) {
67 const IterVarAttr& attr = (*iit).second;
68 if (!found_point) {
69 ICHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope";
70 }
71 if (attr->iter_type == kTensorized) {
72 ICHECK(!found_point) << "Do not allow two tensorized point";
73 found_point = true;
74 loc_scope = i - 1;
75 }
76 }
77 }
78 ICHECK(found_point);
79 // Get domain of the tensorized scope.
80 te::PassUpDomain(stage, dom_map, &up_state);
81 // Get domains if inputs
82 std::unordered_map<Tensor, TensorDom> in_dom;
83 std::unordered_map<const VarNode*, IntSet> temp_dmap;
84 arith::Analyzer analyzer;
85 Array<Tensor> inputs = self->InputTensors();
86 for (Tensor t : inputs) {
87 in_dom.emplace(t, TensorDom(t.ndim()));
88 }
89 for (IterVar iv : self->root_iter_vars()) {
90 IntSet iset = up_state.at(iv);
91 Range iv_range = iset.CoverRange(dom_map.at(iv));
92 (*out_dom)[iv] = iv_range;
93 analyzer.Bind(iv->var, iv_range);
94 temp_dmap[iv->var.get()] = iset;
95 }
96 // Input domains
97 self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom);
98 Range none;
99 for (const auto& kv : in_dom) {
100 Array<Range> vec;
101 const Tensor& t = kv.first;
102 for (size_t i = 0; i < t.ndim(); ++i) {
103 Range r = arith::Union(kv.second.data.at(i)).CoverRange(none);
104 ICHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
105 vec.push_back(std::move(r));
106 }
107 (*in_region)[t] = std::move(vec);
108 }
109 return loc_scope;
110}
111
112void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
113 const ComputeLoopNest& n, size_t tloc) {
114 // Veirfication step.
115 std::unordered_set<const VarNode*> banned;
116 ICHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
117 ICHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0);
118 auto f_push_banned = [&banned](const Stmt& s) {
119 if (const ForNode* op = s.as<ForNode>()) {
120 banned.insert(op->loop_var.get());
121 } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
122 if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
123 banned.insert(iv->var.get());
124 }
125 } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
126 banned.insert(op->var.get());
127 }
128 };
129 for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) {
130 for (const Stmt& s : n.main_nest[i + 1]) {
131 f_push_banned(s);
132 }
133 if (n.init_nest.size() != 0) {
134 for (const Stmt& s : n.init_nest[i + 1]) {
135 f_push_banned(s);
136 }
137 }
138 }
139
140 auto fbanned = [&](const VarNode* node) { return banned.count(node); };
141
142 for (const PrimExpr& pred : n.main_predicates) {
143 if (tir::UsesVar(pred, fbanned)) {
144 LOG(FATAL) << "Tensorize failed, split condition " << pred
145 << " relies on var defined inside tensorize scope";
146 }
147 }
148 for (const PrimExpr& pred : n.init_predicates) {
149 if (tir::UsesVar(pred, fbanned)) {
150 LOG(FATAL) << "Tensorize failed, split condition " << pred
151 << " relies on var defined inside tensorize scope";
152 }
153 }
154}
155
156// Remap the tensor placeholder, index and inline things.
157class TensorIntrinMatcher final : public StmtExprMutator {
158 public:
159 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
160 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
161 op = expr.as<ProducerLoadNode>();
162 auto t = Downcast<Tensor>(op->producer);
163 auto it = in_remap_.find(t);
164 if (it != in_remap_.end()) {
165 const InputEntry& e = it->second;
166 ICHECK_EQ(op->indices.size(), e.region.size());
167 Array<PrimExpr> indices;
168 for (size_t i = e.start; i < e.region.size(); ++i) {
169 indices.push_back(op->indices[i] - e.region[i]->min);
170 }
171 return ProducerLoad(e.tensor, indices);
172 }
173 return expr;
174 }
175
176 PrimExpr VisitExpr_(const VarNode* op) final {
177 auto it = var_remap_.find(op);
178 if (it != var_remap_.end()) {
179 return it->second;
180 } else {
181 return GetRef<PrimExpr>(op);
182 }
183 }
184
185 PrimExpr VisitExpr_(const ReduceNode* op) final {
186 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
187 op = expr.as<ReduceNode>();
188 Array<IterVar> axis;
189 for (size_t i = 0; i < op->axis.size(); ++i) {
190 auto it = axis_remap_.find(op->axis[i]);
191 if (it != axis_remap_.end()) {
192 axis.push_back(it->second);
193 }
194 }
195 return Reduce(op->combiner, op->source, axis, op->condition, op->value_index, op->init);
196 }
197
198 void Init(const ComputeOpNode* self, const Stage& stage,
199 const std::unordered_map<IterVar, Range>& dom_map,
200 const std::unordered_map<IterVar, Range>& out_dom,
201 const std::unordered_map<Tensor, Array<Range>>& in_region, const TensorIntrin& intrin,
202 Map<Var, Range>* compute_intrin_iter_space) {
203 ICHECK(self == stage->op.get());
204
205 for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
206 IterVar iv = stage->leaf_iter_vars[i];
207 auto vit = dom_map.find(iv);
208 if (vit != dom_map.end()) {
209 const Range vrange = vit->second;
210 compute_intrin_iter_space->Set(iv->var, vrange);
211 }
212 }
213 analyzer_.Bind(*compute_intrin_iter_space);
214
215 // input remap.
216 Array<Tensor> inputs = self->InputTensors();
217 ICHECK_EQ(inputs.size(), intrin->inputs.size());
218 for (size_t i = 0; i < inputs.size(); ++i) {
219 InputEntry e;
220 e.tensor = intrin->inputs[i];
221 e.region = Array<Range>(in_region.at(inputs[i]));
222 ICHECK_GE(e.region.size(), e.tensor.ndim());
223 // Enable fuzzy matching, to match [1, n, m] to [n, m]
224 e.start = e.region.size() - e.tensor.ndim();
225 for (size_t j = 0; j < e.start; ++j) {
226 auto canonical_extent = analyzer_.Simplify(e.region[j]->extent);
227 ICHECK(is_one(canonical_extent))
228 << "Tensorize " << intrin->name << ":"
229 << " Input dimension mismatch with tensor intrin "
230 << " expected shape=" << e.tensor->shape << ", given region=" << e.region;
231 }
232 in_remap_[inputs[i]] = e;
233 }
234 // output remap
235 const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
236 ICHECK(intrin_compute) << "Only support compute intrinsic for now";
237 ICHECK_GE(self->axis.size(), intrin_compute->axis.size())
238 << "Tensorize: Output mismatch with tensor intrin ";
239 // Enable fuzzy matching, to match [1, n, m] to [n, m]
240 size_t axis_start = self->axis.size() - intrin_compute->axis.size();
241 for (size_t i = 0; i < axis_start; ++i) {
242 Range r = out_dom.at(self->axis[i]);
243 ICHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin "
244 << " intrin-dim=" << intrin_compute->axis.size()
245 << ", tensorize-dim=" << self->axis.size();
246 var_remap_[self->axis[i]->var.get()] = r->min;
247 }
248 // Assume we tensorize at regin axis i [min, min + extent)
249 // The corresponding intrinsic axis is j [0, extent)
250 // Remap index i to j + min
251 for (size_t i = axis_start; i < self->axis.size(); ++i) {
252 IterVar iv = self->axis[i];
253 IterVar target_iv = intrin_compute->axis[i - axis_start];
254 Range r = out_dom.at(iv);
255 var_remap_[iv->var.get()] = target_iv->var + r->min;
256 axis_remap_[iv] = target_iv;
257 compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
258 }
259 // Remap reduction axis
260 ICHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
261 << "Tensorize: Reduction dimension mismatch with tensor intrin";
262 axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
263 for (size_t i = 0; i < axis_start; ++i) {
264 Range r = out_dom.at(self->reduce_axis[i]);
265 ICHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin "
266 << " intrin-dim=" << intrin_compute->reduce_axis.size()
267 << ", tensorize-dim=" << self->reduce_axis.size();
268 var_remap_[self->reduce_axis[i]->var.get()] = r->min;
269 }
270 for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) {
271 IterVar iv = self->reduce_axis[i];
272 IterVar target_iv = intrin_compute->reduce_axis[i - axis_start];
273 Range r = out_dom.at(iv);
274 var_remap_[iv->var.get()] = target_iv->var + r->min;
275 axis_remap_[iv] = target_iv;
276 compute_intrin_iter_space->Set(target_iv->var, target_iv->dom);
277 }
278 }
279
280 private:
281 // Input entry
282 struct InputEntry {
283 Tensor tensor;
284 size_t start;
285 Array<Range> region;
286 };
287 // input data remap
288 std::unordered_map<Tensor, InputEntry> in_remap_;
289 // variable remap.
290 std::unordered_map<const VarNode*, PrimExpr> var_remap_;
291 // IterVar remap.
292 std::unordered_map<IterVar, IterVar> axis_remap_;
293 // arith analyzer
294 arith::Analyzer analyzer_;
295};
296
297// Try to match tensor dataflow of the stage with the intrinsic
298Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage,
299 const std::unordered_map<IterVar, Range>& dom_map,
300 const std::unordered_map<IterVar, Range>& out_dom,
301 const std::unordered_map<Tensor, Array<Range>>& in_region,
302 const TensorIntrin& intrin,
303 Map<Var, Range>* compute_intrin_iter_space) {
304 TensorIntrinMatcher matcher;
305 matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
306 Array<PrimExpr> ret;
307 for (PrimExpr expr : self->body) {
308 ret.push_back(matcher(expr));
309 }
310 return ret;
311}
312
313void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,
314 const std::unordered_map<IterVar, PrimExpr>& value_map,
315 const std::unordered_map<IterVar, Range>& dom_map,
316 const std::unordered_map<IterVar, Range>& out_dom,
317 const std::unordered_map<Tensor, Array<Range>>& in_region,
318 const TensorIntrin& intrin) {
319 StructuralEqual expr_equal;
320 Map<Var, Range> compute_intrin_iter_space;
321 Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
322 &compute_intrin_iter_space);
323 const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
324 ICHECK(intrin_compute) << "Only support compute intrinsic for now";
325 ICHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch";
326 arith::Analyzer ana;
327 ana.Bind(compute_intrin_iter_space);
328
329 for (size_t i = 0; i < body.size(); ++i) {
330 PrimExpr lhs = ana.Simplify(Substitute(body[i], value_map));
331 // run substitution because the intrin body could depend on outer loop vars.
332 PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map));
333 if (lhs.dtype() != rhs.dtype()) {
334 LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name
335 << "'s declaration "
336 << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype();
337 }
338 ICHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name
339 << "'s declaration "
340 << " provided= " << lhs << ", intrin= " << rhs
341 << ", running this stage: " << stage;
342 }
343}
344
345Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
346 const std::unordered_map<IterVar, Range>& dom_map,
347 bool debug_keep_trivial_loop) {
348 std::unordered_map<IterVar, Range> out_dom;
349 std::unordered_map<Tensor, Array<Range>> in_region;
350 size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
351 TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin;
352 ICHECK(intrin.defined());
353 ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
354 VerifyTensorizeLoopNest(self, stage, n, tloc);
355 VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, intrin);
356 // Start bind data.
357 Stmt nop = Evaluate(0);
358 std::vector<Stmt> input_bind_nest, output_bind_nest;
359 Array<Tensor> inputs = self->InputTensors();
360 ICHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch ";
361 // input binding
362 for (size_t i = 0; i < intrin->inputs.size(); ++i) {
363 Tensor tensor = inputs[i];
364 Buffer buffer = intrin->buffers[i];
365 Array<ObjectRef> bind_spec{buffer, tensor};
366 auto it = in_region.find(tensor);
367 ICHECK(it != in_region.end());
368 const Array<Range>& region = it->second;
369 Array<PrimExpr> tuple;
370 for (const Range r : region) {
371 tuple.push_back(r->min);
372 tuple.push_back(r->extent);
373 }
374 input_bind_nest.emplace_back(
375 AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
376 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
377 }
378 // output binding
379 const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
380 ICHECK(intrin_compute) << "Only support compute intrinsic for now";
381 ICHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
382 ICHECK_EQ(intrin_compute->body.size(), self->body.size());
383 Array<PrimExpr> tuple;
384 for (IterVar iv : self->axis) {
385 auto it = out_dom.find(iv);
386 ICHECK(it != out_dom.end());
387 tuple.push_back(it->second->min);
388 tuple.push_back(it->second->extent);
389 }
390 for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) {
391 Tensor tensor = stage->op.output(i - intrin->inputs.size());
392 Buffer buffer = intrin->buffers[i];
393 Array<ObjectRef> bind_spec{buffer, tensor};
394 output_bind_nest.emplace_back(
395 AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
396 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop));
397 }
398 // Check variable remap
399 std::unordered_map<const VarNode*, PrimExpr> vmap;
400 tir::ArgBinder binder(&vmap);
401 ICHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
402 << "Tensorization fail: reduction axis size do not match";
403 size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size();
404 for (size_t i = 0; i < start; ++i) {
405 IterVar iv = self->reduce_axis[i];
406 auto it = out_dom.find(iv);
407 ICHECK(it != out_dom.end());
408 ICHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match";
409 }
410 for (size_t i = start; i < self->reduce_axis.size(); ++i) {
411 IterVar iv = self->reduce_axis[i];
412 IterVar target = intrin_compute->reduce_axis[i - start];
413 auto it = out_dom.find(iv);
414 ICHECK(it != out_dom.end());
415 binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
416 "tensir_intrin.reduction.min");
417 binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent");
418 }
419 if (tloc <= n.num_common_loop) {
420 // Do no need to split reduction
421 std::vector<std::vector<Stmt>> nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
422 nest.emplace_back(MakeIfNest(n.main_predicates));
423 ICHECK_EQ(n.init_predicates.size(), 0U);
424 ICHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined";
425 Stmt body = MergeNest(output_bind_nest, intrin->body);
426 body = MergeNest(input_bind_nest, body);
427 body = tir::Substitute(body, vmap);
428 body = MergeNest(binder.asserts(), body);
429 body = te::Substitute(body, n.main_vmap);
430 return MergeNest(nest, body);
431 } else {
432 // Need to split reduction
433 ICHECK(intrin->reduce_update.defined())
434 << "Reduction update op for intrin " << intrin << " is not defined";
435 // Need init and update steps
436 ICHECK_NE(self->reduce_axis.size(), 0U);
437 std::vector<std::vector<Stmt>> common(n.main_nest.begin(),
438 n.main_nest.begin() + n.num_common_loop + 1);
439 std::vector<std::vector<Stmt>> update_nest(n.main_nest.begin() + n.num_common_loop + 1,
440 n.main_nest.begin() + tloc + 1);
441 update_nest.emplace_back(MakeIfNest(n.main_predicates));
442
443 if (intrin->reduce_init.defined()) {
444 // init nest
445 std::vector<std::vector<Stmt>> init_nest(n.init_nest.begin(), n.init_nest.begin() + tloc + 1);
446 init_nest.emplace_back(MakeIfNest(n.init_predicates));
447 Stmt init = MergeNest(output_bind_nest, intrin->reduce_init);
448 init = te::Substitute(init, n.init_vmap);
449 init = MergeNest(init_nest, init);
450 // The update
451 Stmt update = MergeNest(output_bind_nest, intrin->reduce_update);
452 update = MergeNest(input_bind_nest, update);
453 update = tir::Substitute(update, vmap);
454 update = MergeNest(binder.asserts(), update);
455 update = te::Substitute(update, n.main_vmap);
456 update = MergeNest(update_nest, update);
457 return MergeNest(common, SeqStmt::Flatten(init, update));
458 } else {
459 // When init op is not available, use body op for reset in the first iter.
460 ICHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined";
461 Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update);
462 update = MergeNest(output_bind_nest, update);
463 update = MergeNest(input_bind_nest, update);
464 update = tir::Substitute(update, vmap);
465 update = MergeNest(binder.asserts(), update);
466 update = te::Substitute(update, n.main_vmap);
467 update = MergeNest(update_nest, update);
468 return MergeNest(common, update);
469 }
470 }
471}
472
473// Register functions for unittests
474TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion").set_body([](TVMArgs args, TVMRetValue* ret) {
475 Stage stage = args[0];
476 Map<IterVar, Range> dmap = args[1];
477 std::unordered_map<IterVar, Range> out_dom;
478 std::unordered_map<Tensor, Array<Range>> in_region;
479 ICHECK(stage->op.as<ComputeOpNode>());
480 InferTensorizeRegion(stage->op.as<ComputeOpNode>(), stage, as_unordered_map(dmap), &out_dom,
481 &in_region);
482 *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom), Map<Tensor, Array<Range>>(in_region)};
483});
484
485TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody").set_body([](TVMArgs args, TVMRetValue* ret) {
486 Stage stage = args[0];
487 Map<IterVar, Range> out_dom = args[1];
488 Map<Tensor, Array<Range>> in_region = args[2];
489 TensorIntrin intrin = args[3];
490 Map<Var, Range> vrange;
491 ICHECK(stage->op.as<ComputeOpNode>());
492 *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), stage, {{}}, as_unordered_map(out_dom),
493 as_unordered_map(in_region), intrin, &vrange);
494});
495} // namespace te
496} // namespace tvm
497