1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | 5B * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Compute Op. |
22 | * \file compute_op.cc |
23 | */ |
24 | #include "compute_op.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/builtin.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | |
34 | #include <string> |
35 | #include <unordered_set> |
36 | #include <utility> |
37 | |
38 | #include "../../arith/interval_set.h" |
39 | #include "../schedule/message_passing.h" |
40 | #include "op_utils.h" |
41 | |
42 | namespace tvm { |
43 | namespace te { |
44 | using namespace tir; |
45 | |
46 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
47 | .set_dispatch<ComputeOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
48 | auto* op = static_cast<const ComputeOpNode*>(node.get()); |
49 | p->stream << "compute(" << op->name << ", body=" << op->body << ", axis=" << op->axis |
50 | << ", reduce_axis=" << op->reduce_axis << ", tag=" << op->tag |
51 | << ", attrs=" << op->attrs << ")" ; |
52 | }); |
53 | |
54 | TVM_REGISTER_NODE_TYPE(ComputeOpNode); |
55 | |
56 | /// Verify if ComputeOp is valid with respect to Reduce operations. |
57 | static void VerifyComputeOp(const ComputeOpNode* op); |
58 | |
59 | inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { |
60 | return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && |
61 | (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) && |
62 | ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); |
63 | } |
64 | |
65 | int ComputeOpNode::num_outputs() const { return body.size(); } |
66 | |
67 | Array<IterVar> BaseComputeOpNode::root_iter_vars() const { |
68 | if (reduce_axis.size() == 0) return axis; |
69 | Array<IterVar> ret = axis; |
70 | for (IterVar iv : reduce_axis) { |
71 | ret.push_back(iv); |
72 | } |
73 | return ret; |
74 | } |
75 | |
76 | DataType ComputeOpNode::output_dtype(size_t idx) const { |
77 | ICHECK_LT(idx, num_outputs()); |
78 | return body[idx].dtype(); |
79 | } |
80 | |
81 | Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const { |
82 | ICHECK_LT(idx, num_outputs()); |
83 | // for now, all outputs of a BaseComputeOp have the same shape |
84 | Array<PrimExpr> shape; |
85 | for (const auto& ivar : this->axis) { |
86 | const Range& r = ivar->dom; |
87 | shape.push_back(r->extent); |
88 | } |
89 | return shape; |
90 | } |
91 | |
92 | Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag, |
93 | Map<String, ObjectRef> attrs) { |
94 | // compute dimension. |
95 | size_t ndim = shape.size(); |
96 | std::vector<IterVar> axis; |
97 | std::vector<Var> args; |
98 | for (size_t i = 0; i < ndim; ++i) { |
99 | std::ostringstream os; |
100 | os << "ax" << i; |
101 | axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); |
102 | args.push_back(axis.back()->var); |
103 | } |
104 | |
105 | return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); |
106 | } |
107 | |
108 | Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string name, |
109 | std::string tag, Map<String, ObjectRef> attrs) { |
110 | // compute dimension. |
111 | size_t ndim = shape.size(); |
112 | std::vector<IterVar> axis; |
113 | std::vector<Var> args; |
114 | for (size_t i = 0; i < ndim; ++i) { |
115 | std::ostringstream os; |
116 | os << "ax" << i; |
117 | axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); |
118 | args.push_back(axis.back()->var); |
119 | } |
120 | |
121 | Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); |
122 | Array<Tensor> outputs; |
123 | for (int idx = 0; idx < op->num_outputs(); ++idx) { |
124 | outputs.push_back(op.output(idx)); |
125 | } |
126 | return outputs; |
127 | } |
128 | |
129 | ComputeOp::ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
130 | Array<IterVar> axis, Array<PrimExpr> body) { |
131 | if (!attrs.defined()) { |
132 | attrs = Map<String, ObjectRef>(); |
133 | } |
134 | auto n = make_object<ComputeOpNode>(); |
135 | n->name = std::move(name); |
136 | n->tag = std::move(tag); |
137 | n->attrs = std::move(attrs); |
138 | n->axis = std::move(axis); |
139 | n->body = std::move(body); |
140 | if (n->body[0]->IsInstance<tir::ReduceNode>()) { |
141 | const tir::ReduceNode* reduce = n->body[0].as<tir::ReduceNode>(); |
142 | n->reduce_axis = reduce->axis; |
143 | } |
144 | VerifyComputeOp(n.get()); |
145 | data_ = std::move(n); |
146 | } |
147 | |
148 | TVM_REGISTER_GLOBAL("te.ComputeOp" ) |
149 | .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs, |
150 | Array<IterVar> axis, |
151 | Array<PrimExpr> body) { return ComputeOp(name, tag, attrs, axis, body); }); |
152 | |
153 | // The schedule related logics |
154 | Array<Tensor> ComputeOpNode::InputTensors() const { |
155 | Array<Tensor> ret; |
156 | std::unordered_set<Tensor> visited; |
157 | for (auto& e : body) { |
158 | tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { |
159 | if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
160 | Tensor t = Downcast<Tensor>(pload->producer); |
161 | if (!visited.count(t)) { |
162 | ret.push_back(t); |
163 | visited.insert(t); |
164 | } |
165 | } |
166 | }); |
167 | } |
168 | return ret; |
169 | } |
170 | |
171 | Operation ComputeOpNode::ReplaceInputs(const Operation& self, |
172 | const std::unordered_map<Tensor, Tensor>& rmap) const { |
173 | ICHECK_EQ(self.operator->(), this); |
174 | VerifyComputeOp(this); |
175 | Array<PrimExpr> arr; |
176 | if (this->body[0]->IsInstance<tir::ReduceNode>()) { |
177 | // Specially handle reduce so the replaced op |
178 | // still share all the components |
179 | PrimExpr new_reduce = te::ReplaceTensor(this->body[0], rmap); |
180 | if (!new_reduce.same_as(this->body[0])) { |
181 | const tir::ReduceNode* r = new_reduce.as<tir::ReduceNode>(); |
182 | for (size_t k = 0; k < this->body.size(); ++k) { |
183 | auto n = make_object<tir::ReduceNode>(*r); |
184 | n->value_index = static_cast<int>(k); |
185 | n->dtype = r->source[k].dtype(); |
186 | arr.push_back(PrimExpr(n)); |
187 | } |
188 | } else { |
189 | arr = this->body; |
190 | } |
191 | } else { |
192 | arr = |
193 | UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); |
194 | } |
195 | if (!arr.same_as(this->body)) { |
196 | return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr); |
197 | } else { |
198 | return self; |
199 | } |
200 | } |
201 | |
202 | void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
203 | const std::unordered_map<const VarNode*, IntSet>& dom_map, |
204 | std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
205 | ICHECK_EQ(self.operator->(), this); |
206 | auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { |
207 | if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
208 | Tensor t = Downcast<Tensor>(pload->producer); |
209 | if (t->op.defined() && out_dom_map->count(t)) { |
210 | TensorDom& dom = out_dom_map->at(t); |
211 | for (size_t i = 0; i < t.ndim(); ++i) { |
212 | // We assume that the value of the argument cannot be out of bounds (otherwise it is |
213 | // undefined behaviour), so we can intersect the estimated set of the argument with the |
214 | // range expected by the tensor. However, intersection may result in overly complex |
215 | // expressions, so we perform a more relaxed form of intersection. |
216 | IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map)); |
217 | const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>(); |
218 | if (arg_interval) { |
219 | PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); |
220 | PrimExpr shape_i_max_value = t->shape[i] - 1; |
221 | PrimExpr min_value = arg_interval->min_value; |
222 | PrimExpr max_value = arg_interval->max_value; |
223 | // Prefer the shape bounds only when we can prove they are tighter. |
224 | // We must update bound's ends in pairs. Here is an counter example: shape_i is |
225 | // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is |
226 | // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0], |
227 | // awkward for further analysis. |
228 | if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || |
229 | (analyzer->CanProve(shape_i_min_value >= min_value) && |
230 | analyzer->CanProve(shape_i_max_value <= max_value))) { |
231 | min_value = shape_i_min_value; |
232 | max_value = shape_i_max_value; |
233 | } |
234 | dom.data[i].push_back(IntSet::Interval(min_value, max_value)); |
235 | } else { |
236 | dom.data[i].push_back(arg_intset); |
237 | } |
238 | } |
239 | } |
240 | } |
241 | }; |
242 | for (auto& e : body) tir::PostOrderVisit(e, fvisit); |
243 | } |
244 | |
245 | void BaseComputeOpNode::GatherBound(const Operation& self, |
246 | const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
247 | std::unordered_map<IterVar, Range>* out_dom_map) const { |
248 | ICHECK_EQ(self.operator->(), this); |
249 | const TensorDom& tdom = tensor_dom.at(self.output(0)); |
250 | for (size_t i = 0; i < this->axis.size(); ++i) { |
251 | Range r = arith::Union(tdom.data.at(i)).CoverRange(this->axis[i]->dom); |
252 | ICHECK(!out_dom_map->count(this->axis[i])); |
253 | (*out_dom_map)[this->axis[i]] = r; |
254 | } |
255 | for (size_t i = 0; i < this->reduce_axis.size(); ++i) { |
256 | ICHECK(!out_dom_map->count(this->reduce_axis[i])); |
257 | (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; |
258 | } |
259 | } |
260 | |
261 | Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, |
262 | const std::unordered_map<IterVar, Range>& realize_map, |
263 | const Stmt& body, String storage_scope) const { |
264 | ICHECK_EQ(stage->op.get(), this); |
265 | Region bounds; |
266 | for (IterVar iv : this->axis) { |
267 | bounds.push_back(realize_map.at(iv)); |
268 | } |
269 | Stmt realize = body; |
270 | for (int i = this->num_outputs(); i > 0; --i) { |
271 | Tensor t = stage->op.output(i - 1); |
272 | realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); |
273 | // alignment requirement, only useful for compute |
274 | for (size_t i = 0; i < num_schedulable_dims(); ++i) { |
275 | auto it = stage->iter_var_attrs.find(this->axis[i]); |
276 | if (it != stage->iter_var_attrs.end()) { |
277 | IterVarAttr attr = (*it).second; |
278 | if (attr->dim_align_factor != 0) { |
279 | Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor, |
280 | attr->dim_align_offset}; |
281 | realize = |
282 | tir::AttrStmt(t, tir::attr::buffer_dim_align, |
283 | Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), realize); |
284 | } |
285 | } |
286 | } |
287 | } |
288 | return realize; |
289 | } |
290 | |
291 | size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); } |
292 | |
293 | // Build a reduction body. |
294 | void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt* init, |
295 | Stmt* provide) { |
296 | Array<PrimExpr> args; |
297 | for (IterVar iv : op->axis) { |
298 | args.push_back(iv->var); |
299 | } |
300 | std::vector<Stmt> inits, provides; |
301 | |
302 | size_t size = op->body.size(); |
303 | const ReduceNode* reduce = op->body[0].as<ReduceNode>(); |
304 | ICHECK(reduce); |
305 | const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>(); |
306 | ICHECK(combiner); |
307 | Array<PrimExpr> lhs; |
308 | for (size_t i = 0; i < size; ++i) { |
309 | lhs.push_back(tensors[i](args)); |
310 | } |
311 | Array<PrimExpr> init_value = combiner->identity_element; |
312 | Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source); |
313 | |
314 | // If an init was passed to ReduceNode, use that for initialization |
315 | // instead of combiner->identity_element |
316 | Array<PrimExpr> reduce_init = reduce->init; |
317 | if (!reduce_init.empty()) { |
318 | init_value = reduce_init; |
319 | } |
320 | for (size_t i = 0; i < size; ++i) { |
321 | Tensor t = tensors[i]; |
322 | inits.emplace_back(ProducerStore(t, init_value[i], args)); |
323 | provides.emplace_back(ProducerStore(t, update_value[i], args)); |
324 | } |
325 | *init = SeqStmt::Flatten(inits); |
326 | *provide = SeqStmt::Flatten(provides); |
327 | if (!is_one(reduce->condition)) { |
328 | *provide = IfThenElse(reduce->condition, *provide); |
329 | } |
330 | } |
331 | |
332 | // Normal computation. |
333 | Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { |
334 | Array<PrimExpr> args; |
335 | for (IterVar iv : op->axis) { |
336 | args.push_back(iv->var); |
337 | } |
338 | return ProducerStore(t, op->body[t->value_index], args); |
339 | } |
340 | |
341 | Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, |
342 | const std::unordered_map<IterVar, Range>& dom_map, |
343 | bool debug_keep_trivial_loop) { |
344 | // grab the nest structure |
345 | ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); |
346 | // Normal loop structure |
347 | n.init_nest.emplace_back(MakeIfNest(n.init_predicates)); |
348 | n.main_nest.emplace_back(MakeIfNest(n.main_predicates)); |
349 | if (self->reduce_axis.size() != 0) { |
350 | // make reduction. |
351 | Stmt init, provide; |
352 | Array<Tensor> source; |
353 | for (size_t i = 0; i < self->body.size(); ++i) { |
354 | source.push_back(stage->op.output(i)); |
355 | } |
356 | MakeReduction(self, source, &init, &provide); |
357 | init = MergeNest(n.init_nest, init); |
358 | init = Substitute(init, n.init_vmap); |
359 | // common nest |
360 | std::vector<std::vector<Stmt>> common(n.main_nest.begin(), |
361 | n.main_nest.begin() + n.num_common_loop + 1); |
362 | std::vector<std::vector<Stmt>> reduce(n.main_nest.begin() + n.num_common_loop + 1, |
363 | n.main_nest.end()); |
364 | provide = MergeNest(reduce, provide); |
365 | if (debug_keep_trivial_loop) { |
366 | provide = MergeNest(common, provide); |
367 | } else { |
368 | provide = MergeNest(common, SeqStmt::Flatten(init, provide)); |
369 | } |
370 | // run substitution in the on the full nest, because loop condition |
371 | // could depend on outer loops. |
372 | return Substitute(provide, n.main_vmap); |
373 | } else { |
374 | std::vector<Stmt> provides; |
375 | for (size_t i = 0; i < self->body.size(); ++i) { |
376 | provides.emplace_back(MakeProvide(self, stage->op.output(i))); |
377 | } |
378 | Stmt provide = SeqStmt::Flatten(provides); |
379 | provide = MergeNest(n.main_nest, provide); |
380 | // run substitution in the on the full nest, because loop condition |
381 | // could depend on outer loops. |
382 | return Substitute(provide, n.main_vmap); |
383 | } |
384 | } |
385 | |
386 | enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize }; |
387 | |
388 | ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) { |
389 | // Verify correctness of leaf nest. |
390 | int thread_red = 0, tensorize = 0; |
391 | |
392 | for (IterVar iv : stage->leaf_iter_vars) { |
393 | IterVarAttr attr; |
394 | auto it = stage->iter_var_attrs.find(iv); |
395 | if (it != stage->iter_var_attrs.end()) { |
396 | attr = (*it).second; |
397 | } |
398 | if (attr.defined() && attr->iter_type == kTensorized) { |
399 | ++tensorize; |
400 | } |
401 | if (iv->iter_type == kCommReduce) { |
402 | if (attr.defined() && attr->bind_thread.defined()) { |
403 | ++thread_red; |
404 | } |
405 | } else { |
406 | ICHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis" ; |
407 | } |
408 | } |
409 | if (tensorize != 0) { |
410 | ICHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize" ; |
411 | return ComputeType::kTensorize; |
412 | } |
413 | if (thread_red != 0) { |
414 | return ComputeType::kCrossThreadReduction; |
415 | } else { |
416 | return ComputeType::kNormal; |
417 | } |
418 | } |
419 | |
420 | // implement the provide utility. |
421 | Stmt ComputeOpNode::BuildProvide(const Stage& stage, |
422 | const std::unordered_map<IterVar, Range>& dom_map, |
423 | bool debug_keep_trivial_loop) const { |
424 | ICHECK_EQ(stage->op.operator->(), this); |
425 | ComputeType ctype = DetectComputeType(this, stage); |
426 | if (ctype == ComputeType::kCrossThreadReduction) { |
427 | // specially handle cross thread reduction. |
428 | return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop); |
429 | } else if (ctype == ComputeType::kTensorize) { |
430 | return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop); |
431 | } else { |
432 | return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop); |
433 | } |
434 | } |
435 | |
436 | ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage, |
437 | const std::unordered_map<IterVar, Range>& dom_map, |
438 | bool debug_keep_trivial_loop) { |
439 | ICHECK_EQ(stage->op.operator->(), self); |
440 | ComputeLoopNest ret; |
441 | // make main loop nest |
442 | ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set<IterVar>(), |
443 | &ret.main_vmap, debug_keep_trivial_loop); |
444 | ret.main_predicates = |
445 | MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>()); |
446 | for (auto& e : ret.main_predicates) { |
447 | e = likely(e); |
448 | } |
449 | if (stage->store_predicate.defined()) { |
450 | ret.main_predicates.push_back(stage->store_predicate); |
451 | } |
452 | if (self->reduce_axis.size() != 0) { |
453 | // try to find the location to insert the initialization. |
454 | // Fuse the initialization and provide loop when possible. |
455 | std::unordered_map<IterVar, int> update_state; |
456 | for (IterVar iv : self->reduce_axis) { |
457 | update_state[iv] = 2; |
458 | } |
459 | for (size_t i = 0; i < self->num_schedulable_dims(); ++i) { |
460 | update_state[self->axis[i]] = 1; |
461 | } |
462 | // find which iter var is related to reduction and which is related to axis. |
463 | te::PassDownBitMaskOr(stage, &update_state); |
464 | auto leaf_iter_vars = stage->leaf_iter_vars; |
465 | // first first loop that is related to reduction. |
466 | size_t begin_loop = leaf_iter_vars.size(); |
467 | for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { |
468 | auto iv = leaf_iter_vars[i]; |
469 | int flag = update_state.at(iv); |
470 | if ((flag & 2) != 0) { |
471 | begin_loop = i; |
472 | break; |
473 | } |
474 | ret.init_vmap[iv] = ret.main_vmap.at(iv); |
475 | } |
476 | ret.num_common_loop = begin_loop; |
477 | // skip loops that are related to reduction and are unrelated to axis. |
478 | std::unordered_set<IterVar> skip_iter; |
479 | for (auto kv : update_state) { |
480 | int flag = kv.second; |
481 | if (flag == 2) skip_iter.insert(kv.first); |
482 | } |
483 | ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), |
484 | debug_keep_trivial_loop); |
485 | ret.init_predicates = |
486 | MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter); |
487 | for (auto& e : ret.init_predicates) { |
488 | e = likely(e); |
489 | } |
490 | } else { |
491 | ICHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1); |
492 | ret.num_common_loop = stage->leaf_iter_vars.size(); |
493 | } |
494 | // copy elison here. |
495 | return ret; |
496 | } |
497 | |
498 | namespace { |
499 | /*! |
500 | * \brief Verify if ComputeOp is valid with respect to Reduce operations. |
501 | * |
502 | * The following two properties are verified: |
503 | * (1) All Reduce operations must exist at top level. |
504 | * (2) For a list of operations, if one is Reduce, then the others |
505 | * must be Reduce as well; and their inputs should have the |
506 | * same attribute except value_index. |
507 | */ |
508 | class ComputeVerifier final : protected tir::ExprVisitor { |
509 | public: |
510 | /// Special member functions |
511 | //@{ |
512 | explicit ComputeVerifier(const ComputeOpNode* compute) |
513 | : compute_(compute), reduce_(compute->body[0].as<tir::ReduceNode>()) {} |
514 | virtual ~ComputeVerifier() = default; |
515 | ComputeVerifier(const ComputeVerifier&) = delete; |
516 | ComputeVerifier(ComputeVerifier&&) = delete; |
517 | ComputeVerifier& operator=(const ComputeVerifier&) = delete; |
518 | ComputeVerifier& operator=(ComputeVerifier&&) = delete; |
519 | //@} |
520 | |
521 | /// Interface to perform compute verification |
522 | void Run() { |
523 | for (const PrimExpr e : compute_->body) { |
524 | // Check for consistency of top level reductions |
525 | const tir::ReduceNode* reduce = e.as<tir::ReduceNode>(); |
526 | ICHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " |
527 | << "with being Reduce operation or not." ; |
528 | |
529 | if (reduce && reduce_) { |
530 | ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " |
531 | << "have the same attribute except value_index" ; |
532 | } |
533 | |
534 | level_ = 0; |
535 | ExprVisitor::VisitExpr(e); |
536 | } |
537 | } |
538 | |
539 | protected: |
540 | /// Visitor implementation |
541 | //@{ |
542 | void VisitExpr(const PrimExpr& n) final { |
543 | ++level_; |
544 | ExprVisitor::VisitExpr(n); |
545 | --level_; |
546 | } |
547 | |
548 | void VisitExpr_(const tir::ReduceNode* op) final { |
549 | // Check for non top level reductions |
550 | ICHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " |
551 | << "Please create another tensor for further composition." ; |
552 | } |
553 | //@} |
554 | |
555 | private: |
556 | const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify |
557 | const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation |
558 | int level_{0}; ///< Level of op being processed |
559 | }; |
560 | } // namespace |
561 | |
562 | /// Verify if ComputeOp is valid with respect to Reduce operations. |
563 | static void VerifyComputeOp(const ComputeOpNode* op) { |
564 | ComputeVerifier v(op); |
565 | v.Run(); |
566 | } |
567 | |
568 | Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
569 | const ComputeLoopNest& n, Stmt body, Stmt update) { |
570 | Array<PrimExpr> conds; |
571 | std::unordered_set<const VarNode*> banned; |
572 | for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { |
573 | IterVar iv = stage->leaf_iter_vars[i]; |
574 | auto iit = stage->iter_var_attrs.find(iv); |
575 | if (iit != stage->iter_var_attrs.end()) { |
576 | const IterVarAttr& attr = (*iit).second; |
577 | if (attr->iter_type == kTensorized) { |
578 | break; |
579 | } |
580 | } |
581 | if (iv->iter_type == kCommReduce) { |
582 | auto vit = dom_map.find(iv); |
583 | ICHECK(vit != dom_map.end()); |
584 | const Range& vrange = vit->second; |
585 | conds.push_back(likely(iv->var > vrange->min)); |
586 | banned.insert(iv->var.get()); |
587 | } |
588 | } |
589 | |
590 | auto fbanned = [&](const VarNode* node) { return banned.count(node); }; |
591 | |
592 | for (const PrimExpr& pred : n.main_predicates) { |
593 | if (tir::UsesVar(pred, fbanned)) { |
594 | LOG(FATAL) << "Tensorize update transform failed, the condition " << pred |
595 | << " has a conflict with the reset condition" ; |
596 | } |
597 | } |
598 | |
599 | auto cond = foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_or(a, b, span); }, |
600 | const_false(1), conds); |
601 | return IfThenElse(cond, update, body); |
602 | } |
603 | |
604 | } // namespace te |
605 | } // namespace tvm |
606 | |