1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file schedule_dataflow_rewrite.cc |
22 | */ |
23 | #include <tvm/te/operation.h> |
24 | #include <tvm/te/schedule.h> |
25 | #include <tvm/tir/op.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | |
28 | #include <unordered_set> |
29 | |
30 | #include "../../tir/transforms/ir_utils.h" |
31 | #include "message_passing.h" |
32 | #include "operation_inline.h" |
33 | |
34 | namespace tvm { |
35 | namespace te { |
36 | // find first occurance location in leaf |
37 | template <typename T> |
38 | size_t FindNodeRef(ArrayNode* array_node, const T& v) { |
39 | const Object* n = v.get(); |
40 | for (size_t i = 0; i < array_node->size(); ++i) { |
41 | if (array_node->at(i).get() == n) return i; |
42 | } |
43 | return array_node->size(); |
44 | } |
45 | |
46 | // The replacer of cache. |
47 | class VarReplacer : public tir::StmtExprMutator { |
48 | public: |
49 | explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {} |
50 | PrimExpr VisitExpr_(const VarNode* op) final { |
51 | auto it = vsub_.find(op); |
52 | if (it != vsub_.end()) return it->second; |
53 | return GetRef<PrimExpr>(op); |
54 | } |
55 | |
56 | tir::CommReducer MutateCommReducer(tir::CommReducer combiner) { |
57 | // Replace free variables in combiner |
58 | auto new_identity = tir::UpdateArray(combiner->identity_element, |
59 | [this](const PrimExpr& e) { return this->VisitExpr(e); }); |
60 | auto new_result = tir::UpdateArray(combiner->result, |
61 | [this](const PrimExpr& e) { return this->VisitExpr(e); }); |
62 | |
63 | if (combiner->identity_element.same_as(new_identity) && |
64 | combiner->identity_element.same_as(new_result)) { |
65 | return combiner; |
66 | } else { |
67 | return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity); |
68 | } |
69 | } |
70 | |
71 | PrimExpr VisitExpr_(const tir::ReduceNode* op) final { |
72 | PrimExpr new_e = StmtExprMutator::VisitExpr_(op); |
73 | const tir::ReduceNode* new_reduce = new_e.as<tir::ReduceNode>(); |
74 | tir::CommReducer new_combiner = MutateCommReducer(op->combiner); |
75 | if (op->combiner.same_as(new_combiner)) { |
76 | return new_e; |
77 | } else { |
78 | return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition, |
79 | new_reduce->value_index, new_reduce->init); |
80 | } |
81 | } |
82 | |
83 | private: |
84 | const std::unordered_map<const VarNode*, PrimExpr>& vsub_; |
85 | }; |
86 | |
87 | PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) { |
88 | using tir::ReduceNode; |
89 | using tir::SelectNode; |
90 | if (predicates.size() == 0) return body; |
91 | const ReduceNode* reduce = body.as<ReduceNode>(); |
92 | |
93 | if (reduce) { |
94 | auto n = make_object<ReduceNode>(*reduce); |
95 | n->condition = foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, |
96 | n->condition, predicates); |
97 | return PrimExpr(n); |
98 | } |
99 | return Select(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, |
100 | const_true(1), predicates), |
101 | body, make_zero(body.dtype())); |
102 | } |
103 | |
104 | // Replace data flow appears in all stages given the tensor change. |
105 | // Also update vmap if subsequent dataflow need to be replaced. |
106 | // Need to keep an update to the date transitive closure property on the vmap by a reverse map. |
107 | void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap, |
108 | std::unordered_map<Tensor, Tensor>* rvmap) { |
109 | for (Stage s : stages) { |
110 | Operation op = s->op->ReplaceInputs(s->op, *vmap); |
111 | if (!op.same_as(s->op)) { |
112 | for (int i = 0; i < op->num_outputs(); ++i) { |
113 | auto it = rvmap->find(s->op.output(i)); |
114 | if (it != rvmap->end()) { |
115 | (*vmap)[it->second] = op.output(i); |
116 | } else { |
117 | (*vmap)[s->op.output(i)] = op.output(i); |
118 | (*rvmap)[op.output(i)] = s->op.output(i); |
119 | } |
120 | } |
121 | s->op = op; |
122 | } |
123 | } |
124 | } |
125 | |
126 | inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { |
127 | return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && |
128 | (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && |
129 | ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); |
130 | } |
131 | |
132 | Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, |
133 | const Array<Operation>& readers) { |
134 | (*this)->InvalidateCache(); |
135 | // create identity mapping. |
136 | std::ostringstream os; |
137 | os << tensor->op->name; |
138 | if (tensor->op->num_outputs() != 1) { |
139 | os << ".v" << tensor->value_index; |
140 | } |
141 | |
142 | // when a schedule has multiple cache_read on the same tensor, |
143 | // we make sure their op names are unique. e.g., w.shared, w_d.shared, w_d_d.shared |
144 | for (auto pair : (*this)->stage_map) { |
145 | auto stage = pair.second; |
146 | if (stage->op->name == os.str() + "." + scope) { |
147 | os << ".d" ; |
148 | } |
149 | } |
150 | os << "." << scope; |
151 | |
152 | std::unordered_map<Tensor, Tensor> vsub; |
153 | Stage s = operator[](tensor->op); |
154 | Tensor sugar_tensor = s->op.output(tensor->value_index); |
155 | Tensor cache = compute( |
156 | sugar_tensor->shape, |
157 | [&sugar_tensor](const Array<Var>& i) { |
158 | return sugar_tensor(Array<PrimExpr>(i.begin(), i.end())); |
159 | }, |
160 | os.str()); |
161 | vsub[sugar_tensor] = cache; |
162 | |
163 | std::unordered_map<Tensor, Tensor> vmap; |
164 | std::unordered_map<Tensor, Tensor> rvmap; |
165 | for (Operation op : readers) { |
166 | Stage s = operator[](op); |
167 | Operation repl_op = s->op->ReplaceInputs(s->op, vsub); |
168 | ICHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op; |
169 | vmap[s->op.output(0)] = repl_op.output(0); |
170 | rvmap[repl_op.output(0)] = s->op.output(0); |
171 | s->op = repl_op; |
172 | } |
173 | ReplaceDataFlow((*this)->stages, &vmap, &rvmap); |
174 | Array<Stage>& stages = (*this)->stages; |
175 | Stage op_stage = operator[](tensor->op); |
176 | size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage); |
177 | Stage cache_stage = Stage(cache->op); |
178 | cache_stage.set_scope(scope); |
179 | ICHECK_LT(pos, stages.size()); |
180 | stages.insert(stages.begin() + pos + 1, cache_stage); |
181 | (*this)->stage_map.Set(cache->op, cache_stage); |
182 | // Update group |
183 | cache_stage->group = op_stage->group; |
184 | if (cache_stage->group.defined()) { |
185 | ++cache_stage->group->num_child_stages; |
186 | } |
187 | return cache; |
188 | } |
189 | |
190 | template <typename OpType> |
191 | void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis, |
192 | Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map, |
193 | std::unordered_map<const VarNode*, PrimExpr>* p_vsub, |
194 | std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar, |
195 | std::vector<PrimExpr>* p_predicates) { |
196 | auto& red_axis = *p_red_axis; |
197 | auto& new_axis = *p_new_axis; |
198 | auto& dom_map = *p_dom_map; |
199 | auto& vsub = *p_vsub; |
200 | auto& vsub2newvar = *p_vsub2newvar; |
201 | auto& predicates = *p_predicates; |
202 | arith::Analyzer analyzer; |
203 | |
204 | for (IterVar iv : op->reduce_axis) { |
205 | red_axis.insert(iv); |
206 | } |
207 | for (IterVar iv : op->axis) { |
208 | dom_map[iv] = iv->dom; |
209 | analyzer.Bind(iv->var, iv->dom); |
210 | } |
211 | te::PassDownDomain(orig_stage, &dom_map, &analyzer, true); |
212 | { |
213 | // The source->cache |
214 | std::unordered_map<IterVar, PrimExpr> value_map; |
215 | for (IterVar iv : orig_stage->leaf_iter_vars) { |
216 | if (red_axis.count(iv)) continue; |
217 | ICHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions" ; |
218 | Range dom = dom_map.at(iv); |
219 | IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c" ), iv->iter_type); |
220 | new_axis.push_back(new_iv); |
221 | if (is_one(dom->min)) { |
222 | value_map[iv] = dom->min; |
223 | } else { |
224 | value_map[iv] = iv->var; |
225 | vsub2newvar[iv->var.get()] = new_iv->var; |
226 | } |
227 | } |
228 | // skip reduction iteration. |
229 | std::unordered_set<IterVar> skip_bound_check; |
230 | for (IterVar iv : op->reduce_axis) { |
231 | skip_bound_check.insert(iv); |
232 | } |
233 | PassUpIndex(orig_stage, dom_map, &value_map, true); |
234 | predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check); |
235 | // The root axis |
236 | for (IterVar iv : op->axis) { |
237 | if (value_map.count(iv)) { |
238 | vsub[iv->var.get()] = value_map.at(iv); |
239 | } // to handle tensor axis |
240 | } |
241 | } |
242 | } |
243 | |
244 | Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope, |
245 | Operation cache_op, Operation orig_new_op, size_t tensor_size) { |
246 | Array<Tensor> cache_tensor_list; |
247 | for (size_t i = 0; i < tensor_size; i++) { |
248 | Tensor cache_tensor = cache_op.output(i); |
249 | cache_tensor_list.push_back(cache_tensor); |
250 | } |
251 | // The replace of the dataflow |
252 | std::unordered_map<Tensor, Tensor> vmap; |
253 | std::unordered_map<Tensor, Tensor> rvmap; |
254 | vmap[orig_stage->op.output(0)] = orig_new_op.output(0); |
255 | rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); |
256 | for (size_t i = 0; i < tensor_size; i++) { |
257 | vmap[orig_stage->op.output(0)] = orig_new_op.output(0); |
258 | rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); |
259 | } |
260 | ReplaceDataFlow(sch->stages, &vmap, &rvmap); |
261 | // mutate orig stage |
262 | orig_stage->op = orig_new_op; |
263 | orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); |
264 | orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; |
265 | orig_stage->relations = Array<IterVarRelation>(); |
266 | // create schedule for new cached stage. |
267 | Array<Stage>& stages = sch->stages; |
268 | size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage); |
269 | Stage cache_stage = Stage(cache_op); |
270 | cache_stage.set_scope(scope); |
271 | ICHECK_LT(pos, stages.size()); |
272 | stages.insert(stages.begin() + pos, cache_stage); |
273 | sch->stage_map.Set(cache_op, cache_stage); |
274 | // Update group |
275 | cache_stage->group = orig_stage->group; |
276 | if (cache_stage->group.defined()) { |
277 | ++cache_stage->group->num_child_stages; |
278 | } |
279 | return cache_tensor_list; |
280 | } |
281 | |
282 | // Cache write and relayout the data according to loop pattern |
283 | Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array, |
284 | const std::string& scope) { |
285 | size_t tensor_size = tensor_array.size(); |
286 | sch->InvalidateCache(); |
287 | Tensor tensor = tensor_array[0]; |
288 | Stage orig_stage = sch[tensor->op]; |
289 | const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>(); |
290 | |
291 | std::unordered_set<IterVar> red_axis; |
292 | Array<IterVar> new_axis; |
293 | std::unordered_map<IterVar, Range> dom_map; |
294 | |
295 | std::unordered_map<const VarNode*, PrimExpr> vsub; |
296 | std::unordered_map<const VarNode*, PrimExpr> vsub2newvar; |
297 | std::vector<PrimExpr> predicates; |
298 | |
299 | PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, |
300 | &predicates); |
301 | |
302 | PrimExpr body; |
303 | Array<PrimExpr> body_list; |
304 | const tir::ReduceNode* first_reduce = nullptr; |
305 | for (auto cbody : compute->body) { |
306 | body = VarReplacer(vsub)(cbody); |
307 | body = InjectPredicate(predicates, body); |
308 | body = VarReplacer(vsub2newvar)(body); |
309 | // Reduce nodes in ONE computeOp must be the same except value_index |
310 | // This is right only if the original body ensures Reduce nodes are the same |
311 | if (body->IsInstance<tir::ReduceNode>()) { |
312 | const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>(); |
313 | if (first_reduce != nullptr) { |
314 | ICHECK(ReduceEqual(reduce_body, first_reduce)); |
315 | body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis, |
316 | first_reduce->condition, reduce_body->value_index, reduce_body->init); |
317 | } else { |
318 | first_reduce = reduce_body; |
319 | } |
320 | } else { |
321 | ICHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys" ; |
322 | } |
323 | body_list.push_back(body); |
324 | } |
325 | // The reader args |
326 | Array<PrimExpr> args; |
327 | { |
328 | // cache->compute |
329 | std::unordered_map<IterVar, PrimExpr> value_map; |
330 | for (IterVar iv : compute->axis) { |
331 | value_map[iv] = iv->var; |
332 | } |
333 | te::PassDownIndex(orig_stage, dom_map, &value_map, true); |
334 | for (IterVar iv : orig_stage->leaf_iter_vars) { |
335 | if (red_axis.count(iv)) continue; |
336 | args.push_back(value_map.at(iv)); |
337 | } |
338 | } |
339 | Operation cache_op = |
340 | ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); |
341 | |
342 | Array<PrimExpr> cache_expr_list; |
343 | for (size_t i = 0; i < tensor_size; i++) { |
344 | Tensor cache_tensor = cache_op.output(i); |
345 | cache_expr_list.push_back(cache_tensor(args)); |
346 | } |
347 | Operation orig_new_op = |
348 | ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); |
349 | return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); |
350 | } |
351 | |
352 | // for tensor compute op |
353 | Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array, |
354 | const std::string& scope) { |
355 | size_t tensor_size = tensor_array.size(); |
356 | sch->InvalidateCache(); |
357 | Tensor tensor = tensor_array[0]; |
358 | Stage orig_stage = sch[tensor->op]; |
359 | const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>(); |
360 | ICHECK_EQ(tensor_op->num_outputs(), 1) |
361 | << "cache write only support single output tensor_compute_op" ; |
362 | |
363 | std::unordered_set<IterVar> red_axis; |
364 | Array<IterVar> new_axis; |
365 | std::unordered_map<IterVar, Range> dom_map; |
366 | |
367 | std::unordered_map<const VarNode*, PrimExpr> vsub; |
368 | std::unordered_map<const VarNode*, PrimExpr> vsub2newvar; |
369 | std::vector<PrimExpr> predicates; |
370 | |
371 | PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, |
372 | &predicates); |
373 | |
374 | for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) { |
375 | IterVar iv = tensor_op->axis[i]; |
376 | IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c" ), iv->iter_type); |
377 | new_axis.push_back(new_iv); |
378 | } |
379 | Array<Region> new_regions; |
380 | for (Region old_region : tensor_op->input_regions) { |
381 | Region region; |
382 | for (Range r : old_region) { |
383 | PrimExpr min = VarReplacer(vsub2newvar)(r->min); |
384 | PrimExpr extent = VarReplacer(vsub2newvar)(r->extent); |
385 | region.push_back(Range::FromMinExtent(min, extent)); |
386 | } |
387 | new_regions.push_back(region); |
388 | } |
389 | |
390 | Array<PrimExpr> new_scalar_inputs; |
391 | for (PrimExpr old_input : tensor_op->scalar_inputs) { |
392 | new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); |
393 | } |
394 | |
395 | Operation cache_op = |
396 | TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis, |
397 | tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin, |
398 | tensor_op->inputs, new_regions, new_scalar_inputs); |
399 | |
400 | // axis will be used in generating compute op |
401 | Array<IterVar> compute_axis = tensor_op->axis; |
402 | for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { |
403 | IterVar iv = tensor_op->axis[i]; |
404 | IterVar aiv = IterVar(iv->dom, iv->var, kDataPar); |
405 | compute_axis.Set(i, aiv); |
406 | } |
407 | |
408 | // The reader args |
409 | Array<PrimExpr> args; |
410 | { |
411 | // cache->compute |
412 | std::unordered_map<IterVar, PrimExpr> value_map; |
413 | for (IterVar iv : compute_axis) { |
414 | value_map[iv] = iv->var; |
415 | } |
416 | PassDownIndex(orig_stage, dom_map, &value_map, true); |
417 | for (IterVar iv : orig_stage->leaf_iter_vars) { |
418 | if (red_axis.count(iv)) continue; |
419 | args.push_back(value_map.at(iv)); |
420 | } |
421 | // tensorized region axis |
422 | for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { |
423 | IterVar iv = compute_axis[i]; |
424 | args.push_back(value_map.at(iv)); |
425 | } |
426 | } |
427 | |
428 | Array<PrimExpr> cache_expr_list; |
429 | for (size_t i = 0; i < tensor_size; i++) { |
430 | Tensor cache_tensor = cache_op.output(i); |
431 | cache_expr_list.push_back(cache_tensor(args)); |
432 | } |
433 | Operation orig_new_op = |
434 | ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); |
435 | return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); |
436 | } |
437 | |
438 | Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) { |
439 | (*this)->InvalidateCache(); |
440 | ICHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0" ; |
441 | Tensor tensor = tensor_array[0]; |
442 | Stage orig_stage = operator[](tensor->op); |
443 | const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); |
444 | ICHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size()) |
445 | << "size of input tensor list must be same as number of stage outputs" ; |
446 | for (size_t i = 1; i < tensor_array.size(); i++) { |
447 | Stage tmp_stage = operator[](tensor_array[i]->op); |
448 | ICHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp" ; |
449 | } |
450 | return CacheWriteWithReLayout(*this, tensor_array, scope); |
451 | } |
452 | |
453 | Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { |
454 | // support original compute and tensor compute both |
455 | (*this)->InvalidateCache(); |
456 | if (tensor->op.as<ComputeOpNode>()) { |
457 | return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; |
458 | } else if (tensor->op.as<TensorComputeOpNode>()) { |
459 | return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0]; |
460 | } else { |
461 | LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers" ; |
462 | } |
463 | } |
464 | |
465 | void RebaseNonZeroMinLoop(ScheduleNode* sch) { |
466 | std::unordered_map<IterVar, IterVar> rebase_map; |
467 | for (Stage s : sch->stages) { |
468 | if (s->attach_type == kInlinedAlready) continue; |
469 | |
470 | auto root_iter_vars = s->op->root_iter_vars(); |
471 | ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); |
472 | for (IterVar iv : root_iter_vars) { |
473 | size_t idx = FindNodeRef(leaf_vars, iv); |
474 | auto it = s->iter_var_attrs.find(iv); |
475 | // don;t need to rebase path that are binded. |
476 | if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { |
477 | continue; |
478 | } |
479 | if (idx < leaf_vars->size()) { |
480 | // insert rebase |
481 | IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix("" ), iv->iter_type); |
482 | s->relations.push_back(te::Rebase(iv, rebased)); |
483 | if (s->iter_var_attrs.count(iv)) { |
484 | s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); |
485 | } |
486 | leaf_vars->SetItem(idx, rebased); |
487 | rebase_map[iv] = rebased; |
488 | } |
489 | } |
490 | } |
491 | // remap the parent relation |
492 | for (Stage s : sch->stages) { |
493 | if (s->attach_type != kScope) continue; |
494 | if (rebase_map.count(s->attach_ivar)) { |
495 | s->attach_ivar = rebase_map.at(s->attach_ivar); |
496 | } |
497 | } |
498 | for (Stage s : sch->groups) { |
499 | if (s->attach_type != kScope) continue; |
500 | if (rebase_map.count(s->attach_ivar)) { |
501 | s->attach_ivar = rebase_map.at(s->attach_ivar); |
502 | } |
503 | } |
504 | } |
505 | |
506 | void InjectInline(ScheduleNode* sch, bool ) { |
507 | sch->InvalidateCache(); |
508 | |
509 | std::vector<Array<PrimExpr>> new_body(sch->stages.size()); |
510 | std::vector<bool> changed(sch->stages.size(), false); |
511 | std::vector<Stmt> new_hybrid_body(sch->stages.size()); |
512 | std::vector<bool> hybrid_changed(sch->stages.size(), false); |
513 | // (sshtin): this workaround allows to inline extern ops into their consumer. |
514 | // All inputs for extern op should not be inlined because inlining may happen |
515 | // before TE generation for particular extern op. That may lead to |
516 | // crash during lowering or building stages. |
517 | // The problem description: |
518 | // In case of operations fusing, arguments inlining |
519 | // prevents creation of ProducerNode for extern operation. |
520 | // Instead of the creation it is supposed to use operation argument as inlined buffer |
521 | // but extern_op TIR generation can be peformed after inlining procedure so |
522 | // newly generated TIR does not have reference to input data at all. |
523 | std::unordered_map<Operation, Operation> ext_ops; |
524 | for (size_t i = 0; i < sch->stages.size(); i++) { |
525 | Stage stage = sch->stages[i]; |
526 | auto ext_op = stage->op.as<ExternOpNode>(); |
527 | if (ext_op) { |
528 | auto inps = ext_op->InputTensors(); |
529 | for (size_t ii = 0; ii < inps.size(); ++ii) { |
530 | if (ext_ops.find(inps[ii]->op) == ext_ops.end()) { |
531 | ext_ops[inps[ii]->op] = stage->op; |
532 | } |
533 | } |
534 | } |
535 | } |
536 | // inline all the ops |
537 | for (size_t i = sch->stages.size(); i != 0; --i) { |
538 | Stage stage = sch->stages[i - 1]; |
539 | if (stage->attach_type == kInline) { |
540 | stage->attach_type = kInlinedAlready; |
541 | Array<Var> args; |
542 | PrimExpr body; |
543 | { |
544 | // setup args |
545 | const ComputeOpNode* compute = stage->op.as<ComputeOpNode>(); |
546 | ICHECK(compute) << "can only inline compute op" ; |
547 | for (auto iv : compute->axis) { |
548 | args.push_back(iv->var); |
549 | } |
550 | if (ext_ops.find(stage->op) != ext_ops.end()) { |
551 | // sshtin: The extern op can try to get access to the input tensors as a raw data, |
552 | // that can lead to error in IR builder. |
553 | stage->attach_type = kGroupRoot; |
554 | continue; |
555 | } |
556 | ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output" ; |
557 | if (feature_extraction_mode && compute->attrs.count("const_matrix" )) { |
558 | // Use constant value to replace access of const matrices. |
559 | // This produces wrong IR but is good enough for feature extraction purposes. |
560 | // This simplification can accelerate the feature extration and evolutionary search. |
561 | body = make_const(compute->output_dtype(0), 1.0f); |
562 | } else { |
563 | body = compute->body[0]; |
564 | } |
565 | } |
566 | for (size_t j = i; j < sch->stages.size(); ++j) { |
567 | Stage s = sch->stages[j]; |
568 | const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); |
569 | const HybridOpNode* hybrid = s->op.as<HybridOpNode>(); |
570 | if (compute) { |
571 | if (!new_body[j].size()) { |
572 | new_body[j] = compute->body; |
573 | } |
574 | if (new_body[j][0]->IsInstance<tir::ReduceNode>()) { |
575 | // specially handle reduction inline for multiplre reductions. |
576 | const tir::ReduceNode* reduce = new_body[j][0].as<tir::ReduceNode>(); |
577 | for (size_t k = 1; k < new_body[j].size(); ++k) { |
578 | const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>(); |
579 | ICHECK(reduce_); |
580 | ICHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " |
581 | << "have the same attribute except value_index" ; |
582 | } |
583 | PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body) |
584 | .as<tir::EvaluateNode>() |
585 | ->value; |
586 | if (!new_value.same_as(new_body[j][0])) { |
587 | changed[j] = true; |
588 | const tir::ReduceNode* r = new_value.as<tir::ReduceNode>(); |
589 | ICHECK(r != nullptr); |
590 | ICHECK_EQ(new_body[j].size(), r->source.size()); |
591 | for (size_t k = 0; k < new_body[j].size(); ++k) { |
592 | auto n = make_object<tir::ReduceNode>(*r); |
593 | n->value_index = static_cast<int>(k); |
594 | n->dtype = r->source[k].dtype(); |
595 | new_body[j].Set(k, PrimExpr(n)); |
596 | } |
597 | } |
598 | } else { |
599 | for (size_t k = 0; k < new_body[j].size(); ++k) { |
600 | PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body) |
601 | .as<tir::EvaluateNode>() |
602 | ->value; |
603 | if (!new_value.same_as(new_body[j][k])) { |
604 | new_body[j].Set(k, new_value); |
605 | changed[j] = true; |
606 | } |
607 | } |
608 | } |
609 | } else if (hybrid) { |
610 | if (!new_hybrid_body[j].defined()) { |
611 | new_hybrid_body[j] = hybrid->body; |
612 | } |
613 | Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body); |
614 | if (!new_stmt.same_as(new_hybrid_body[j])) { |
615 | new_hybrid_body[j] = new_stmt; |
616 | hybrid_changed[j] = true; |
617 | } |
618 | } |
619 | } |
620 | } |
621 | } |
622 | std::unordered_map<Tensor, Tensor> repl; |
623 | // rewrite dataflow |
624 | for (size_t i = 0; i < sch->stages.size(); ++i) { |
625 | Stage s = sch->stages[i]; |
626 | if (s->attach_type == kInlinedAlready) continue; |
627 | if (new_body[i].size()) { |
628 | // Logics from ReplaceDataFlow |
629 | const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); |
630 | ICHECK(compute); |
631 | Operation op = s->op; |
632 | if (changed[i]) { |
633 | op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]); |
634 | } |
635 | op = op->ReplaceInputs(op, repl); |
636 | if (!op.same_as(s->op)) { |
637 | for (int idx = 0; idx < s->op->num_outputs(); ++idx) { |
638 | repl[s->op.output(idx)] = op.output(idx); |
639 | } |
640 | s->op = op; |
641 | } |
642 | } else if (hybrid_changed[i]) { |
643 | const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>(); |
644 | ICHECK(hybrid); |
645 | Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, |
646 | hybrid->outputs, new_hybrid_body[i]); |
647 | op = op->ReplaceInputs(op, repl); |
648 | for (int idx = 0; idx < s->op->num_outputs(); ++idx) { |
649 | repl[s->op.output(idx)] = op.output(idx); |
650 | } |
651 | s->op = op; |
652 | } else { |
653 | Operation op = s->op->ReplaceInputs(s->op, repl); |
654 | if (!op.same_as(s->op)) { |
655 | for (int j = 0; j < op->num_outputs(); ++j) { |
656 | repl[s->op.output(j)] = op.output(j); |
657 | } |
658 | s->op = op; |
659 | } |
660 | } |
661 | } |
662 | } |
663 | |
664 | void LegalizeInvalidAttach(ScheduleNode* sch) { |
665 | // Legalize the compute_at location if the target iterator of compute_at is split or fused. |
666 | // Case 1: If the target of compute_at is split, |
667 | // we will move the compute_at location to the inner iterator. |
668 | // Case 2: If the target of compute_at is fused, |
669 | // we will move the compute_at location to the newly fused iterator. |
670 | // Note that case 2 can only happen if the target of compute_at |
671 | // is the innermost operand of fuse operation. |
672 | |
673 | // Map an old invalid attach point to its new valid attach point |
674 | std::unordered_map<IterVar, IterVar> replace_map; |
675 | |
676 | for (Stage stage : sch->stages) { |
677 | std::unordered_set<const Object*> visited; |
678 | for (Stage s = stage; s.defined();) { |
679 | // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`, |
680 | // because we follow the validation check in that function to legalize the attach. |
681 | ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group" ; |
682 | visited.insert(s.get()); |
683 | Stage spec = s.GetAttachSpec(); |
684 | if (spec->attach_type != kScope) { |
685 | break; |
686 | } |
687 | bool start_attach = false; |
688 | IterVar attach_ivar = spec->attach_ivar; |
689 | s = spec->attach_stage; |
690 | ICHECK(attach_ivar.defined()); |
691 | ICHECK(s.defined()); |
692 | |
693 | for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { |
694 | IterVar iv = s->leaf_iter_vars[i - 1]; |
695 | if (!start_attach && iv.same_as(attach_ivar)) { |
696 | start_attach = true; |
697 | break; |
698 | } |
699 | } |
700 | |
701 | if (!start_attach) { |
702 | IterVar new_attach_ivar = attach_ivar; |
703 | bool updated = true; |
704 | // recursively update the relations |
705 | while (updated) { |
706 | updated = false; |
707 | for (const auto& rel : s->relations) { |
708 | if (const FuseNode* r = rel.as<FuseNode>()) { |
709 | if (new_attach_ivar.same_as(r->inner)) { |
710 | new_attach_ivar = r->fused; |
711 | updated = true; |
712 | } |
713 | } else if (const SplitNode* r = rel.as<SplitNode>()) { |
714 | if (new_attach_ivar.same_as(r->parent)) { |
715 | new_attach_ivar = r->inner; |
716 | updated = true; |
717 | } |
718 | } |
719 | } |
720 | replace_map[attach_ivar] = new_attach_ivar; |
721 | } |
722 | } |
723 | } |
724 | } |
725 | |
726 | // remap the parent relation |
727 | for (Stage s : sch->stages) { |
728 | if (s->attach_type != kScope) continue; |
729 | if (replace_map.count(s->attach_ivar)) { |
730 | s->attach_ivar = replace_map.at(s->attach_ivar); |
731 | } |
732 | } |
733 | for (Stage s : sch->groups) { |
734 | if (s->attach_type != kScope) continue; |
735 | if (replace_map.count(s->attach_ivar)) { |
736 | s->attach_ivar = replace_map.at(s->attach_ivar); |
737 | } |
738 | } |
739 | } |
740 | |
741 | Schedule Schedule::normalize() { |
742 | Schedule sn = copy(); |
743 | InjectInline(sn.operator->(), false); |
744 | RebaseNonZeroMinLoop(sn.operator->()); |
745 | LegalizeInvalidAttach(sn.operator->()); |
746 | return sn; |
747 | } |
748 | |
749 | Schedule Schedule::() { |
750 | Schedule sn = copy(); |
751 | InjectInline(sn.operator->(), true); |
752 | RebaseNonZeroMinLoop(sn.operator->()); |
753 | LegalizeInvalidAttach(sn.operator->()); |
754 | return sn; |
755 | } |
756 | |
757 | // Handle reduction factor. |
758 | Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) { |
759 | (*this)->InvalidateCache(); |
760 | using tir::ReduceNode; |
761 | ICHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis" ; |
762 | Stage reduce_stage = operator[](tensor->op); |
763 | const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>(); |
764 | ICHECK(compute_op) << "Can only factor ComputeOp" ; |
765 | ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); |
766 | { |
767 | size_t axis_pos = FindNodeRef(leaf_vars, axis); |
768 | ICHECK_NE(axis_pos, leaf_vars->size()) |
769 | << "Cannot find IterVar " << axis << " in leaf iter vars" ; |
770 | } |
771 | // Find touched reduction axis. |
772 | std::unordered_map<IterVar, int> touch_map; |
773 | touch_map[axis] = 1; |
774 | te::PassUpBitMaskOr(reduce_stage, &touch_map, true); |
775 | te::PassDownBitMaskOr(reduce_stage, &touch_map, true); |
776 | // skip reduction iteration. |
777 | std::unordered_set<IterVar> skip_bound_check; |
778 | // Verify normal axis are not touched. |
779 | for (IterVar iv : compute_op->axis) { |
780 | ICHECK(!touch_map.count(iv)) << "Factor axis touches normal axis." ; |
781 | skip_bound_check.insert(iv); |
782 | } |
783 | // get analyzer. |
784 | arith::Analyzer analyzer; |
785 | // Get the replace index |
786 | std::unordered_map<IterVar, Range> dom_map; |
787 | std::unordered_map<IterVar, PrimExpr> value_map; |
788 | for (IterVar iv : compute_op->reduce_axis) { |
789 | if (touch_map.count(iv)) { |
790 | dom_map[iv] = iv->dom; |
791 | } else { |
792 | skip_bound_check.insert(iv); |
793 | } |
794 | analyzer.Bind(iv->var, iv->dom); |
795 | } |
796 | te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true); |
797 | for (IterVar iv : reduce_stage->leaf_iter_vars) { |
798 | if (touch_map.count(iv)) { |
799 | Range dom = dom_map.at(iv); |
800 | if (is_one(dom->extent)) { |
801 | value_map[iv] = dom->min; |
802 | } else { |
803 | value_map[iv] = iv->var; |
804 | } |
805 | } |
806 | } |
807 | te::PassUpIndex(reduce_stage, dom_map, &value_map, true); |
808 | std::vector<PrimExpr> predicates = |
809 | MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check); |
810 | |
811 | // Get the factored op node. |
812 | const int factor_axis_pos = |
813 | factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis; |
814 | ICHECK_LE(factor_axis_pos, compute_op->axis.size()); |
815 | auto n = make_object<ComputeOpNode>(); |
816 | n->name = compute_op->name + ".rf" ; |
817 | { |
818 | // axis relacement. |
819 | IterVar iv(dom_map.at(axis), axis->var, kDataPar); |
820 | ICHECK(is_zero(iv->dom->min)) << "Can only factor reduction domain starting from 0" ; |
821 | |
822 | const int size = compute_op->axis.size(); |
823 | for (int idx = 0; idx < size; ++idx) { |
824 | if (factor_axis_pos == idx) { |
825 | n->axis.push_back(iv); |
826 | } |
827 | n->axis.push_back(compute_op->axis[idx]); |
828 | } |
829 | if (factor_axis_pos == size) { |
830 | n->axis.push_back(iv); |
831 | } |
832 | } |
833 | // predicate generation, copy not touched axis. |
834 | int idx = tensor->value_index; |
835 | const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>(); |
836 | ICHECK(reduce) << "Can only rfactor non-inline reductions" ; |
837 | predicates.push_back(reduce->condition); |
838 | |
839 | PrimExpr predicate = |
840 | likely(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, |
841 | const_true(1), predicates)); |
842 | |
843 | std::unordered_map<const VarNode*, PrimExpr> vsub; |
844 | |
845 | for (IterVar iv : compute_op->reduce_axis) { |
846 | if (!touch_map.count(iv)) { |
847 | n->reduce_axis.push_back(iv); |
848 | } else { |
849 | ICHECK(value_map.count(iv)); |
850 | PrimExpr index = value_map.at(iv); |
851 | vsub[iv->var.get()] = index; |
852 | } |
853 | } |
854 | |
855 | // Copy touched axis. |
856 | for (IterVar iv : reduce_stage->leaf_iter_vars) { |
857 | if (touch_map.count(iv) && !iv.same_as(axis)) { |
858 | ICHECK_EQ(iv->iter_type, kCommReduce); |
859 | IterVar ncpy(dom_map.at(iv), iv->var, iv->iter_type, iv->thread_tag, iv->span); |
860 | n->reduce_axis.push_back(ncpy); |
861 | } |
862 | } |
863 | VarReplacer replacer(vsub); |
864 | Array<PrimExpr> new_source = |
865 | tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); }); |
866 | |
867 | PrimExpr new_pred = replacer(predicate); |
868 | |
869 | std::vector<PrimExpr> body; |
870 | for (size_t idx = 0; idx < reduce->source.size(); ++idx) { |
871 | body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {})); |
872 | } |
873 | n->body = Array<PrimExpr>(body); |
874 | // refresh relations, keep the un-touched relations. |
875 | Array<IterVarRelation> rels; |
876 | for (IterVarRelation rel : reduce_stage->relations) { |
877 | bool touched = false; |
878 | if (const SplitNode* r = rel.as<SplitNode>()) { |
879 | if (touch_map.count(r->parent)) touched = true; |
880 | } else if (const FuseNode* r = rel.as<FuseNode>()) { |
881 | if (touch_map.count(r->fused)) touched = true; |
882 | } else if (const RebaseNode* r = rel.as<RebaseNode>()) { |
883 | if (touch_map.count(r->parent)) touched = true; |
884 | } else { |
885 | LOG(FATAL) << "unknown relation type" ; |
886 | } |
887 | if (!touched) { |
888 | rels.push_back(rel); |
889 | } |
890 | } |
891 | // initialize the factored stage. |
892 | Operation factor_op(n); |
893 | Array<Stage>& stages = (*this)->stages; |
894 | size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage); |
895 | Stage factor_stage = Stage(factor_op); |
896 | factor_stage->relations = rels; |
897 | ICHECK_LT(stage_pos, stages.size()); |
898 | stages.insert(stages.begin() + stage_pos, factor_stage); |
899 | (*this)->stage_map.Set(factor_op, factor_stage); |
900 | factor_stage->group = reduce_stage->group; |
901 | if (factor_stage->group.defined()) { |
902 | ++factor_stage->group->num_child_stages; |
903 | } |
904 | // Replace the old reduction. |
905 | IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v" ); |
906 | Array<Tensor> factor_tensors; |
907 | Array<Tensor> old_tensors; |
908 | int size = factor_op->num_outputs(); |
909 | for (int idx = 0; idx < size; ++idx) { |
910 | factor_tensors.push_back(factor_op.output(idx)); |
911 | old_tensors.push_back(reduce_stage->op.output(idx)); |
912 | } |
913 | Array<Tensor> repl_tensors = compute( |
914 | old_tensors[0]->shape, |
915 | [&](const Array<Var>& i) { |
916 | Array<PrimExpr> indices; |
917 | const int idx_size = static_cast<int>(i.size()); |
918 | for (int idx = 0; idx < idx_size; ++idx) { |
919 | if (factor_axis_pos == idx) { |
920 | indices.push_back(repl_red_axis->var); |
921 | } |
922 | indices.push_back(i[idx]); |
923 | } |
924 | Array<PrimExpr> new_init = reduce->init; |
925 | if (!reduce->init.empty()) { |
926 | std::unordered_map<const VarNode*, PrimExpr> init_vsub; |
927 | for (const auto& init : reduce->init) { |
928 | if (init->IsInstance<ProducerLoadNode>()) { |
929 | ICHECK_EQ(compute_op->axis.size(), idx_size) |
930 | << "'init' should have the number of dimensions as output when using with " |
931 | "rfactor" ; |
932 | for (int idx = 0; idx < idx_size; idx++) { |
933 | init_vsub[compute_op->axis[idx]->var.get()] = i[idx]; |
934 | } |
935 | } |
936 | } |
937 | VarReplacer init_replacer(init_vsub); |
938 | new_init = tir::UpdateArray( |
939 | reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); }); |
940 | } |
941 | if (factor_axis_pos == idx_size) { |
942 | indices.push_back(repl_red_axis->var); |
943 | } |
944 | Array<PrimExpr> factor_exprs; |
945 | for (int idx = 0; idx < size; ++idx) { |
946 | factor_exprs.push_back(factor_tensors[idx](indices)); |
947 | } |
948 | Array<PrimExpr> reductions; |
949 | Array<IterVar> axis = {repl_red_axis}; |
950 | PrimExpr cond = const_true(); |
951 | for (int idx = 0; idx < size; ++idx) { |
952 | reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init)); |
953 | } |
954 | return reductions; |
955 | }, |
956 | reduce_stage->op->name + ".repl" ); |
957 | |
958 | std::unordered_map<Tensor, Tensor> vmap; |
959 | std::unordered_map<Tensor, Tensor> rvmap; |
960 | for (int idx = 0; idx < size; ++idx) { |
961 | vmap[old_tensors[idx]] = repl_tensors[idx]; |
962 | rvmap[repl_tensors[idx]] = old_tensors[idx]; |
963 | } |
964 | ReplaceDataFlow((*this)->stages, &vmap, &rvmap); |
965 | // revamp the reduction stage. |
966 | reduce_stage->op = repl_tensors[0]->op; |
967 | reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars(); |
968 | reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; |
969 | reduce_stage->relations = Array<IterVarRelation>(); |
970 | return factor_tensors; |
971 | } |
972 | } // namespace te |
973 | } // namespace tvm |
974 | |