1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file schedule_lang.cc
22 */
23#include <dmlc/thread_local.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/te/operation.h>
26#include <tvm/te/schedule.h>
27
28#include <algorithm>
29#include <stack>
30#include <unordered_set>
31#include <vector>
32
33#include "graph.h"
34
35namespace tvm {
36namespace te {
37
38// find first occurance location in leaf
39template <typename T>
40size_t FindNodeRef(ArrayNode* array_node, const T& v) {
41 const Object* n = v.get();
42 for (size_t i = 0; i < array_node->size(); ++i) {
43 if (array_node->at(i).get() == n) return i;
44 }
45 return array_node->size();
46}
47
48size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) {
49 size_t pos = FindNodeRef(leaf_vars, v);
50 if (pos < leaf_vars->size()) return pos;
51
52 if (FindNodeRef(all_vars, v) < all_vars->size()) {
53 LOG(FATAL) << "Operate on iter var " << v << "that has already been split";
54 } else {
55 LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule";
56 }
57 return 0;
58}
59
60DataType MatchDataType(std::vector<DataType> dtypes) {
61 int max_bits = -1;
62 for (const auto& dtype : dtypes) {
63 ICHECK(dtype.is_int());
64 ICHECK(dtype.is_scalar());
65 max_bits = std::max(max_bits, dtype.bits());
66 }
67 return DataType::Int(max_bits);
68}
69
70void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts,
71 IterVar* p_outer, IterVar* p_inner) {
72 // Check if split is valid.
73 ICHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce ||
74 parent->iter_type == kOrdered)
75 << "Cannot split on " << IterVarType2String(parent->iter_type);
76 IterVar outer = IterVar(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type);
77 IterVar inner = IterVar(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type);
78 *p_outer = outer;
79 *p_inner = inner;
80 // The splits
81 Array<IterVar>& all_vars = self->all_iter_vars;
82 Array<IterVar>& leaf_vars = self->leaf_iter_vars;
83 size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent);
84 self->relations.push_back(Split(parent, outer, inner, factor, nparts));
85 // add vars to all vars
86 all_vars.push_back(outer);
87 all_vars.push_back(inner);
88 // replace the position.
89 leaf_vars.erase(leaf_vars.begin() + pos);
90 leaf_vars.insert(leaf_vars.begin() + pos, inner);
91 leaf_vars.insert(leaf_vars.begin() + pos, outer);
92}
93
94Stage::Stage(Operation op) {
95 auto n = make_object<StageNode>();
96 n->op = op;
97 n->origin_op = op;
98 n->all_iter_vars = op->root_iter_vars();
99 // remove opaque var from leaf.
100 Array<IterVar> clean;
101 for (IterVar iv : n->all_iter_vars) {
102 if (iv->iter_type != kOpaque) clean.push_back(iv);
103 }
104 if (clean.size() == n->all_iter_vars.size()) {
105 n->leaf_iter_vars = n->all_iter_vars;
106 } else {
107 n->leaf_iter_vars = clean;
108 }
109 data_ = std::move(n);
110}
111
112bool Stage::is_scheduled() const {
113 const StageNode* n = operator->();
114 return !(n->relations.empty() && n->attach_type == kGroupRoot &&
115 n->all_iter_vars.same_as(n->leaf_iter_vars));
116}
117
118Stage Stage::GetAttachSpec() const {
119 Stage attach_spec = *this;
120 while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) {
121 attach_spec = attach_spec->group;
122 }
123 return attach_spec;
124}
125
126Stage& Stage::set_scope(std::string scope) { // NOLINT(*)
127 (*this)->scope = scope;
128 return *this;
129}
130
131Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*)
132 ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
133 // Group constraint checking.
134 Stage group = (*this)->group;
135 if (group.defined()) {
136 Stage pg = parent->group;
137 while (pg.defined() && !pg.same_as(group)) {
138 pg = pg->group;
139 }
140 ICHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group";
141 }
142
143 (*this)->attach_type = kScope;
144 (*this)->attach_ivar = scope;
145 (*this)->attach_stage = parent;
146 bool found = false;
147 for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) {
148 if (scope == parent->leaf_iter_vars[i]) {
149 found = true;
150 break;
151 }
152 }
153 ICHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars"
154 << " parent=" << parent;
155 return *this;
156}
157
158Stage& Stage::compute_inline() { // NOLINT(*)
159 ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
160 (*this)->attach_type = kInline;
161 return *this;
162}
163
164Stage& Stage::compute_root() { // NOLINT(*)
165 ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates";
166 (*this)->attach_type = kGroupRoot;
167 return *this;
168}
169
170Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*)
171 StageNode* self = operator->();
172 ICHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce)
173 << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
174 ICHECK(thread_ivar->iter_type == kThreadIndex)
175 << "Cannot rebase by " << IterVarType2String(ivar->iter_type)
176 << ", only thread axis is allowed so far";
177 ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
178 ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
179 FindLeafVar(all_vars, leaf_vars, ivar);
180
181 auto it = self->iter_var_attrs.find(ivar);
182 ObjectPtr<IterVarAttrNode> n;
183 if (it != self->iter_var_attrs.end()) {
184 n = make_object<IterVarAttrNode>(*(*it).second.operator->());
185 if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) {
186 LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread;
187 }
188 } else {
189 n = make_object<IterVarAttrNode>();
190 }
191 n->bind_thread = thread_ivar;
192 self->iter_var_attrs.Set(ivar, IterVarAttr(n));
193 return *this;
194}
195
196Stage& Stage::env_threads(Array<IterVar> threads) {
197 StageNode* self = operator->();
198 ICHECK(self->op.defined() && self->op.as<ScanOpNode>())
199 << "env_threads is only valid for composite ops such as ScanOp";
200 ICHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads";
201 Array<IterVar>& leaf_vars = self->leaf_iter_vars;
202 Array<IterVar>& all_vars = self->all_iter_vars;
203 std::vector<IterVar> temp;
204 for (IterVar iv : threads) {
205 temp.push_back(iv);
206 }
207 leaf_vars.insert(leaf_vars.begin(), temp.begin(), temp.end());
208 all_vars.insert(all_vars.end(), temp.begin(), temp.end());
209 self->env_threads = threads;
210 return *this;
211}
212
213Stage& Stage::set_store_predicate(PrimExpr predicate) {
214 StageNode* self = operator->();
215 self->store_predicate = predicate;
216 return *this;
217}
218
219Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
220 IterVar* p_inner) { // NOLINT(*)
221 SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
222 return *this;
223}
224
225Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
226 IterVar* p_inner) { // NOLINT(*)
227 SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
228 return *this;
229}
230
231Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*)
232 StageNode* self = operator->();
233 ICHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce ||
234 outer->iter_type == kOrdered)
235 << "Cannot fuse " << IterVarType2String(outer->iter_type);
236 ICHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce ||
237 inner->iter_type == kOrdered)
238 << "Cannot fuse " << IterVarType2String(inner->iter_type);
239
240 IterVarType iter_type = outer->iter_type;
241 if (inner->iter_type > iter_type) iter_type = inner->iter_type;
242 std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused";
243 DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()});
244
245 IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type);
246
247 Array<IterVar>& all_vars = self->all_iter_vars;
248 Array<IterVar>& leaf_vars = self->leaf_iter_vars;
249
250 size_t pos_inner = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), inner);
251 size_t pos_outer = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), outer);
252 if (pos_inner + 1 == pos_outer) {
253 std::swap(outer, inner);
254 std::swap(pos_inner, pos_outer);
255 }
256 ICHECK_EQ(pos_inner, pos_outer + 1)
257 << "Can only fuse iterations that are consecutive between each other";
258 self->relations.push_back(Fuse(outer, inner, fused));
259 all_vars.push_back(fused);
260 leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1);
261 leaf_vars.insert(leaf_vars.begin() + pos_outer, fused);
262 *p_target = fused;
263 return *this;
264}
265
266Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*)
267 if (axes.size() != 0) {
268 IterVar fused = axes[0];
269 for (size_t i = 1; i < axes.size(); ++i) {
270 this->fuse(fused, axes[i], &fused);
271 }
272 *p_target = std::move(fused);
273 } else {
274 StageNode* self = operator->();
275 // special handle fuse empty array.
276 // insert at the outer most loop
277 IterVar singleton =
278 IterVar(Range::FromMinExtent(0, 1), Var("singleton", DataType::Int(32)), kDataPar);
279 self->relations.push_back(Singleton(singleton));
280 Array<IterVar>& all_vars = self->all_iter_vars;
281 Array<IterVar>& leaf_vars = self->leaf_iter_vars;
282 all_vars.push_back(singleton);
283 leaf_vars.insert(leaf_vars.begin(), singleton);
284 *p_target = singleton;
285 }
286 return *this;
287}
288
289Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*)
290 std::unordered_set<IterVar> seen_var;
291 StageNode* self = operator->();
292 for (IterVar iv : order) {
293 ICHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce ||
294 iv->iter_type == kThreadIndex)
295 << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")";
296
297 ICHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv;
298 seen_var.insert(iv);
299 }
300 ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
301 ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
302 std::vector<size_t> pos;
303
304 for (size_t i = 0; i < order.size(); ++i) {
305 pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
306 }
307 std::vector<ObjectRef> temp;
308 for (size_t i = 0; i < pos.size(); ++i) {
309 temp.emplace_back(leaf_vars->at(pos[i]));
310 }
311 std::sort(pos.begin(), pos.end());
312 for (size_t i = 0; i < pos.size(); ++i) {
313 leaf_vars->SetItem(pos[i], temp[i]);
314 }
315 return *this;
316}
317
318Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor,
319 IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) {
320 split(x_parent, x_factor, p_x_outer, p_x_inner);
321 split(y_parent, y_factor, p_y_outer, p_y_inner);
322 reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner}));
323 return *this;
324}
325
326template <typename FUpdate>
327inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate,
328 bool need_leaf = true) {
329 if (need_leaf) {
330 ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
331 ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
332 FindLeafVar(all_vars, leaf_vars, var);
333 }
334 auto it = self->iter_var_attrs.find(var);
335 ObjectPtr<IterVarAttrNode> n;
336 if (it != self->iter_var_attrs.end()) {
337 n = make_object<IterVarAttrNode>(*(*it).second.operator->());
338 } else {
339 n = make_object<IterVarAttrNode>();
340 }
341 fupdate(n.get());
342 self->iter_var_attrs.Set(var, IterVarAttr(n));
343}
344
345inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) {
346 UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; });
347}
348
349Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
350 ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled ||
351 var->iter_type == kVectorized || var->iter_type == kTensorized ||
352 var->iter_type == kParallelized)
353 << "Cannot vectorize on " << IterVarType2String(var->iter_type);
354 SetAttrIterType(operator->(), var, kVectorized);
355 return *this;
356}
357
358Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*)
359 UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
360 n->iter_type = kTensorized;
361 n->tensor_intrin = f;
362 });
363 return *this;
364}
365
366Stage& Stage::unroll(IterVar var) { // NOLINT(*)
367 SetAttrIterType(operator->(), var, kUnrolled);
368 return *this;
369}
370
371Stage& Stage::parallel(IterVar var) { // NOLINT(*)
372 SetAttrIterType(operator->(), var, kParallelized);
373 return *this;
374}
375
376Stage& Stage::pragma(IterVar var, const std::string& pragma_type,
377 const PrimExpr& pragma_value) { // NOLINT(*)
378 if (pragma_type == "unroll") {
379 this->unroll(var);
380 } else if (pragma_type == "vectorize") {
381 this->vectorize(var);
382 } else {
383 UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
384 n->pragma_keys.push_back(tir::StringImm(pragma_type));
385 n->pragma_values.push_back(pragma_value);
386 });
387 }
388 return *this;
389}
390
391Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) {
392 StageNode* self = operator->();
393 ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
394 ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
395 FindLeafVar(all_vars, leaf_vars, var);
396 auto it = self->iter_var_attrs.find(var);
397 ObjectPtr<IterVarAttrNode> n;
398 if (it != self->iter_var_attrs.end()) {
399 n = make_object<IterVarAttrNode>(*(*it).second.operator->());
400 } else {
401 n = make_object<IterVarAttrNode>();
402 }
403 n->prefetch_data.push_back(tensor);
404 n->prefetch_offset.push_back(offset);
405 self->iter_var_attrs.Set(var, IterVarAttr(n));
406 return *this;
407}
408
409Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
410 StageNode* self = operator->();
411 UpdateIterVarAttr(
412 self, axis,
413 [factor, offset](IterVarAttrNode* n) {
414 n->dim_align_factor = factor;
415 n->dim_align_offset = offset;
416 },
417 false);
418 return *this;
419}
420
421Stage& Stage::double_buffer() {
422 StageNode* self = operator->();
423 ICHECK(!self->is_output) << "Cannot apply double buffer on output";
424 self->double_buffer = true;
425 return *this;
426}
427
428Stage& Stage::rolling_buffer() {
429 StageNode* self = operator->();
430 ICHECK(!self->is_output) << "Cannot apply rolling buffer on output";
431 self->rolling_buffer = true;
432 return *this;
433}
434Stage& Stage::transform_layout(const Array<Var>& initial_indices,
435 const Array<PrimExpr>& final_indices,
436 Array<IterVar>* out_iter_vars) {
437 StageNode* self = operator->();
438 IndexMap map(initial_indices, final_indices);
439 self->layout_transforms.push_back(map);
440
441 auto* compute = self->op.as<ComputeOpNode>();
442
443 // Can only rewrite the indices of compute op nodes.
444 if (!compute) {
445 return *this;
446 }
447
448 CHECK_EQ(initial_indices.size(), compute->axis.size())
449 << "Expected number of initial indices in transformation to match the dimension of "
450 << self->op->name;
451
452 // Locate the IterVar objects for the data axes.
453 auto leaf_iter_range = [&]() -> std::pair<size_t, size_t> {
454 std::vector<size_t> leaf_var_indices;
455 for (const auto& axis : compute->axis) {
456 leaf_var_indices.push_back(
457 FindLeafVar(self->all_iter_vars.CopyOnWrite(), self->leaf_iter_vars.CopyOnWrite(), axis));
458 }
459 auto minmax_element = std::minmax_element(leaf_var_indices.begin(), leaf_var_indices.end());
460 return {*minmax_element.first, *minmax_element.second + 1};
461 }();
462 CHECK_EQ(leaf_iter_range.first + compute->axis.size(), leaf_iter_range.second)
463 << "Cannot transform indices if they have already been reordered";
464
465 // Determine the updated ranges of iteration.
466 Array<Range> initial_ranges;
467 for (const auto& iter_var : compute->axis) {
468 initial_ranges.push_back(iter_var->dom);
469 }
470 Array<Range> final_ranges = map->MapRanges(initial_ranges);
471
472 // Make IterVar objects to represent the new iterations.
473 auto inverse = map.Inverse(initial_ranges);
474 Array<IterVar> final_indices_iter;
475 ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size());
476 for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
477 final_indices_iter.push_back(IterVar(final_ranges[i], inverse->initial_indices[i], kDataPar));
478 }
479
480 // Append the new IterVar objects to all_iter_vars
481 for (const auto& iter_var : final_indices_iter) {
482 self->all_iter_vars.push_back(iter_var);
483 }
484
485 // Replace the existing IterVar objects in leaf_iter_vars with the
486 // new IterVar objects.
487 self->leaf_iter_vars.erase(self->leaf_iter_vars.begin() + leaf_iter_range.first,
488 self->leaf_iter_vars.begin() + leaf_iter_range.second);
489 self->leaf_iter_vars.insert(self->leaf_iter_vars.begin() + leaf_iter_range.first,
490 final_indices_iter.begin(), final_indices_iter.end());
491
492 // Define a relationship for each new axis
493 self->relations.push_back(Transform(compute->axis, final_indices_iter, map, inverse));
494
495 // Return the iteration variables as an output.
496 if (out_iter_vars) {
497 *out_iter_vars = final_indices_iter;
498 }
499
500 return *this;
501}
502
503Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) {
504 StageNode* self = operator->();
505 self->axis_separators = axis_separators;
506 return *this;
507}
508
509Stage CopyStage(const Stage& s) {
510 ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->());
511 return Stage(n);
512}
513
514Schedule Schedule::copy() const {
515 // map of stages.
516 const ScheduleNode* self = operator->();
517 std::unordered_map<Stage, Stage, ObjectPtrHash, ObjectPtrEqual> smap;
518 ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>();
519 n->outputs = self->outputs;
520 // Copy the stages.
521 for (Stage s : self->stages) {
522 Stage scopy = CopyStage(s);
523 smap[s] = scopy;
524 n->stages.push_back(scopy);
525 }
526 for (Stage g : self->groups) {
527 Stage gcopy = CopyStage(g);
528 smap[g] = gcopy;
529 n->groups.push_back(gcopy);
530 }
531 // Remaps the reference relations.
532 for (auto kv : self->stage_map) {
533 n->stage_map.Set(kv.first, smap.at(kv.second));
534 }
535 for (Stage s : n->stages) {
536 if (s->attach_stage.defined()) {
537 ICHECK(smap.find(s->attach_stage) != smap.end())
538 << s->attach_stage << " not found in " << (*this);
539 s->attach_stage = smap.at(s->attach_stage);
540 }
541 if (s->group.defined()) {
542 ICHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this);
543 s->group = smap.at(s->group);
544 }
545 }
546 for (Stage s : n->groups) {
547 if (s->attach_stage.defined()) {
548 ICHECK(smap.find(s->attach_stage) != smap.end())
549 << s->attach_stage << " not found in " << (*this);
550 s->attach_stage = smap.at(s->attach_stage);
551 }
552 if (s->group.defined()) {
553 ICHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this);
554 s->group = smap.at(s->group);
555 }
556 }
557 return Schedule(n);
558}
559
560Stage Schedule::operator[](const Operation& op) {
561 auto it = (*this)->stage_map.find(op);
562 ICHECK(it != (*this)->stage_map.end())
563 << "Cannot find Stage for operator " << op << " in the schedule";
564 return (*it).second;
565}
566
567Stage LeastCommonAncestor(Stage g1, Stage g2) {
568 if (!g1.defined()) return g1;
569 if (!g2.defined()) return g2;
570 if (g1.same_as(g2)) return g1;
571 Stage g = g1;
572 while (g.defined()) {
573 if (g.same_as(g2)) return g2;
574 g = g->group;
575 }
576 g = g2;
577 while (g.defined()) {
578 if (g.same_as(g1)) return g1;
579 g = g->group;
580 }
581 return g;
582}
583
584Array<Tensor> RemapTensor(ScheduleNode* self, const Array<Tensor>& arr) {
585 self->InitCache();
586 const auto& op2stage_cache = self->op2stage_cache_;
587 Array<Tensor> ret;
588 for (Tensor t : arr) {
589 if (!op2stage_cache.count(t->op.get())) {
590 ICHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan";
591 t = self->stage_map[t->op]->op.output(t->value_index);
592 }
593 ret.push_back(t);
594 }
595 return ret;
596}
597
598// Group the schedule stages.
599Stage Schedule::create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
600 bool include_inputs) {
601 ScheduleNode* self = operator->();
602 self->InitCache();
603 const auto& op2stage_cache = self->op2stage_cache_;
604 // Get the ops.
605 Array<Operation> ops =
606 te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs);
607 // local counter entry
608 // Automatically initialize to 0 during creation.
609 struct Entry {
610 int count{0};
611 };
612 // Map of group->touched counter
613 std::unordered_map<Stage, Entry, ObjectPtrHash, ObjectPtrEqual> counter;
614 // The parent group;
615 Stage parent_group;
616 // Detect common parent and child.
617 for (size_t i = 0; i < ops.size(); ++i) {
618 Operation op = ops[i];
619 auto it = op2stage_cache.find(op.get());
620 ICHECK(it != op2stage_cache.end());
621 Stage op_group = it->second->group;
622 if (i == 0) {
623 parent_group = op_group;
624 } else {
625 parent_group = LeastCommonAncestor(parent_group, op_group);
626 }
627 if (op_group.defined()) {
628 ++counter[op_group].count;
629 }
630 }
631 // Create the new group stage.
632 Stage gstage(make_object<StageNode>());
633 gstage->group = parent_group;
634 if (parent_group.defined()) {
635 ++parent_group->num_child_stages;
636 }
637 // Propagate the counter statistics from by checking if subgroup
638 // Is full and propagate.
639 std::vector<Stage> stack;
640 for (auto& kv : counter) {
641 if (!kv.first.same_as(parent_group)) {
642 if (kv.first->num_child_stages == kv.second.count) {
643 stack.push_back(kv.first);
644 }
645 }
646 }
647 while (!stack.empty()) {
648 Stage g = stack.back();
649 stack.pop_back();
650 if (g->group.defined() && !g->group.same_as(parent_group)) {
651 Entry& e = counter[g->group];
652 ++e.count;
653 if (e.count == g->group->num_child_stages) {
654 stack.push_back(g->group);
655 }
656 }
657 }
658 // Verification and remappig the subgroups.
659 for (auto& kv : counter) {
660 if (kv.first.same_as(parent_group)) continue;
661 ICHECK_EQ(kv.first->num_child_stages, kv.second.count)
662 << "Trying to group region that intersect with an already existed group";
663 if (kv.first->group.same_as(parent_group)) {
664 Stage s = kv.first;
665 s->group = gstage;
666 ++gstage->num_child_stages;
667 if (parent_group.defined()) {
668 --parent_group->num_child_stages;
669 }
670 }
671 }
672 // Remap the group of op stages.
673 for (Operation op : ops) {
674 auto it = op2stage_cache.find(op.get());
675 ICHECK(it != op2stage_cache.end());
676 Stage s = it->second;
677 if (s->group.same_as(parent_group)) {
678 s->group = gstage;
679 ++gstage->num_child_stages;
680 if (parent_group.defined()) {
681 --parent_group->num_child_stages;
682 }
683 }
684 }
685 // Correct the attach to keep everything in group.
686 for (Operation op : ops) {
687 auto it = op2stage_cache.find(op.get());
688 ICHECK(it != op2stage_cache.end());
689 Stage s = it->second;
690 if (s->attach_type == kScope) {
691 Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage);
692 if (!cg.same_as(gstage)) {
693 LOG(WARNING) << "group invalidates some previous compute_at relation "
694 << " and keeps things to be computed inside the group";
695 s.compute_root();
696 }
697 }
698 }
699
700 self->groups.push_back(gstage);
701 return gstage;
702}
703
704void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); }
705
706void ScheduleNode::InitCache() {
707 if (op2stage_cache_.size() == stages.size()) return;
708 InvalidateCache();
709 for (Stage s : stages) {
710 if (s->op.defined()) {
711 op2stage_cache_[s->op.get()] = s;
712 }
713 }
714 ICHECK_EQ(op2stage_cache_.size(), stages.size());
715}
716
717bool ScheduleNode::Contain(const Operation& op) const {
718 return stage_map.find(op) != stage_map.end();
719}
720
721Schedule::Schedule(Array<Operation> ops) {
722 auto n = make_object<ScheduleNode>();
723 data_ = n;
724 n->outputs = ops;
725 auto g = te::CreateReadGraph(n->outputs);
726 Array<Operation> post_order = te::PostDFSOrder(n->outputs, g);
727 // output set.
728 std::unordered_set<Operation> output_set;
729 for (Operation x : ops) {
730 output_set.insert(x);
731 }
732 for (Operation op : post_order) {
733 Stage stage(op);
734 stage->is_output = output_set.count(op) != 0;
735 n->stages.push_back(stage);
736 n->stage_map.Set(op, stage);
737 // mark scan updates.
738 if (const ScanOpNode* scan = op.as<ScanOpNode>()) {
739 Array<Tensor> inputs;
740 for (Tensor t : scan->state_placeholder) {
741 inputs.push_back(t);
742 }
743 for (Tensor t : scan->inputs) {
744 inputs.push_back(t);
745 }
746 // Create the scan group.
747 Stage scan_group = this->create_group(scan->update, inputs, false);
748 scan_group->attach_type = kScanUpdate;
749 scan_group->attach_stage = stage;
750
751 for (size_t i = 0; i < scan->update.size(); ++i) {
752 Stage s = n->stage_map[scan->update[i]->op];
753 ICHECK(scan_group.same_as(s->group));
754 }
755 }
756 }
757}
758
759Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) {
760 auto n = make_object<SplitNode>();
761 n->parent = parent;
762 n->outer = outer;
763 n->inner = inner;
764 n->factor = factor;
765 n->nparts = nparts;
766 data_ = std::move(n);
767}
768
769Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) {
770 auto n = make_object<FuseNode>();
771 n->outer = outer;
772 n->inner = inner;
773 n->fused = fused;
774 data_ = std::move(n);
775}
776
777Rebase::Rebase(IterVar parent, IterVar rebased) {
778 auto n = make_object<RebaseNode>();
779 n->parent = parent;
780 n->rebased = rebased;
781 data_ = std::move(n);
782}
783
784Singleton::Singleton(IterVar iter) {
785 auto n = make_object<SingletonNode>();
786 n->iter = iter;
787 data_ = std::move(n);
788}
789
790Transform::Transform(Array<IterVar> original_variables, Array<IterVar> transformed_variables,
791 IndexMap forward_transformation, IndexMap inverse_transformation) {
792 auto n = make_object<TransformNode>();
793 n->original_variables = original_variables;
794 n->transformed_variables = transformed_variables;
795 n->forward_transformation = forward_transformation;
796 n->inverse_transformation = inverse_transformation;
797 data_ = std::move(n);
798}
799
800SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
801 ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>();
802 n->clauses = std::move(conditions);
803 data_ = std::move(n);
804}
805
806/*! \brief Entry to hold the SpecializedCondition context stack. */
807struct TVMSpecializationThreadLocalEntry {
808 /*! \brief The current specialized condition */
809 std::stack<SpecializedCondition> condition_stack;
810};
811
812/*! \brief Thread local store to hold the Target context stack. */
813typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore;
814
815void SpecializedCondition::EnterWithScope() {
816 TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
817 entry->condition_stack.push(*this);
818}
819
820void SpecializedCondition::ExitWithScope() {
821 TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
822 ICHECK(!entry->condition_stack.empty());
823 ICHECK(entry->condition_stack.top().same_as(*this));
824 entry->condition_stack.pop();
825}
826
827SpecializedCondition SpecializedCondition::Current() {
828 TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get();
829 SpecializedCondition cond;
830 if (entry->condition_stack.size() > 0) {
831 cond = entry->condition_stack.top();
832 }
833 return cond;
834}
835
836class SpecializedCondition::Internal {
837 public:
838 static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); }
839
840 static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); }
841};
842
843TVM_REGISTER_NODE_TYPE(StageNode);
844TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
845TVM_REGISTER_NODE_TYPE(SplitNode);
846TVM_REGISTER_NODE_TYPE(FuseNode);
847TVM_REGISTER_NODE_TYPE(RebaseNode);
848TVM_REGISTER_NODE_TYPE(SingletonNode);
849TVM_REGISTER_NODE_TYPE(ScheduleNode);
850TVM_REGISTER_NODE_TYPE(SpecializedConditionNode);
851
852// Printer
853TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
854 .set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) {
855 auto* op = static_cast<const StageNode*>(node.get());
856 if (op->op.defined()) {
857 p->stream << "stage(" << op->origin_op->name << ", " << op->op << ")";
858 } else {
859 p->stream << "group-stage(" << op << ")";
860 }
861 })
862 .set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) {
863 auto* op = static_cast<const IterVarAttrNode*>(node.get());
864 p->stream << IterVarType2String(op->iter_type);
865 })
866 .set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) {
867 auto* op = static_cast<const SplitNode*>(node.get());
868 p->stream << "split(parent=";
869 p->Print(op->parent);
870 p->stream << ", outer=";
871 p->Print(op->outer);
872 p->stream << ", inner=";
873 p->Print(op->inner);
874 if (op->factor.defined()) {
875 p->stream << ", factor=";
876 p->Print(op->factor);
877 } else {
878 p->stream << ", nparts=";
879 p->Print(op->nparts);
880 }
881 p->stream << ')';
882 })
883 .set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) {
884 auto* op = static_cast<const FuseNode*>(node.get());
885 p->stream << "fuse(";
886 p->stream << "outer=";
887 p->Print(op->outer);
888 p->stream << ", inner=";
889 p->Print(op->inner);
890 p->stream << ", fused=";
891 p->Print(op->fused);
892 p->stream << ')';
893 })
894 .set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) {
895 auto* op = static_cast<const RebaseNode*>(node.get());
896 p->stream << "rebase(";
897 p->stream << "parent=";
898 p->Print(op->parent);
899 p->stream << ", rebased=";
900 p->Print(op->rebased);
901 p->stream << ')';
902 })
903 .set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) {
904 auto* op = static_cast<const SingletonNode*>(node.get());
905 p->stream << "singleton(";
906 p->Print(op->iter);
907 p->stream << ')';
908 })
909 .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) {
910 auto* op = static_cast<const ScheduleNode*>(node.get());
911 p->stream << "schedule(" << op << ")";
912 })
913 .set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) {
914 auto* op = static_cast<const SpecializedConditionNode*>(node.get());
915 p->stream << "specialized_condition(";
916 p->Print(op->clauses);
917 p->stream << ')';
918 });
919
920TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule);
921
922TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope);
923
924TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind);
925
926TVM_REGISTER_GLOBAL("te.StageSplitByFactor")
927 .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
928 IterVar outer, inner;
929 stage.split(parent, factor, &outer, &inner);
930 return Array<IterVar>({outer, inner});
931 });
932
933TVM_REGISTER_GLOBAL("te.StageSplitByNParts")
934 .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
935 IterVar outer, inner;
936 stage.split_by_nparts(parent, nparts, &outer, &inner);
937 return Array<IterVar>({outer, inner});
938 });
939
940TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array<IterVar> axes) {
941 IterVar fused;
942 stage.fuse(axes, &fused);
943 return fused;
944});
945
946TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at);
947
948TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline);
949
950TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root);
951
952TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder);
953
954TVM_REGISTER_GLOBAL("te.StageTile")
955 .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor,
956 PrimExpr y_factor) {
957 IterVar x_outer, y_outer, x_inner, y_inner;
958 stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner);
959 return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
960 });
961
962TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads);
963
964TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate);
965
966TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll);
967
968TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize);
969
970TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize);
971
972TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel);
973
974TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma);
975
976TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch);
977
978TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align);
979
980TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer);
981
982TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer);
983
984TVM_REGISTER_GLOBAL("te.StageTransformLayout")
985 .set_body_typed([](Stage stage, const Array<Var>& initial_indices,
986 const Array<PrimExpr>& final_indices) {
987 Array<IterVar> new_iter_vars;
988 stage.transform_layout(initial_indices, final_indices, &new_iter_vars);
989 return new_iter_vars;
990 });
991
992TVM_REGISTER_GLOBAL("te.StageSetAxisSeparators").set_body_method(&Stage::set_axis_separators);
993
994TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize);
995
996TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group);
997
998TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read);
999
1000TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) {
1001 if (args[1].IsObjectRef<Tensor>()) {
1002 *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]);
1003 } else {
1004 *ret = args[0].operator Schedule().cache_write(args[1].operator Array<Tensor>(), args[2]);
1005 }
1006});
1007
1008TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor);
1009
1010TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array<PrimExpr> condition) {
1011 return SpecializedCondition(condition);
1012});
1013
1014TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) {
1015 *ret = SpecializedCondition::Current();
1016});
1017
1018TVM_REGISTER_GLOBAL("te.EnterSpecializationScope")
1019 .set_body_typed(SpecializedCondition::Internal::EnterScope);
1020
1021TVM_REGISTER_GLOBAL("te.ExitSpecializationScope")
1022 .set_body_typed(SpecializedCondition::Internal::ExitScope);
1023
1024} // namespace te
1025} // namespace tvm
1026