1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
75B * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Compute Op.
22 * \file compute_op.cc
23 */
24#include "compute_op.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/te/operation.h>
29#include <tvm/tir/analysis.h>
30#include <tvm/tir/builtin.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/stmt_functor.h>
33
34#include <string>
35#include <unordered_set>
36#include <utility>
37
38#include "../../arith/interval_set.h"
39#include "../schedule/message_passing.h"
40#include "op_utils.h"
41
42namespace tvm {
43namespace te {
44using namespace tir;
45
46TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
47 .set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) {
48 auto* op = static_cast<const ComputeOpNode*>(node.get());
49 p->stream << "compute(" << op->name << ", body=" << op->body << ", axis=" << op->axis
50 << ", reduce_axis=" << op->reduce_axis << ", tag=" << op->tag
51 << ", attrs=" << op->attrs << ")";
52 });
53
54TVM_REGISTER_NODE_TYPE(ComputeOpNode);
55
56/// Verify if ComputeOp is valid with respect to Reduce operations.
57static void VerifyComputeOp(const ComputeOpNode* op);
58
59inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
60 return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
61 (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) &&
62 ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
63}
64
65int ComputeOpNode::num_outputs() const { return body.size(); }
66
67Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
68 if (reduce_axis.size() == 0) return axis;
69 Array<IterVar> ret = axis;
70 for (IterVar iv : reduce_axis) {
71 ret.push_back(iv);
72 }
73 return ret;
74}
75
76DataType ComputeOpNode::output_dtype(size_t idx) const {
77 ICHECK_LT(idx, num_outputs());
78 return body[idx].dtype();
79}
80
81Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
82 ICHECK_LT(idx, num_outputs());
83 // for now, all outputs of a BaseComputeOp have the same shape
84 Array<PrimExpr> shape;
85 for (const auto& ivar : this->axis) {
86 const Range& r = ivar->dom;
87 shape.push_back(r->extent);
88 }
89 return shape;
90}
91
92Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag,
93 Map<String, ObjectRef> attrs) {
94 // compute dimension.
95 size_t ndim = shape.size();
96 std::vector<IterVar> axis;
97 std::vector<Var> args;
98 for (size_t i = 0; i < ndim; ++i) {
99 std::ostringstream os;
100 os << "ax" << i;
101 axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
102 args.push_back(axis.back()->var);
103 }
104
105 return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0);
106}
107
108Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string name,
109 std::string tag, Map<String, ObjectRef> attrs) {
110 // compute dimension.
111 size_t ndim = shape.size();
112 std::vector<IterVar> axis;
113 std::vector<Var> args;
114 for (size_t i = 0; i < ndim; ++i) {
115 std::ostringstream os;
116 os << "ax" << i;
117 axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
118 args.push_back(axis.back()->var);
119 }
120
121 Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args));
122 Array<Tensor> outputs;
123 for (int idx = 0; idx < op->num_outputs(); ++idx) {
124 outputs.push_back(op.output(idx));
125 }
126 return outputs;
127}
128
129ComputeOp::ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
130 Array<IterVar> axis, Array<PrimExpr> body) {
131 if (!attrs.defined()) {
132 attrs = Map<String, ObjectRef>();
133 }
134 auto n = make_object<ComputeOpNode>();
135 n->name = std::move(name);
136 n->tag = std::move(tag);
137 n->attrs = std::move(attrs);
138 n->axis = std::move(axis);
139 n->body = std::move(body);
140 if (n->body[0]->IsInstance<tir::ReduceNode>()) {
141 const tir::ReduceNode* reduce = n->body[0].as<tir::ReduceNode>();
142 n->reduce_axis = reduce->axis;
143 }
144 VerifyComputeOp(n.get());
145 data_ = std::move(n);
146}
147
148TVM_REGISTER_GLOBAL("te.ComputeOp")
149 .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
150 Array<IterVar> axis,
151 Array<PrimExpr> body) { return ComputeOp(name, tag, attrs, axis, body); });
152
153// The schedule related logics
154Array<Tensor> ComputeOpNode::InputTensors() const {
155 Array<Tensor> ret;
156 std::unordered_set<Tensor> visited;
157 for (auto& e : body) {
158 tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
159 if (auto* pload = n.as<tir::ProducerLoadNode>()) {
160 Tensor t = Downcast<Tensor>(pload->producer);
161 if (!visited.count(t)) {
162 ret.push_back(t);
163 visited.insert(t);
164 }
165 }
166 });
167 }
168 return ret;
169}
170
171Operation ComputeOpNode::ReplaceInputs(const Operation& self,
172 const std::unordered_map<Tensor, Tensor>& rmap) const {
173 ICHECK_EQ(self.operator->(), this);
174 VerifyComputeOp(this);
175 Array<PrimExpr> arr;
176 if (this->body[0]->IsInstance<tir::ReduceNode>()) {
177 // Specially handle reduce so the replaced op
178 // still share all the components
179 PrimExpr new_reduce = te::ReplaceTensor(this->body[0], rmap);
180 if (!new_reduce.same_as(this->body[0])) {
181 const tir::ReduceNode* r = new_reduce.as<tir::ReduceNode>();
182 for (size_t k = 0; k < this->body.size(); ++k) {
183 auto n = make_object<tir::ReduceNode>(*r);
184 n->value_index = static_cast<int>(k);
185 n->dtype = r->source[k].dtype();
186 arr.push_back(PrimExpr(n));
187 }
188 } else {
189 arr = this->body;
190 }
191 } else {
192 arr =
193 UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); });
194 }
195 if (!arr.same_as(this->body)) {
196 return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr);
197 } else {
198 return self;
199 }
200}
201
202void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
203 const std::unordered_map<const VarNode*, IntSet>& dom_map,
204 std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
205 ICHECK_EQ(self.operator->(), this);
206 auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
207 if (auto* pload = n.as<tir::ProducerLoadNode>()) {
208 Tensor t = Downcast<Tensor>(pload->producer);
209 if (t->op.defined() && out_dom_map->count(t)) {
210 TensorDom& dom = out_dom_map->at(t);
211 for (size_t i = 0; i < t.ndim(); ++i) {
212 // We assume that the value of the argument cannot be out of bounds (otherwise it is
213 // undefined behaviour), so we can intersect the estimated set of the argument with the
214 // range expected by the tensor. However, intersection may result in overly complex
215 // expressions, so we perform a more relaxed form of intersection.
216 IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map));
217 const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
218 if (arg_interval) {
219 PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
220 PrimExpr shape_i_max_value = t->shape[i] - 1;
221 PrimExpr min_value = arg_interval->min_value;
222 PrimExpr max_value = arg_interval->max_value;
223 // Prefer the shape bounds only when we can prove they are tighter.
224 // We must update bound's ends in pairs. Here is an counter example: shape_i is
225 // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
226 // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
227 // awkward for further analysis.
228 if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
229 (analyzer->CanProve(shape_i_min_value >= min_value) &&
230 analyzer->CanProve(shape_i_max_value <= max_value))) {
231 min_value = shape_i_min_value;
232 max_value = shape_i_max_value;
233 }
234 dom.data[i].push_back(IntSet::Interval(min_value, max_value));
235 } else {
236 dom.data[i].push_back(arg_intset);
237 }
238 }
239 }
240 }
241 };
242 for (auto& e : body) tir::PostOrderVisit(e, fvisit);
243}
244
245void BaseComputeOpNode::GatherBound(const Operation& self,
246 const std::unordered_map<Tensor, TensorDom>& tensor_dom,
247 std::unordered_map<IterVar, Range>* out_dom_map) const {
248 ICHECK_EQ(self.operator->(), this);
249 const TensorDom& tdom = tensor_dom.at(self.output(0));
250 for (size_t i = 0; i < this->axis.size(); ++i) {
251 Range r = arith::Union(tdom.data.at(i)).CoverRange(this->axis[i]->dom);
252 ICHECK(!out_dom_map->count(this->axis[i]));
253 (*out_dom_map)[this->axis[i]] = r;
254 }
255 for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
256 ICHECK(!out_dom_map->count(this->reduce_axis[i]));
257 (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
258 }
259}
260
261Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
262 const std::unordered_map<IterVar, Range>& realize_map,
263 const Stmt& body, String storage_scope) const {
264 ICHECK_EQ(stage->op.get(), this);
265 Region bounds;
266 for (IterVar iv : this->axis) {
267 bounds.push_back(realize_map.at(iv));
268 }
269 Stmt realize = body;
270 for (int i = this->num_outputs(); i > 0; --i) {
271 Tensor t = stage->op.output(i - 1);
272 realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope);
273 // alignment requirement, only useful for compute
274 for (size_t i = 0; i < num_schedulable_dims(); ++i) {
275 auto it = stage->iter_var_attrs.find(this->axis[i]);
276 if (it != stage->iter_var_attrs.end()) {
277 IterVarAttr attr = (*it).second;
278 if (attr->dim_align_factor != 0) {
279 Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
280 attr->dim_align_offset};
281 realize =
282 tir::AttrStmt(t, tir::attr::buffer_dim_align,
283 Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize);
284 }
285 }
286 }
287 }
288 return realize;
289}
290
291size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); }
292
293// Build a reduction body.
294void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt* init,
295 Stmt* provide) {
296 Array<PrimExpr> args;
297 for (IterVar iv : op->axis) {
298 args.push_back(iv->var);
299 }
300 std::vector<Stmt> inits, provides;
301
302 size_t size = op->body.size();
303 const ReduceNode* reduce = op->body[0].as<ReduceNode>();
304 ICHECK(reduce);
305 const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
306 ICHECK(combiner);
307 Array<PrimExpr> lhs;
308 for (size_t i = 0; i < size; ++i) {
309 lhs.push_back(tensors[i](args));
310 }
311 Array<PrimExpr> init_value = combiner->identity_element;
312 Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
313
314 // If an init was passed to ReduceNode, use that for initialization
315 // instead of combiner->identity_element
316 Array<PrimExpr> reduce_init = reduce->init;
317 if (!reduce_init.empty()) {
318 init_value = reduce_init;
319 }
320 for (size_t i = 0; i < size; ++i) {
321 Tensor t = tensors[i];
322 inits.emplace_back(ProducerStore(t, init_value[i], args));
323 provides.emplace_back(ProducerStore(t, update_value[i], args));
324 }
325 *init = SeqStmt::Flatten(inits);
326 *provide = SeqStmt::Flatten(provides);
327 if (!is_one(reduce->condition)) {
328 *provide = IfThenElse(reduce->condition, *provide);
329 }
330}
331
332// Normal computation.
333Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
334 Array<PrimExpr> args;
335 for (IterVar iv : op->axis) {
336 args.push_back(iv->var);
337 }
338 return ProducerStore(t, op->body[t->value_index], args);
339}
340
341Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
342 const std::unordered_map<IterVar, Range>& dom_map,
343 bool debug_keep_trivial_loop) {
344 // grab the nest structure
345 ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
346 // Normal loop structure
347 n.init_nest.emplace_back(MakeIfNest(n.init_predicates));
348 n.main_nest.emplace_back(MakeIfNest(n.main_predicates));
349 if (self->reduce_axis.size() != 0) {
350 // make reduction.
351 Stmt init, provide;
352 Array<Tensor> source;
353 for (size_t i = 0; i < self->body.size(); ++i) {
354 source.push_back(stage->op.output(i));
355 }
356 MakeReduction(self, source, &init, &provide);
357 init = MergeNest(n.init_nest, init);
358 init = Substitute(init, n.init_vmap);
359 // common nest
360 std::vector<std::vector<Stmt>> common(n.main_nest.begin(),
361 n.main_nest.begin() + n.num_common_loop + 1);
362 std::vector<std::vector<Stmt>> reduce(n.main_nest.begin() + n.num_common_loop + 1,
363 n.main_nest.end());
364 provide = MergeNest(reduce, provide);
365 if (debug_keep_trivial_loop) {
366 provide = MergeNest(common, provide);
367 } else {
368 provide = MergeNest(common, SeqStmt::Flatten(init, provide));
369 }
370 // run substitution in the on the full nest, because loop condition
371 // could depend on outer loops.
372 return Substitute(provide, n.main_vmap);
373 } else {
374 std::vector<Stmt> provides;
375 for (size_t i = 0; i < self->body.size(); ++i) {
376 provides.emplace_back(MakeProvide(self, stage->op.output(i)));
377 }
378 Stmt provide = SeqStmt::Flatten(provides);
379 provide = MergeNest(n.main_nest, provide);
380 // run substitution in the on the full nest, because loop condition
381 // could depend on outer loops.
382 return Substitute(provide, n.main_vmap);
383 }
384}
385
386enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize };
387
388ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) {
389 // Verify correctness of leaf nest.
390 int thread_red = 0, tensorize = 0;
391
392 for (IterVar iv : stage->leaf_iter_vars) {
393 IterVarAttr attr;
394 auto it = stage->iter_var_attrs.find(iv);
395 if (it != stage->iter_var_attrs.end()) {
396 attr = (*it).second;
397 }
398 if (attr.defined() && attr->iter_type == kTensorized) {
399 ++tensorize;
400 }
401 if (iv->iter_type == kCommReduce) {
402 if (attr.defined() && attr->bind_thread.defined()) {
403 ++thread_red;
404 }
405 } else {
406 ICHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis";
407 }
408 }
409 if (tensorize != 0) {
410 ICHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize";
411 return ComputeType::kTensorize;
412 }
413 if (thread_red != 0) {
414 return ComputeType::kCrossThreadReduction;
415 } else {
416 return ComputeType::kNormal;
417 }
418}
419
420// implement the provide utility.
421Stmt ComputeOpNode::BuildProvide(const Stage& stage,
422 const std::unordered_map<IterVar, Range>& dom_map,
423 bool debug_keep_trivial_loop) const {
424 ICHECK_EQ(stage->op.operator->(), this);
425 ComputeType ctype = DetectComputeType(this, stage);
426 if (ctype == ComputeType::kCrossThreadReduction) {
427 // specially handle cross thread reduction.
428 return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
429 } else if (ctype == ComputeType::kTensorize) {
430 return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
431 } else {
432 return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
433 }
434}
435
436ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage,
437 const std::unordered_map<IterVar, Range>& dom_map,
438 bool debug_keep_trivial_loop) {
439 ICHECK_EQ(stage->op.operator->(), self);
440 ComputeLoopNest ret;
441 // make main loop nest
442 ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(),
443 &ret.main_vmap, debug_keep_trivial_loop);
444 ret.main_predicates =
445 MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>());
446 for (auto& e : ret.main_predicates) {
447 e = likely(e);
448 }
449 if (stage->store_predicate.defined()) {
450 ret.main_predicates.push_back(stage->store_predicate);
451 }
452 if (self->reduce_axis.size() != 0) {
453 // try to find the location to insert the initialization.
454 // Fuse the initialization and provide loop when possible.
455 std::unordered_map<IterVar, int> update_state;
456 for (IterVar iv : self->reduce_axis) {
457 update_state[iv] = 2;
458 }
459 for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
460 update_state[self->axis[i]] = 1;
461 }
462 // find which iter var is related to reduction and which is related to axis.
463 te::PassDownBitMaskOr(stage, &update_state);
464 auto leaf_iter_vars = stage->leaf_iter_vars;
465 // first first loop that is related to reduction.
466 size_t begin_loop = leaf_iter_vars.size();
467 for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
468 auto iv = leaf_iter_vars[i];
469 int flag = update_state.at(iv);
470 if ((flag & 2) != 0) {
471 begin_loop = i;
472 break;
473 }
474 ret.init_vmap[iv] = ret.main_vmap.at(iv);
475 }
476 ret.num_common_loop = begin_loop;
477 // skip loops that are related to reduction and are unrelated to axis.
478 std::unordered_set<IterVar> skip_iter;
479 for (auto kv : update_state) {
480 int flag = kv.second;
481 if (flag == 2) skip_iter.insert(kv.first);
482 }
483 ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
484 debug_keep_trivial_loop);
485 ret.init_predicates =
486 MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
487 for (auto& e : ret.init_predicates) {
488 e = likely(e);
489 }
490 } else {
491 ICHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
492 ret.num_common_loop = stage->leaf_iter_vars.size();
493 }
494 // copy elison here.
495 return ret;
496}
497
498namespace {
499/*!
500 * \brief Verify if ComputeOp is valid with respect to Reduce operations.
501 *
502 * The following two properties are verified:
503 * (1) All Reduce operations must exist at top level.
504 * (2) For a list of operations, if one is Reduce, then the others
505 * must be Reduce as well; and their inputs should have the
506 * same attribute except value_index.
507 */
508class ComputeVerifier final : protected tir::ExprVisitor {
509 public:
510 /// Special member functions
511 //@{
512 explicit ComputeVerifier(const ComputeOpNode* compute)
513 : compute_(compute), reduce_(compute->body[0].as<tir::ReduceNode>()) {}
514 virtual ~ComputeVerifier() = default;
515 ComputeVerifier(const ComputeVerifier&) = delete;
516 ComputeVerifier(ComputeVerifier&&) = delete;
517 ComputeVerifier& operator=(const ComputeVerifier&) = delete;
518 ComputeVerifier& operator=(ComputeVerifier&&) = delete;
519 //@}
520
521 /// Interface to perform compute verification
522 void Run() {
523 for (const PrimExpr e : compute_->body) {
524 // Check for consistency of top level reductions
525 const tir::ReduceNode* reduce = e.as<tir::ReduceNode>();
526 ICHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent "
527 << "with being Reduce operation or not.";
528
529 if (reduce && reduce_) {
530 ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should "
531 << "have the same attribute except value_index";
532 }
533
534 level_ = 0;
535 ExprVisitor::VisitExpr(e);
536 }
537 }
538
539 protected:
540 /// Visitor implementation
541 //@{
542 void VisitExpr(const PrimExpr& n) final {
543 ++level_;
544 ExprVisitor::VisitExpr(n);
545 --level_;
546 }
547
548 void VisitExpr_(const tir::ReduceNode* op) final {
549 // Check for non top level reductions
550 ICHECK(0 == level_) << "Reductions are only allowed at the top level of compute. "
551 << "Please create another tensor for further composition.";
552 }
553 //@}
554
555 private:
556 const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
557 const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
558 int level_{0}; ///< Level of op being processed
559};
560} // namespace
561
562/// Verify if ComputeOp is valid with respect to Reduce operations.
563static void VerifyComputeOp(const ComputeOpNode* op) {
564 ComputeVerifier v(op);
565 v.Run();
566}
567
568Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
569 const ComputeLoopNest& n, Stmt body, Stmt update) {
570 Array<PrimExpr> conds;
571 std::unordered_set<const VarNode*> banned;
572 for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
573 IterVar iv = stage->leaf_iter_vars[i];
574 auto iit = stage->iter_var_attrs.find(iv);
575 if (iit != stage->iter_var_attrs.end()) {
576 const IterVarAttr& attr = (*iit).second;
577 if (attr->iter_type == kTensorized) {
578 break;
579 }
580 }
581 if (iv->iter_type == kCommReduce) {
582 auto vit = dom_map.find(iv);
583 ICHECK(vit != dom_map.end());
584 const Range& vrange = vit->second;
585 conds.push_back(likely(iv->var > vrange->min));
586 banned.insert(iv->var.get());
587 }
588 }
589
590 auto fbanned = [&](const VarNode* node) { return banned.count(node); };
591
592 for (const PrimExpr& pred : n.main_predicates) {
593 if (tir::UsesVar(pred, fbanned)) {
594 LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
595 << " has a conflict with the reset condition";
596 }
597 }
598
599 auto cond = foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_or(a, b, span); },
600 const_false(1), conds);
601 return IfThenElse(cond, update, body);
602}
603
604} // namespace te
605} // namespace tvm
606