1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file schedule_dataflow_rewrite.cc
22 */
23#include <tvm/te/operation.h>
24#include <tvm/te/schedule.h>
25#include <tvm/tir/op.h>
26#include <tvm/tir/stmt_functor.h>
27
28#include <unordered_set>
29
30#include "../../tir/transforms/ir_utils.h"
31#include "message_passing.h"
32#include "operation_inline.h"
33
34namespace tvm {
35namespace te {
36// find first occurance location in leaf
37template <typename T>
38size_t FindNodeRef(ArrayNode* array_node, const T& v) {
39 const Object* n = v.get();
40 for (size_t i = 0; i < array_node->size(); ++i) {
41 if (array_node->at(i).get() == n) return i;
42 }
43 return array_node->size();
44}
45
46// The replacer of cache.
47class VarReplacer : public tir::StmtExprMutator {
48 public:
49 explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {}
50 PrimExpr VisitExpr_(const VarNode* op) final {
51 auto it = vsub_.find(op);
52 if (it != vsub_.end()) return it->second;
53 return GetRef<PrimExpr>(op);
54 }
55
56 tir::CommReducer MutateCommReducer(tir::CommReducer combiner) {
57 // Replace free variables in combiner
58 auto new_identity = tir::UpdateArray(combiner->identity_element,
59 [this](const PrimExpr& e) { return this->VisitExpr(e); });
60 auto new_result = tir::UpdateArray(combiner->result,
61 [this](const PrimExpr& e) { return this->VisitExpr(e); });
62
63 if (combiner->identity_element.same_as(new_identity) &&
64 combiner->identity_element.same_as(new_result)) {
65 return combiner;
66 } else {
67 return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity);
68 }
69 }
70
71 PrimExpr VisitExpr_(const tir::ReduceNode* op) final {
72 PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
73 const tir::ReduceNode* new_reduce = new_e.as<tir::ReduceNode>();
74 tir::CommReducer new_combiner = MutateCommReducer(op->combiner);
75 if (op->combiner.same_as(new_combiner)) {
76 return new_e;
77 } else {
78 return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
79 new_reduce->value_index, new_reduce->init);
80 }
81 }
82
83 private:
84 const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
85};
86
87PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
88 using tir::ReduceNode;
89 using tir::SelectNode;
90 if (predicates.size() == 0) return body;
91 const ReduceNode* reduce = body.as<ReduceNode>();
92
93 if (reduce) {
94 auto n = make_object<ReduceNode>(*reduce);
95 n->condition = foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
96 n->condition, predicates);
97 return PrimExpr(n);
98 }
99 return Select(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
100 const_true(1), predicates),
101 body, make_zero(body.dtype()));
102}
103
104// Replace data flow appears in all stages given the tensor change.
105// Also update vmap if subsequent dataflow need to be replaced.
106// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
107void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap,
108 std::unordered_map<Tensor, Tensor>* rvmap) {
109 for (Stage s : stages) {
110 Operation op = s->op->ReplaceInputs(s->op, *vmap);
111 if (!op.same_as(s->op)) {
112 for (int i = 0; i < op->num_outputs(); ++i) {
113 auto it = rvmap->find(s->op.output(i));
114 if (it != rvmap->end()) {
115 (*vmap)[it->second] = op.output(i);
116 } else {
117 (*vmap)[s->op.output(i)] = op.output(i);
118 (*rvmap)[op.output(i)] = s->op.output(i);
119 }
120 }
121 s->op = op;
122 }
123 }
124}
125
126inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
127 return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
128 (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
129 ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
130}
131
132Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
133 const Array<Operation>& readers) {
134 (*this)->InvalidateCache();
135 // create identity mapping.
136 std::ostringstream os;
137 os << tensor->op->name;
138 if (tensor->op->num_outputs() != 1) {
139 os << ".v" << tensor->value_index;
140 }
141
142 // when a schedule has multiple cache_read on the same tensor,
143 // we make sure their op names are unique. e.g., w.shared, w_d.shared, w_d_d.shared
144 for (auto pair : (*this)->stage_map) {
145 auto stage = pair.second;
146 if (stage->op->name == os.str() + "." + scope) {
147 os << ".d";
148 }
149 }
150 os << "." << scope;
151
152 std::unordered_map<Tensor, Tensor> vsub;
153 Stage s = operator[](tensor->op);
154 Tensor sugar_tensor = s->op.output(tensor->value_index);
155 Tensor cache = compute(
156 sugar_tensor->shape,
157 [&sugar_tensor](const Array<Var>& i) {
158 return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
159 },
160 os.str());
161 vsub[sugar_tensor] = cache;
162
163 std::unordered_map<Tensor, Tensor> vmap;
164 std::unordered_map<Tensor, Tensor> rvmap;
165 for (Operation op : readers) {
166 Stage s = operator[](op);
167 Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
168 ICHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op;
169 vmap[s->op.output(0)] = repl_op.output(0);
170 rvmap[repl_op.output(0)] = s->op.output(0);
171 s->op = repl_op;
172 }
173 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
174 Array<Stage>& stages = (*this)->stages;
175 Stage op_stage = operator[](tensor->op);
176 size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
177 Stage cache_stage = Stage(cache->op);
178 cache_stage.set_scope(scope);
179 ICHECK_LT(pos, stages.size());
180 stages.insert(stages.begin() + pos + 1, cache_stage);
181 (*this)->stage_map.Set(cache->op, cache_stage);
182 // Update group
183 cache_stage->group = op_stage->group;
184 if (cache_stage->group.defined()) {
185 ++cache_stage->group->num_child_stages;
186 }
187 return cache;
188}
189
190template <typename OpType>
191void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis,
192 Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map,
193 std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
194 std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
195 std::vector<PrimExpr>* p_predicates) {
196 auto& red_axis = *p_red_axis;
197 auto& new_axis = *p_new_axis;
198 auto& dom_map = *p_dom_map;
199 auto& vsub = *p_vsub;
200 auto& vsub2newvar = *p_vsub2newvar;
201 auto& predicates = *p_predicates;
202 arith::Analyzer analyzer;
203
204 for (IterVar iv : op->reduce_axis) {
205 red_axis.insert(iv);
206 }
207 for (IterVar iv : op->axis) {
208 dom_map[iv] = iv->dom;
209 analyzer.Bind(iv->var, iv->dom);
210 }
211 te::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
212 {
213 // The source->cache
214 std::unordered_map<IterVar, PrimExpr> value_map;
215 for (IterVar iv : orig_stage->leaf_iter_vars) {
216 if (red_axis.count(iv)) continue;
217 ICHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions";
218 Range dom = dom_map.at(iv);
219 IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
220 new_axis.push_back(new_iv);
221 if (is_one(dom->min)) {
222 value_map[iv] = dom->min;
223 } else {
224 value_map[iv] = iv->var;
225 vsub2newvar[iv->var.get()] = new_iv->var;
226 }
227 }
228 // skip reduction iteration.
229 std::unordered_set<IterVar> skip_bound_check;
230 for (IterVar iv : op->reduce_axis) {
231 skip_bound_check.insert(iv);
232 }
233 PassUpIndex(orig_stage, dom_map, &value_map, true);
234 predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check);
235 // The root axis
236 for (IterVar iv : op->axis) {
237 if (value_map.count(iv)) {
238 vsub[iv->var.get()] = value_map.at(iv);
239 } // to handle tensor axis
240 }
241 }
242}
243
244Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope,
245 Operation cache_op, Operation orig_new_op, size_t tensor_size) {
246 Array<Tensor> cache_tensor_list;
247 for (size_t i = 0; i < tensor_size; i++) {
248 Tensor cache_tensor = cache_op.output(i);
249 cache_tensor_list.push_back(cache_tensor);
250 }
251 // The replace of the dataflow
252 std::unordered_map<Tensor, Tensor> vmap;
253 std::unordered_map<Tensor, Tensor> rvmap;
254 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
255 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
256 for (size_t i = 0; i < tensor_size; i++) {
257 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
258 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
259 }
260 ReplaceDataFlow(sch->stages, &vmap, &rvmap);
261 // mutate orig stage
262 orig_stage->op = orig_new_op;
263 orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
264 orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
265 orig_stage->relations = Array<IterVarRelation>();
266 // create schedule for new cached stage.
267 Array<Stage>& stages = sch->stages;
268 size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
269 Stage cache_stage = Stage(cache_op);
270 cache_stage.set_scope(scope);
271 ICHECK_LT(pos, stages.size());
272 stages.insert(stages.begin() + pos, cache_stage);
273 sch->stage_map.Set(cache_op, cache_stage);
274 // Update group
275 cache_stage->group = orig_stage->group;
276 if (cache_stage->group.defined()) {
277 ++cache_stage->group->num_child_stages;
278 }
279 return cache_tensor_list;
280}
281
282// Cache write and relayout the data according to loop pattern
283Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array,
284 const std::string& scope) {
285 size_t tensor_size = tensor_array.size();
286 sch->InvalidateCache();
287 Tensor tensor = tensor_array[0];
288 Stage orig_stage = sch[tensor->op];
289 const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
290
291 std::unordered_set<IterVar> red_axis;
292 Array<IterVar> new_axis;
293 std::unordered_map<IterVar, Range> dom_map;
294
295 std::unordered_map<const VarNode*, PrimExpr> vsub;
296 std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
297 std::vector<PrimExpr> predicates;
298
299 PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
300 &predicates);
301
302 PrimExpr body;
303 Array<PrimExpr> body_list;
304 const tir::ReduceNode* first_reduce = nullptr;
305 for (auto cbody : compute->body) {
306 body = VarReplacer(vsub)(cbody);
307 body = InjectPredicate(predicates, body);
308 body = VarReplacer(vsub2newvar)(body);
309 // Reduce nodes in ONE computeOp must be the same except value_index
310 // This is right only if the original body ensures Reduce nodes are the same
311 if (body->IsInstance<tir::ReduceNode>()) {
312 const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>();
313 if (first_reduce != nullptr) {
314 ICHECK(ReduceEqual(reduce_body, first_reduce));
315 body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
316 first_reduce->condition, reduce_body->value_index, reduce_body->init);
317 } else {
318 first_reduce = reduce_body;
319 }
320 } else {
321 ICHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys";
322 }
323 body_list.push_back(body);
324 }
325 // The reader args
326 Array<PrimExpr> args;
327 {
328 // cache->compute
329 std::unordered_map<IterVar, PrimExpr> value_map;
330 for (IterVar iv : compute->axis) {
331 value_map[iv] = iv->var;
332 }
333 te::PassDownIndex(orig_stage, dom_map, &value_map, true);
334 for (IterVar iv : orig_stage->leaf_iter_vars) {
335 if (red_axis.count(iv)) continue;
336 args.push_back(value_map.at(iv));
337 }
338 }
339 Operation cache_op =
340 ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list);
341
342 Array<PrimExpr> cache_expr_list;
343 for (size_t i = 0; i < tensor_size; i++) {
344 Tensor cache_tensor = cache_op.output(i);
345 cache_expr_list.push_back(cache_tensor(args));
346 }
347 Operation orig_new_op =
348 ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list);
349 return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
350}
351
352// for tensor compute op
353Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array,
354 const std::string& scope) {
355 size_t tensor_size = tensor_array.size();
356 sch->InvalidateCache();
357 Tensor tensor = tensor_array[0];
358 Stage orig_stage = sch[tensor->op];
359 const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
360 ICHECK_EQ(tensor_op->num_outputs(), 1)
361 << "cache write only support single output tensor_compute_op";
362
363 std::unordered_set<IterVar> red_axis;
364 Array<IterVar> new_axis;
365 std::unordered_map<IterVar, Range> dom_map;
366
367 std::unordered_map<const VarNode*, PrimExpr> vsub;
368 std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
369 std::vector<PrimExpr> predicates;
370
371 PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
372 &predicates);
373
374 for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
375 IterVar iv = tensor_op->axis[i];
376 IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
377 new_axis.push_back(new_iv);
378 }
379 Array<Region> new_regions;
380 for (Region old_region : tensor_op->input_regions) {
381 Region region;
382 for (Range r : old_region) {
383 PrimExpr min = VarReplacer(vsub2newvar)(r->min);
384 PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
385 region.push_back(Range::FromMinExtent(min, extent));
386 }
387 new_regions.push_back(region);
388 }
389
390 Array<PrimExpr> new_scalar_inputs;
391 for (PrimExpr old_input : tensor_op->scalar_inputs) {
392 new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
393 }
394
395 Operation cache_op =
396 TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis,
397 tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin,
398 tensor_op->inputs, new_regions, new_scalar_inputs);
399
400 // axis will be used in generating compute op
401 Array<IterVar> compute_axis = tensor_op->axis;
402 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
403 IterVar iv = tensor_op->axis[i];
404 IterVar aiv = IterVar(iv->dom, iv->var, kDataPar);
405 compute_axis.Set(i, aiv);
406 }
407
408 // The reader args
409 Array<PrimExpr> args;
410 {
411 // cache->compute
412 std::unordered_map<IterVar, PrimExpr> value_map;
413 for (IterVar iv : compute_axis) {
414 value_map[iv] = iv->var;
415 }
416 PassDownIndex(orig_stage, dom_map, &value_map, true);
417 for (IterVar iv : orig_stage->leaf_iter_vars) {
418 if (red_axis.count(iv)) continue;
419 args.push_back(value_map.at(iv));
420 }
421 // tensorized region axis
422 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
423 IterVar iv = compute_axis[i];
424 args.push_back(value_map.at(iv));
425 }
426 }
427
428 Array<PrimExpr> cache_expr_list;
429 for (size_t i = 0; i < tensor_size; i++) {
430 Tensor cache_tensor = cache_op.output(i);
431 cache_expr_list.push_back(cache_tensor(args));
432 }
433 Operation orig_new_op =
434 ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
435 return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
436}
437
438Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) {
439 (*this)->InvalidateCache();
440 ICHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0";
441 Tensor tensor = tensor_array[0];
442 Stage orig_stage = operator[](tensor->op);
443 const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
444 ICHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
445 << "size of input tensor list must be same as number of stage outputs";
446 for (size_t i = 1; i < tensor_array.size(); i++) {
447 Stage tmp_stage = operator[](tensor_array[i]->op);
448 ICHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp";
449 }
450 return CacheWriteWithReLayout(*this, tensor_array, scope);
451}
452
453Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
454 // support original compute and tensor compute both
455 (*this)->InvalidateCache();
456 if (tensor->op.as<ComputeOpNode>()) {
457 return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
458 } else if (tensor->op.as<TensorComputeOpNode>()) {
459 return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
460 } else {
461 LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
462 }
463}
464
465void RebaseNonZeroMinLoop(ScheduleNode* sch) {
466 std::unordered_map<IterVar, IterVar> rebase_map;
467 for (Stage s : sch->stages) {
468 if (s->attach_type == kInlinedAlready) continue;
469
470 auto root_iter_vars = s->op->root_iter_vars();
471 ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
472 for (IterVar iv : root_iter_vars) {
473 size_t idx = FindNodeRef(leaf_vars, iv);
474 auto it = s->iter_var_attrs.find(iv);
475 // don;t need to rebase path that are binded.
476 if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
477 continue;
478 }
479 if (idx < leaf_vars->size()) {
480 // insert rebase
481 IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
482 s->relations.push_back(te::Rebase(iv, rebased));
483 if (s->iter_var_attrs.count(iv)) {
484 s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
485 }
486 leaf_vars->SetItem(idx, rebased);
487 rebase_map[iv] = rebased;
488 }
489 }
490 }
491 // remap the parent relation
492 for (Stage s : sch->stages) {
493 if (s->attach_type != kScope) continue;
494 if (rebase_map.count(s->attach_ivar)) {
495 s->attach_ivar = rebase_map.at(s->attach_ivar);
496 }
497 }
498 for (Stage s : sch->groups) {
499 if (s->attach_type != kScope) continue;
500 if (rebase_map.count(s->attach_ivar)) {
501 s->attach_ivar = rebase_map.at(s->attach_ivar);
502 }
503 }
504}
505
506void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
507 sch->InvalidateCache();
508
509 std::vector<Array<PrimExpr>> new_body(sch->stages.size());
510 std::vector<bool> changed(sch->stages.size(), false);
511 std::vector<Stmt> new_hybrid_body(sch->stages.size());
512 std::vector<bool> hybrid_changed(sch->stages.size(), false);
513 // (sshtin): this workaround allows to inline extern ops into their consumer.
514 // All inputs for extern op should not be inlined because inlining may happen
515 // before TE generation for particular extern op. That may lead to
516 // crash during lowering or building stages.
517 // The problem description:
518 // In case of operations fusing, arguments inlining
519 // prevents creation of ProducerNode for extern operation.
520 // Instead of the creation it is supposed to use operation argument as inlined buffer
521 // but extern_op TIR generation can be peformed after inlining procedure so
522 // newly generated TIR does not have reference to input data at all.
523 std::unordered_map<Operation, Operation> ext_ops;
524 for (size_t i = 0; i < sch->stages.size(); i++) {
525 Stage stage = sch->stages[i];
526 auto ext_op = stage->op.as<ExternOpNode>();
527 if (ext_op) {
528 auto inps = ext_op->InputTensors();
529 for (size_t ii = 0; ii < inps.size(); ++ii) {
530 if (ext_ops.find(inps[ii]->op) == ext_ops.end()) {
531 ext_ops[inps[ii]->op] = stage->op;
532 }
533 }
534 }
535 }
536 // inline all the ops
537 for (size_t i = sch->stages.size(); i != 0; --i) {
538 Stage stage = sch->stages[i - 1];
539 if (stage->attach_type == kInline) {
540 stage->attach_type = kInlinedAlready;
541 Array<Var> args;
542 PrimExpr body;
543 {
544 // setup args
545 const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
546 ICHECK(compute) << "can only inline compute op";
547 for (auto iv : compute->axis) {
548 args.push_back(iv->var);
549 }
550 if (ext_ops.find(stage->op) != ext_ops.end()) {
551 // sshtin: The extern op can try to get access to the input tensors as a raw data,
552 // that can lead to error in IR builder.
553 stage->attach_type = kGroupRoot;
554 continue;
555 }
556 ICHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
557 if (feature_extraction_mode && compute->attrs.count("const_matrix")) {
558 // Use constant value to replace access of const matrices.
559 // This produces wrong IR but is good enough for feature extraction purposes.
560 // This simplification can accelerate the feature extration and evolutionary search.
561 body = make_const(compute->output_dtype(0), 1.0f);
562 } else {
563 body = compute->body[0];
564 }
565 }
566 for (size_t j = i; j < sch->stages.size(); ++j) {
567 Stage s = sch->stages[j];
568 const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
569 const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
570 if (compute) {
571 if (!new_body[j].size()) {
572 new_body[j] = compute->body;
573 }
574 if (new_body[j][0]->IsInstance<tir::ReduceNode>()) {
575 // specially handle reduction inline for multiplre reductions.
576 const tir::ReduceNode* reduce = new_body[j][0].as<tir::ReduceNode>();
577 for (size_t k = 1; k < new_body[j].size(); ++k) {
578 const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>();
579 ICHECK(reduce_);
580 ICHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
581 << "have the same attribute except value_index";
582 }
583 PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
584 .as<tir::EvaluateNode>()
585 ->value;
586 if (!new_value.same_as(new_body[j][0])) {
587 changed[j] = true;
588 const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
589 ICHECK(r != nullptr);
590 ICHECK_EQ(new_body[j].size(), r->source.size());
591 for (size_t k = 0; k < new_body[j].size(); ++k) {
592 auto n = make_object<tir::ReduceNode>(*r);
593 n->value_index = static_cast<int>(k);
594 n->dtype = r->source[k].dtype();
595 new_body[j].Set(k, PrimExpr(n));
596 }
597 }
598 } else {
599 for (size_t k = 0; k < new_body[j].size(); ++k) {
600 PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
601 .as<tir::EvaluateNode>()
602 ->value;
603 if (!new_value.same_as(new_body[j][k])) {
604 new_body[j].Set(k, new_value);
605 changed[j] = true;
606 }
607 }
608 }
609 } else if (hybrid) {
610 if (!new_hybrid_body[j].defined()) {
611 new_hybrid_body[j] = hybrid->body;
612 }
613 Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
614 if (!new_stmt.same_as(new_hybrid_body[j])) {
615 new_hybrid_body[j] = new_stmt;
616 hybrid_changed[j] = true;
617 }
618 }
619 }
620 }
621 }
622 std::unordered_map<Tensor, Tensor> repl;
623 // rewrite dataflow
624 for (size_t i = 0; i < sch->stages.size(); ++i) {
625 Stage s = sch->stages[i];
626 if (s->attach_type == kInlinedAlready) continue;
627 if (new_body[i].size()) {
628 // Logics from ReplaceDataFlow
629 const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
630 ICHECK(compute);
631 Operation op = s->op;
632 if (changed[i]) {
633 op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]);
634 }
635 op = op->ReplaceInputs(op, repl);
636 if (!op.same_as(s->op)) {
637 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
638 repl[s->op.output(idx)] = op.output(idx);
639 }
640 s->op = op;
641 }
642 } else if (hybrid_changed[i]) {
643 const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
644 ICHECK(hybrid);
645 Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
646 hybrid->outputs, new_hybrid_body[i]);
647 op = op->ReplaceInputs(op, repl);
648 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
649 repl[s->op.output(idx)] = op.output(idx);
650 }
651 s->op = op;
652 } else {
653 Operation op = s->op->ReplaceInputs(s->op, repl);
654 if (!op.same_as(s->op)) {
655 for (int j = 0; j < op->num_outputs(); ++j) {
656 repl[s->op.output(j)] = op.output(j);
657 }
658 s->op = op;
659 }
660 }
661 }
662}
663
664void LegalizeInvalidAttach(ScheduleNode* sch) {
665 // Legalize the compute_at location if the target iterator of compute_at is split or fused.
666 // Case 1: If the target of compute_at is split,
667 // we will move the compute_at location to the inner iterator.
668 // Case 2: If the target of compute_at is fused,
669 // we will move the compute_at location to the newly fused iterator.
670 // Note that case 2 can only happen if the target of compute_at
671 // is the innermost operand of fuse operation.
672
673 // Map an old invalid attach point to its new valid attach point
674 std::unordered_map<IterVar, IterVar> replace_map;
675
676 for (Stage stage : sch->stages) {
677 std::unordered_set<const Object*> visited;
678 for (Stage s = stage; s.defined();) {
679 // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
680 // because we follow the validation check in that function to legalize the attach.
681 ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
682 visited.insert(s.get());
683 Stage spec = s.GetAttachSpec();
684 if (spec->attach_type != kScope) {
685 break;
686 }
687 bool start_attach = false;
688 IterVar attach_ivar = spec->attach_ivar;
689 s = spec->attach_stage;
690 ICHECK(attach_ivar.defined());
691 ICHECK(s.defined());
692
693 for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
694 IterVar iv = s->leaf_iter_vars[i - 1];
695 if (!start_attach && iv.same_as(attach_ivar)) {
696 start_attach = true;
697 break;
698 }
699 }
700
701 if (!start_attach) {
702 IterVar new_attach_ivar = attach_ivar;
703 bool updated = true;
704 // recursively update the relations
705 while (updated) {
706 updated = false;
707 for (const auto& rel : s->relations) {
708 if (const FuseNode* r = rel.as<FuseNode>()) {
709 if (new_attach_ivar.same_as(r->inner)) {
710 new_attach_ivar = r->fused;
711 updated = true;
712 }
713 } else if (const SplitNode* r = rel.as<SplitNode>()) {
714 if (new_attach_ivar.same_as(r->parent)) {
715 new_attach_ivar = r->inner;
716 updated = true;
717 }
718 }
719 }
720 replace_map[attach_ivar] = new_attach_ivar;
721 }
722 }
723 }
724 }
725
726 // remap the parent relation
727 for (Stage s : sch->stages) {
728 if (s->attach_type != kScope) continue;
729 if (replace_map.count(s->attach_ivar)) {
730 s->attach_ivar = replace_map.at(s->attach_ivar);
731 }
732 }
733 for (Stage s : sch->groups) {
734 if (s->attach_type != kScope) continue;
735 if (replace_map.count(s->attach_ivar)) {
736 s->attach_ivar = replace_map.at(s->attach_ivar);
737 }
738 }
739}
740
741Schedule Schedule::normalize() {
742 Schedule sn = copy();
743 InjectInline(sn.operator->(), false);
744 RebaseNonZeroMinLoop(sn.operator->());
745 LegalizeInvalidAttach(sn.operator->());
746 return sn;
747}
748
749Schedule Schedule::normalize_for_feature_extraction() {
750 Schedule sn = copy();
751 InjectInline(sn.operator->(), true);
752 RebaseNonZeroMinLoop(sn.operator->());
753 LegalizeInvalidAttach(sn.operator->());
754 return sn;
755}
756
757// Handle reduction factor.
758Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
759 (*this)->InvalidateCache();
760 using tir::ReduceNode;
761 ICHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
762 Stage reduce_stage = operator[](tensor->op);
763 const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
764 ICHECK(compute_op) << "Can only factor ComputeOp";
765 ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
766 {
767 size_t axis_pos = FindNodeRef(leaf_vars, axis);
768 ICHECK_NE(axis_pos, leaf_vars->size())
769 << "Cannot find IterVar " << axis << " in leaf iter vars";
770 }
771 // Find touched reduction axis.
772 std::unordered_map<IterVar, int> touch_map;
773 touch_map[axis] = 1;
774 te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
775 te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
776 // skip reduction iteration.
777 std::unordered_set<IterVar> skip_bound_check;
778 // Verify normal axis are not touched.
779 for (IterVar iv : compute_op->axis) {
780 ICHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
781 skip_bound_check.insert(iv);
782 }
783 // get analyzer.
784 arith::Analyzer analyzer;
785 // Get the replace index
786 std::unordered_map<IterVar, Range> dom_map;
787 std::unordered_map<IterVar, PrimExpr> value_map;
788 for (IterVar iv : compute_op->reduce_axis) {
789 if (touch_map.count(iv)) {
790 dom_map[iv] = iv->dom;
791 } else {
792 skip_bound_check.insert(iv);
793 }
794 analyzer.Bind(iv->var, iv->dom);
795 }
796 te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
797 for (IterVar iv : reduce_stage->leaf_iter_vars) {
798 if (touch_map.count(iv)) {
799 Range dom = dom_map.at(iv);
800 if (is_one(dom->extent)) {
801 value_map[iv] = dom->min;
802 } else {
803 value_map[iv] = iv->var;
804 }
805 }
806 }
807 te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
808 std::vector<PrimExpr> predicates =
809 MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
810
811 // Get the factored op node.
812 const int factor_axis_pos =
813 factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
814 ICHECK_LE(factor_axis_pos, compute_op->axis.size());
815 auto n = make_object<ComputeOpNode>();
816 n->name = compute_op->name + ".rf";
817 {
818 // axis relacement.
819 IterVar iv(dom_map.at(axis), axis->var, kDataPar);
820 ICHECK(is_zero(iv->dom->min)) << "Can only factor reduction domain starting from 0";
821
822 const int size = compute_op->axis.size();
823 for (int idx = 0; idx < size; ++idx) {
824 if (factor_axis_pos == idx) {
825 n->axis.push_back(iv);
826 }
827 n->axis.push_back(compute_op->axis[idx]);
828 }
829 if (factor_axis_pos == size) {
830 n->axis.push_back(iv);
831 }
832 }
833 // predicate generation, copy not touched axis.
834 int idx = tensor->value_index;
835 const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
836 ICHECK(reduce) << "Can only rfactor non-inline reductions";
837 predicates.push_back(reduce->condition);
838
839 PrimExpr predicate =
840 likely(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
841 const_true(1), predicates));
842
843 std::unordered_map<const VarNode*, PrimExpr> vsub;
844
845 for (IterVar iv : compute_op->reduce_axis) {
846 if (!touch_map.count(iv)) {
847 n->reduce_axis.push_back(iv);
848 } else {
849 ICHECK(value_map.count(iv));
850 PrimExpr index = value_map.at(iv);
851 vsub[iv->var.get()] = index;
852 }
853 }
854
855 // Copy touched axis.
856 for (IterVar iv : reduce_stage->leaf_iter_vars) {
857 if (touch_map.count(iv) && !iv.same_as(axis)) {
858 ICHECK_EQ(iv->iter_type, kCommReduce);
859 IterVar ncpy(dom_map.at(iv), iv->var, iv->iter_type, iv->thread_tag, iv->span);
860 n->reduce_axis.push_back(ncpy);
861 }
862 }
863 VarReplacer replacer(vsub);
864 Array<PrimExpr> new_source =
865 tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
866
867 PrimExpr new_pred = replacer(predicate);
868
869 std::vector<PrimExpr> body;
870 for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
871 body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
872 }
873 n->body = Array<PrimExpr>(body);
874 // refresh relations, keep the un-touched relations.
875 Array<IterVarRelation> rels;
876 for (IterVarRelation rel : reduce_stage->relations) {
877 bool touched = false;
878 if (const SplitNode* r = rel.as<SplitNode>()) {
879 if (touch_map.count(r->parent)) touched = true;
880 } else if (const FuseNode* r = rel.as<FuseNode>()) {
881 if (touch_map.count(r->fused)) touched = true;
882 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
883 if (touch_map.count(r->parent)) touched = true;
884 } else {
885 LOG(FATAL) << "unknown relation type";
886 }
887 if (!touched) {
888 rels.push_back(rel);
889 }
890 }
891 // initialize the factored stage.
892 Operation factor_op(n);
893 Array<Stage>& stages = (*this)->stages;
894 size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
895 Stage factor_stage = Stage(factor_op);
896 factor_stage->relations = rels;
897 ICHECK_LT(stage_pos, stages.size());
898 stages.insert(stages.begin() + stage_pos, factor_stage);
899 (*this)->stage_map.Set(factor_op, factor_stage);
900 factor_stage->group = reduce_stage->group;
901 if (factor_stage->group.defined()) {
902 ++factor_stage->group->num_child_stages;
903 }
904 // Replace the old reduction.
905 IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
906 Array<Tensor> factor_tensors;
907 Array<Tensor> old_tensors;
908 int size = factor_op->num_outputs();
909 for (int idx = 0; idx < size; ++idx) {
910 factor_tensors.push_back(factor_op.output(idx));
911 old_tensors.push_back(reduce_stage->op.output(idx));
912 }
913 Array<Tensor> repl_tensors = compute(
914 old_tensors[0]->shape,
915 [&](const Array<Var>& i) {
916 Array<PrimExpr> indices;
917 const int idx_size = static_cast<int>(i.size());
918 for (int idx = 0; idx < idx_size; ++idx) {
919 if (factor_axis_pos == idx) {
920 indices.push_back(repl_red_axis->var);
921 }
922 indices.push_back(i[idx]);
923 }
924 Array<PrimExpr> new_init = reduce->init;
925 if (!reduce->init.empty()) {
926 std::unordered_map<const VarNode*, PrimExpr> init_vsub;
927 for (const auto& init : reduce->init) {
928 if (init->IsInstance<ProducerLoadNode>()) {
929 ICHECK_EQ(compute_op->axis.size(), idx_size)
930 << "'init' should have the number of dimensions as output when using with "
931 "rfactor";
932 for (int idx = 0; idx < idx_size; idx++) {
933 init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
934 }
935 }
936 }
937 VarReplacer init_replacer(init_vsub);
938 new_init = tir::UpdateArray(
939 reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
940 }
941 if (factor_axis_pos == idx_size) {
942 indices.push_back(repl_red_axis->var);
943 }
944 Array<PrimExpr> factor_exprs;
945 for (int idx = 0; idx < size; ++idx) {
946 factor_exprs.push_back(factor_tensors[idx](indices));
947 }
948 Array<PrimExpr> reductions;
949 Array<IterVar> axis = {repl_red_axis};
950 PrimExpr cond = const_true();
951 for (int idx = 0; idx < size; ++idx) {
952 reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
953 }
954 return reductions;
955 },
956 reduce_stage->op->name + ".repl");
957
958 std::unordered_map<Tensor, Tensor> vmap;
959 std::unordered_map<Tensor, Tensor> rvmap;
960 for (int idx = 0; idx < size; ++idx) {
961 vmap[old_tensors[idx]] = repl_tensors[idx];
962 rvmap[repl_tensors[idx]] = old_tensors[idx];
963 }
964 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
965 // revamp the reduction stage.
966 reduce_stage->op = repl_tensors[0]->op;
967 reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
968 reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
969 reduce_stage->relations = Array<IterVarRelation>();
970 return factor_tensors;
971}
972} // namespace te
973} // namespace tvm
974