1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Logics related to tensorize, used by ComputeOpNode. |
22 | * \file tensorize.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include "../schedule/message_passing.h" |
30 | #include "compute_op.h" |
31 | #include "op_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace te { |
35 | |
36 | using namespace tir; |
37 | |
38 | // Detect the region of input and output to be tensorized. |
39 | // out_dom: the domain of root iter vars in output op |
40 | // in_region: region of each input tensor. |
41 | // return The location of the tensorized scope start. |
42 | size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage, |
43 | const std::unordered_map<IterVar, Range>& dom_map, |
44 | std::unordered_map<IterVar, Range>* out_dom, |
45 | std::unordered_map<Tensor, Array<Range>>* in_region) { |
46 | // Get the bound of the tensorized scope. |
47 | bool found_point = false; |
48 | size_t loc_scope = 0; |
49 | std::unordered_map<IterVar, IntSet> up_state; |
50 | // Loop over the leafs |
51 | for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) { |
52 | IterVar iv = stage->leaf_iter_vars[i - 1]; |
53 | ICHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce); |
54 | auto vit = dom_map.find(iv); |
55 | ICHECK(vit != dom_map.end()); |
56 | const Range& vrange = vit->second; |
57 | if (is_one(vrange->extent)) { |
58 | up_state[iv] = IntSet::SinglePoint(vrange->min); |
59 | } else if (found_point) { |
60 | ICHECK(is_zero(vrange->min)); |
61 | up_state[iv] = IntSet::SinglePoint(iv->var); |
62 | } else { |
63 | up_state[iv] = IntSet::FromRange(vrange); |
64 | } |
65 | auto iit = stage->iter_var_attrs.find(iv); |
66 | if (iit != stage->iter_var_attrs.end()) { |
67 | const IterVarAttr& attr = (*iit).second; |
68 | if (!found_point) { |
69 | ICHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope" ; |
70 | } |
71 | if (attr->iter_type == kTensorized) { |
72 | ICHECK(!found_point) << "Do not allow two tensorized point" ; |
73 | found_point = true; |
74 | loc_scope = i - 1; |
75 | } |
76 | } |
77 | } |
78 | ICHECK(found_point); |
79 | // Get domain of the tensorized scope. |
80 | te::PassUpDomain(stage, dom_map, &up_state); |
81 | // Get domains if inputs |
82 | std::unordered_map<Tensor, TensorDom> in_dom; |
83 | std::unordered_map<const VarNode*, IntSet> temp_dmap; |
84 | arith::Analyzer analyzer; |
85 | Array<Tensor> inputs = self->InputTensors(); |
86 | for (Tensor t : inputs) { |
87 | in_dom.emplace(t, TensorDom(t.ndim())); |
88 | } |
89 | for (IterVar iv : self->root_iter_vars()) { |
90 | IntSet iset = up_state.at(iv); |
91 | Range iv_range = iset.CoverRange(dom_map.at(iv)); |
92 | (*out_dom)[iv] = iv_range; |
93 | analyzer.Bind(iv->var, iv_range); |
94 | temp_dmap[iv->var.get()] = iset; |
95 | } |
96 | // Input domains |
97 | self->PropBoundToInputs(stage->op, &analyzer, temp_dmap, &in_dom); |
98 | Range none; |
99 | for (const auto& kv : in_dom) { |
100 | Array<Range> vec; |
101 | const Tensor& t = kv.first; |
102 | for (size_t i = 0; i < t.ndim(); ++i) { |
103 | Range r = arith::Union(kv.second.data.at(i)).CoverRange(none); |
104 | ICHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t; |
105 | vec.push_back(std::move(r)); |
106 | } |
107 | (*in_region)[t] = std::move(vec); |
108 | } |
109 | return loc_scope; |
110 | } |
111 | |
112 | void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, |
113 | const ComputeLoopNest& n, size_t tloc) { |
114 | // Veirfication step. |
115 | std::unordered_set<const VarNode*> banned; |
116 | ICHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); |
117 | ICHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0); |
118 | auto f_push_banned = [&banned](const Stmt& s) { |
119 | if (const ForNode* op = s.as<ForNode>()) { |
120 | banned.insert(op->loop_var.get()); |
121 | } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) { |
122 | if (const IterVarNode* iv = op->node.as<IterVarNode>()) { |
123 | banned.insert(iv->var.get()); |
124 | } |
125 | } else if (const LetStmtNode* op = s.as<LetStmtNode>()) { |
126 | banned.insert(op->var.get()); |
127 | } |
128 | }; |
129 | for (size_t i = tloc; i < stage->leaf_iter_vars.size(); ++i) { |
130 | for (const Stmt& s : n.main_nest[i + 1]) { |
131 | f_push_banned(s); |
132 | } |
133 | if (n.init_nest.size() != 0) { |
134 | for (const Stmt& s : n.init_nest[i + 1]) { |
135 | f_push_banned(s); |
136 | } |
137 | } |
138 | } |
139 | |
140 | auto fbanned = [&](const VarNode* node) { return banned.count(node); }; |
141 | |
142 | for (const PrimExpr& pred : n.main_predicates) { |
143 | if (tir::UsesVar(pred, fbanned)) { |
144 | LOG(FATAL) << "Tensorize failed, split condition " << pred |
145 | << " relies on var defined inside tensorize scope" ; |
146 | } |
147 | } |
148 | for (const PrimExpr& pred : n.init_predicates) { |
149 | if (tir::UsesVar(pred, fbanned)) { |
150 | LOG(FATAL) << "Tensorize failed, split condition " << pred |
151 | << " relies on var defined inside tensorize scope" ; |
152 | } |
153 | } |
154 | } |
155 | |
156 | // Remap the tensor placeholder, index and inline things. |
157 | class TensorIntrinMatcher final : public StmtExprMutator { |
158 | public: |
159 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
160 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
161 | op = expr.as<ProducerLoadNode>(); |
162 | auto t = Downcast<Tensor>(op->producer); |
163 | auto it = in_remap_.find(t); |
164 | if (it != in_remap_.end()) { |
165 | const InputEntry& e = it->second; |
166 | ICHECK_EQ(op->indices.size(), e.region.size()); |
167 | Array<PrimExpr> indices; |
168 | for (size_t i = e.start; i < e.region.size(); ++i) { |
169 | indices.push_back(op->indices[i] - e.region[i]->min); |
170 | } |
171 | return ProducerLoad(e.tensor, indices); |
172 | } |
173 | return expr; |
174 | } |
175 | |
176 | PrimExpr VisitExpr_(const VarNode* op) final { |
177 | auto it = var_remap_.find(op); |
178 | if (it != var_remap_.end()) { |
179 | return it->second; |
180 | } else { |
181 | return GetRef<PrimExpr>(op); |
182 | } |
183 | } |
184 | |
185 | PrimExpr VisitExpr_(const ReduceNode* op) final { |
186 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
187 | op = expr.as<ReduceNode>(); |
188 | Array<IterVar> axis; |
189 | for (size_t i = 0; i < op->axis.size(); ++i) { |
190 | auto it = axis_remap_.find(op->axis[i]); |
191 | if (it != axis_remap_.end()) { |
192 | axis.push_back(it->second); |
193 | } |
194 | } |
195 | return Reduce(op->combiner, op->source, axis, op->condition, op->value_index, op->init); |
196 | } |
197 | |
198 | void Init(const ComputeOpNode* self, const Stage& stage, |
199 | const std::unordered_map<IterVar, Range>& dom_map, |
200 | const std::unordered_map<IterVar, Range>& out_dom, |
201 | const std::unordered_map<Tensor, Array<Range>>& in_region, const TensorIntrin& intrin, |
202 | Map<Var, Range>* compute_intrin_iter_space) { |
203 | ICHECK(self == stage->op.get()); |
204 | |
205 | for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { |
206 | IterVar iv = stage->leaf_iter_vars[i]; |
207 | auto vit = dom_map.find(iv); |
208 | if (vit != dom_map.end()) { |
209 | const Range vrange = vit->second; |
210 | compute_intrin_iter_space->Set(iv->var, vrange); |
211 | } |
212 | } |
213 | analyzer_.Bind(*compute_intrin_iter_space); |
214 | |
215 | // input remap. |
216 | Array<Tensor> inputs = self->InputTensors(); |
217 | ICHECK_EQ(inputs.size(), intrin->inputs.size()); |
218 | for (size_t i = 0; i < inputs.size(); ++i) { |
219 | InputEntry e; |
220 | e.tensor = intrin->inputs[i]; |
221 | e.region = Array<Range>(in_region.at(inputs[i])); |
222 | ICHECK_GE(e.region.size(), e.tensor.ndim()); |
223 | // Enable fuzzy matching, to match [1, n, m] to [n, m] |
224 | e.start = e.region.size() - e.tensor.ndim(); |
225 | for (size_t j = 0; j < e.start; ++j) { |
226 | auto canonical_extent = analyzer_.Simplify(e.region[j]->extent); |
227 | ICHECK(is_one(canonical_extent)) |
228 | << "Tensorize " << intrin->name << ":" |
229 | << " Input dimension mismatch with tensor intrin " |
230 | << " expected shape=" << e.tensor->shape << ", given region=" << e.region; |
231 | } |
232 | in_remap_[inputs[i]] = e; |
233 | } |
234 | // output remap |
235 | const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
236 | ICHECK(intrin_compute) << "Only support compute intrinsic for now" ; |
237 | ICHECK_GE(self->axis.size(), intrin_compute->axis.size()) |
238 | << "Tensorize: Output mismatch with tensor intrin " ; |
239 | // Enable fuzzy matching, to match [1, n, m] to [n, m] |
240 | size_t axis_start = self->axis.size() - intrin_compute->axis.size(); |
241 | for (size_t i = 0; i < axis_start; ++i) { |
242 | Range r = out_dom.at(self->axis[i]); |
243 | ICHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin " |
244 | << " intrin-dim=" << intrin_compute->axis.size() |
245 | << ", tensorize-dim=" << self->axis.size(); |
246 | var_remap_[self->axis[i]->var.get()] = r->min; |
247 | } |
248 | // Assume we tensorize at regin axis i [min, min + extent) |
249 | // The corresponding intrinsic axis is j [0, extent) |
250 | // Remap index i to j + min |
251 | for (size_t i = axis_start; i < self->axis.size(); ++i) { |
252 | IterVar iv = self->axis[i]; |
253 | IterVar target_iv = intrin_compute->axis[i - axis_start]; |
254 | Range r = out_dom.at(iv); |
255 | var_remap_[iv->var.get()] = target_iv->var + r->min; |
256 | axis_remap_[iv] = target_iv; |
257 | compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); |
258 | } |
259 | // Remap reduction axis |
260 | ICHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) |
261 | << "Tensorize: Reduction dimension mismatch with tensor intrin" ; |
262 | axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); |
263 | for (size_t i = 0; i < axis_start; ++i) { |
264 | Range r = out_dom.at(self->reduce_axis[i]); |
265 | ICHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin " |
266 | << " intrin-dim=" << intrin_compute->reduce_axis.size() |
267 | << ", tensorize-dim=" << self->reduce_axis.size(); |
268 | var_remap_[self->reduce_axis[i]->var.get()] = r->min; |
269 | } |
270 | for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { |
271 | IterVar iv = self->reduce_axis[i]; |
272 | IterVar target_iv = intrin_compute->reduce_axis[i - axis_start]; |
273 | Range r = out_dom.at(iv); |
274 | var_remap_[iv->var.get()] = target_iv->var + r->min; |
275 | axis_remap_[iv] = target_iv; |
276 | compute_intrin_iter_space->Set(target_iv->var, target_iv->dom); |
277 | } |
278 | } |
279 | |
280 | private: |
281 | // Input entry |
282 | struct InputEntry { |
283 | Tensor tensor; |
284 | size_t start; |
285 | Array<Range> region; |
286 | }; |
287 | // input data remap |
288 | std::unordered_map<Tensor, InputEntry> in_remap_; |
289 | // variable remap. |
290 | std::unordered_map<const VarNode*, PrimExpr> var_remap_; |
291 | // IterVar remap. |
292 | std::unordered_map<IterVar, IterVar> axis_remap_; |
293 | // arith analyzer |
294 | arith::Analyzer analyzer_; |
295 | }; |
296 | |
297 | // Try to match tensor dataflow of the stage with the intrinsic |
298 | Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage, |
299 | const std::unordered_map<IterVar, Range>& dom_map, |
300 | const std::unordered_map<IterVar, Range>& out_dom, |
301 | const std::unordered_map<Tensor, Array<Range>>& in_region, |
302 | const TensorIntrin& intrin, |
303 | Map<Var, Range>* compute_intrin_iter_space) { |
304 | TensorIntrinMatcher matcher; |
305 | matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); |
306 | Array<PrimExpr> ret; |
307 | for (PrimExpr expr : self->body) { |
308 | ret.push_back(matcher(expr)); |
309 | } |
310 | return ret; |
311 | } |
312 | |
313 | void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, |
314 | const std::unordered_map<IterVar, PrimExpr>& value_map, |
315 | const std::unordered_map<IterVar, Range>& dom_map, |
316 | const std::unordered_map<IterVar, Range>& out_dom, |
317 | const std::unordered_map<Tensor, Array<Range>>& in_region, |
318 | const TensorIntrin& intrin) { |
319 | StructuralEqual expr_equal; |
320 | Map<Var, Range> compute_intrin_iter_space; |
321 | Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, |
322 | &compute_intrin_iter_space); |
323 | const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
324 | ICHECK(intrin_compute) << "Only support compute intrinsic for now" ; |
325 | ICHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch" ; |
326 | arith::Analyzer ana; |
327 | ana.Bind(compute_intrin_iter_space); |
328 | |
329 | for (size_t i = 0; i < body.size(); ++i) { |
330 | PrimExpr lhs = ana.Simplify(Substitute(body[i], value_map)); |
331 | // run substitution because the intrin body could depend on outer loop vars. |
332 | PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); |
333 | if (lhs.dtype() != rhs.dtype()) { |
334 | LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name |
335 | << "'s declaration " |
336 | << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype(); |
337 | } |
338 | ICHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name |
339 | << "'s declaration " |
340 | << " provided= " << lhs << ", intrin= " << rhs |
341 | << ", running this stage: " << stage; |
342 | } |
343 | } |
344 | |
345 | Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, |
346 | const std::unordered_map<IterVar, Range>& dom_map, |
347 | bool debug_keep_trivial_loop) { |
348 | std::unordered_map<IterVar, Range> out_dom; |
349 | std::unordered_map<Tensor, Array<Range>> in_region; |
350 | size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); |
351 | TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; |
352 | ICHECK(intrin.defined()); |
353 | ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); |
354 | VerifyTensorizeLoopNest(self, stage, n, tloc); |
355 | VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, intrin); |
356 | // Start bind data. |
357 | Stmt nop = Evaluate(0); |
358 | std::vector<Stmt> input_bind_nest, output_bind_nest; |
359 | Array<Tensor> inputs = self->InputTensors(); |
360 | ICHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch " ; |
361 | // input binding |
362 | for (size_t i = 0; i < intrin->inputs.size(); ++i) { |
363 | Tensor tensor = inputs[i]; |
364 | Buffer buffer = intrin->buffers[i]; |
365 | Array<ObjectRef> bind_spec{buffer, tensor}; |
366 | auto it = in_region.find(tensor); |
367 | ICHECK(it != in_region.end()); |
368 | const Array<Range>& region = it->second; |
369 | Array<PrimExpr> tuple; |
370 | for (const Range r : region) { |
371 | tuple.push_back(r->min); |
372 | tuple.push_back(r->extent); |
373 | } |
374 | input_bind_nest.emplace_back( |
375 | AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
376 | Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
377 | } |
378 | // output binding |
379 | const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); |
380 | ICHECK(intrin_compute) << "Only support compute intrinsic for now" ; |
381 | ICHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size()); |
382 | ICHECK_EQ(intrin_compute->body.size(), self->body.size()); |
383 | Array<PrimExpr> tuple; |
384 | for (IterVar iv : self->axis) { |
385 | auto it = out_dom.find(iv); |
386 | ICHECK(it != out_dom.end()); |
387 | tuple.push_back(it->second->min); |
388 | tuple.push_back(it->second->extent); |
389 | } |
390 | for (size_t i = intrin->inputs.size(); i < intrin->buffers.size(); ++i) { |
391 | Tensor tensor = stage->op.output(i - intrin->inputs.size()); |
392 | Buffer buffer = intrin->buffers[i]; |
393 | Array<ObjectRef> bind_spec{buffer, tensor}; |
394 | output_bind_nest.emplace_back( |
395 | AttrStmt(bind_spec, tir::attr::buffer_bind_scope, |
396 | Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple), nop)); |
397 | } |
398 | // Check variable remap |
399 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
400 | tir::ArgBinder binder(&vmap); |
401 | ICHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) |
402 | << "Tensorization fail: reduction axis size do not match" ; |
403 | size_t start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); |
404 | for (size_t i = 0; i < start; ++i) { |
405 | IterVar iv = self->reduce_axis[i]; |
406 | auto it = out_dom.find(iv); |
407 | ICHECK(it != out_dom.end()); |
408 | ICHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match" ; |
409 | } |
410 | for (size_t i = start; i < self->reduce_axis.size(); ++i) { |
411 | IterVar iv = self->reduce_axis[i]; |
412 | IterVar target = intrin_compute->reduce_axis[i - start]; |
413 | auto it = out_dom.find(iv); |
414 | ICHECK(it != out_dom.end()); |
415 | binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0), |
416 | "tensir_intrin.reduction.min" ); |
417 | binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent" ); |
418 | } |
419 | if (tloc <= n.num_common_loop) { |
420 | // Do no need to split reduction |
421 | std::vector<std::vector<Stmt>> nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); |
422 | nest.emplace_back(MakeIfNest(n.main_predicates)); |
423 | ICHECK_EQ(n.init_predicates.size(), 0U); |
424 | ICHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined" ; |
425 | Stmt body = MergeNest(output_bind_nest, intrin->body); |
426 | body = MergeNest(input_bind_nest, body); |
427 | body = tir::Substitute(body, vmap); |
428 | body = MergeNest(binder.asserts(), body); |
429 | body = te::Substitute(body, n.main_vmap); |
430 | return MergeNest(nest, body); |
431 | } else { |
432 | // Need to split reduction |
433 | ICHECK(intrin->reduce_update.defined()) |
434 | << "Reduction update op for intrin " << intrin << " is not defined" ; |
435 | // Need init and update steps |
436 | ICHECK_NE(self->reduce_axis.size(), 0U); |
437 | std::vector<std::vector<Stmt>> common(n.main_nest.begin(), |
438 | n.main_nest.begin() + n.num_common_loop + 1); |
439 | std::vector<std::vector<Stmt>> update_nest(n.main_nest.begin() + n.num_common_loop + 1, |
440 | n.main_nest.begin() + tloc + 1); |
441 | update_nest.emplace_back(MakeIfNest(n.main_predicates)); |
442 | |
443 | if (intrin->reduce_init.defined()) { |
444 | // init nest |
445 | std::vector<std::vector<Stmt>> init_nest(n.init_nest.begin(), n.init_nest.begin() + tloc + 1); |
446 | init_nest.emplace_back(MakeIfNest(n.init_predicates)); |
447 | Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); |
448 | init = te::Substitute(init, n.init_vmap); |
449 | init = MergeNest(init_nest, init); |
450 | // The update |
451 | Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); |
452 | update = MergeNest(input_bind_nest, update); |
453 | update = tir::Substitute(update, vmap); |
454 | update = MergeNest(binder.asserts(), update); |
455 | update = te::Substitute(update, n.main_vmap); |
456 | update = MergeNest(update_nest, update); |
457 | return MergeNest(common, SeqStmt::Flatten(init, update)); |
458 | } else { |
459 | // When init op is not available, use body op for reset in the first iter. |
460 | ICHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined" ; |
461 | Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update); |
462 | update = MergeNest(output_bind_nest, update); |
463 | update = MergeNest(input_bind_nest, update); |
464 | update = tir::Substitute(update, vmap); |
465 | update = MergeNest(binder.asserts(), update); |
466 | update = te::Substitute(update, n.main_vmap); |
467 | update = MergeNest(update_nest, update); |
468 | return MergeNest(common, update); |
469 | } |
470 | } |
471 | } |
472 | |
473 | // Register functions for unittests |
474 | TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
475 | Stage stage = args[0]; |
476 | Map<IterVar, Range> dmap = args[1]; |
477 | std::unordered_map<IterVar, Range> out_dom; |
478 | std::unordered_map<Tensor, Array<Range>> in_region; |
479 | ICHECK(stage->op.as<ComputeOpNode>()); |
480 | InferTensorizeRegion(stage->op.as<ComputeOpNode>(), stage, as_unordered_map(dmap), &out_dom, |
481 | &in_region); |
482 | *ret = Array<ObjectRef>{Map<IterVar, Range>(out_dom), Map<Tensor, Array<Range>>(in_region)}; |
483 | }); |
484 | |
485 | TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
486 | Stage stage = args[0]; |
487 | Map<IterVar, Range> out_dom = args[1]; |
488 | Map<Tensor, Array<Range>> in_region = args[2]; |
489 | TensorIntrin intrin = args[3]; |
490 | Map<Var, Range> vrange; |
491 | ICHECK(stage->op.as<ComputeOpNode>()); |
492 | *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), stage, {{}}, as_unordered_map(out_dom), |
493 | as_unordered_map(in_region), intrin, &vrange); |
494 | }); |
495 | } // namespace te |
496 | } // namespace tvm |
497 | |