1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file bound.cc
22 * \brief The bound inference logic.
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/te/operation.h>
26#include <tvm/te/schedule_pass.h>
27
28#include <unordered_map>
29#include <unordered_set>
30
31#include "../../runtime/thread_storage_scope.h"
32#include "graph.h"
33#include "message_passing.h"
34
35namespace tvm {
36namespace te {
37
38using runtime::StorageRank;
39using runtime::StorageScope;
40using runtime::ThreadScope;
41
42/*! \brief The graph context used during bound inference. */
43struct GraphContext {
44 /*! \brief The feed graph */
45 FeedGraph feed_graph;
46 /*! \brief Attachment path */
47 AttachPath attach_path;
48 /*! \brief The bind map */
49 std::unordered_map<IterVar, IterVar> bind_map;
50 /*! \brief map from op to stage */
51 std::unordered_map<const Object*, Stage> op2stage_;
52};
53
54bool NeedRelax(const IterVar& iv, bool found_attach,
55 const std::unordered_map<IterVar, IterVar>& bind_map,
56 const runtime::StorageScope& scope) {
57 auto it = bind_map.find(iv);
58 const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
59 if (tag.length() == 0 || tag == "pipeline") {
60 return !found_attach;
61 }
62 ThreadScope ts = ThreadScope::Create(tag);
63
64 // When there is warp memory
65 // threadIdx.x must be set to be warp index.
66 if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) {
67 return true;
68 }
69 return static_cast<int>(scope.rank) <= ts.rank;
70}
71
72// infer storage scope, if not given
73StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) {
74 if (stage->scope.length() != 0) {
75 return StorageScope::Create(stage->scope);
76 }
77 int max_rank = -1;
78 for (IterVar iv : ctx.attach_path.at(stage->op)) {
79 auto it = ctx.bind_map.find(iv);
80 const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
81 if (tag != "pipeline" && tag.length() != 0) {
82 max_rank = std::max(max_rank, ThreadScope::Create(tag).rank);
83 }
84 }
85 StorageScope s;
86 s.rank = runtime::DefaultStorageRank(max_rank);
87 return s;
88}
89
90void InferRootBound(const Stage& stage, const GraphContext& ctx,
91 std::unordered_map<IterVar, Range>* rmap) {
92 ICHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops";
93 if (stage->attach_type == kInlinedAlready) return;
94 if (stage->is_output) {
95 // verify correctness.
96 ICHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root";
97 }
98 if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
99 for (auto iv : stage->op->root_iter_vars()) {
100 ICHECK(iv->dom.defined());
101 ICHECK(!rmap->count(iv));
102 (*rmap)[iv] = iv->dom;
103 }
104 return;
105 }
106 // The tensor domain.
107 std::unordered_map<Tensor, TensorDom> tmap;
108 // The consumers of the op.
109 std::unordered_set<Operation> consumers;
110 for (int i = 0; i < stage->op->num_outputs(); ++i) {
111 Tensor t = stage->op.output(i);
112 tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
113 auto it = ctx.feed_graph.find(t);
114 if (it != ctx.feed_graph.end()) {
115 for (const Operation& op : it->second) {
116 consumers.insert(op);
117 }
118 } else {
119 LOG(INFO) << "not in feed graph consumer = " << stage->op;
120 }
121 }
122 // storage scope.
123 runtime::StorageScope scope = InferStorageScope(stage, ctx);
124 // Bound prop by other consumers.
125 // - Compute bound by relaxation rules: NeedRelax
126 // - For normal index, use relative location of loop nest./
127 // - For thread index, use the thread scope.
128 //
129 Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
130 // The parent set.
131 for (const Operation& op : consumers) {
132 Map<Var, IntSet> relax_set;
133 std::unordered_map<IterVar, IntSet> up_state;
134 bool found_attach = false;
135 ICHECK(ctx.op2stage_.count(op.get()));
136 const Stage& op_stage = ctx.op2stage_.at(op.get());
137 // Consumer nest
138 for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
139 IterVar iv = op_stage->leaf_iter_vars[i - 1];
140 if (stage_attach.size() != 0 && iv == stage_attach[0]) {
141 found_attach = true;
142 }
143 auto it = rmap->find(iv);
144 ICHECK(it != rmap->end());
145 const Range& vrange = it->second;
146 if (is_one(vrange->extent)) {
147 up_state[iv] = IntSet::SinglePoint(vrange->min);
148 } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
149 ICHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
150 << " call schedule.normalize to achieve this. ";
151 if (ctx.bind_map.count(iv)) {
152 up_state[iv] = IntSet::SinglePoint(ctx.bind_map.at(iv)->var);
153 } else {
154 up_state[iv] = IntSet::SinglePoint(iv->var);
155 }
156 } else {
157 up_state[iv] = IntSet::FromRange(vrange);
158 }
159 }
160 // Consumer's attach nest
161 for (IterVar iv : ctx.attach_path.at(op)) {
162 if (stage_attach.size() != 0 && iv == stage_attach[0]) {
163 found_attach = true;
164 }
165 Range vrange = rmap->at(iv);
166 ICHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
167 << "call schedule.normalize to achieve this.";
168 if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
169 relax_set.Set(iv->var, IntSet::FromRange(vrange));
170 if (ctx.bind_map.count(iv)) {
171 relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::FromRange(vrange));
172 }
173 }
174 }
175 ICHECK(found_attach || stage_attach.size() == 0)
176 << "Invalid Schedule, cannot find the producer " << stage->op
177 << " along the loop nest specified by compute_at of consumer " << op;
178 // Get the domain of the consumer
179 PassUpDomain(op_stage, *rmap, &up_state);
180 // Relax if needed.
181 std::unordered_map<const VarNode*, IntSet> dom_map;
182 arith::Analyzer analyzer;
183 for (auto entry : *rmap) {
184 analyzer.Bind(entry.first->var, entry.second);
185 }
186 for (auto iv : op->root_iter_vars()) {
187 Range r;
188 if (up_state.count(iv)) {
189 r = up_state.at(iv).CoverRange(iv->dom);
190 } else {
191 r = iv->dom;
192 }
193 if (relax_set.size() != 0) {
194 dom_map[iv->var.get()] =
195 IntSet::Interval(analyzer.int_set(r->min, relax_set).min(),
196 analyzer.int_set(r->min + r->extent - 1, relax_set).max());
197 } else {
198 dom_map[iv->var.get()] = IntSet::FromRange(r);
199 }
200 analyzer.Bind(iv->var, r, true);
201 }
202 op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
203 }
204 stage->op->GatherBound(stage->op, tmap, rmap);
205}
206
207Map<IterVar, Range> InferBound(const Schedule& sch) {
208 // Prepare context
209 GraphContext ctx;
210 Array<Operation> roots;
211 arith::Analyzer analyzer;
212
213 for (Operation op : sch->outputs) {
214 roots.push_back(sch->stage_map[op]->op);
215 }
216 ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
217
218 for (Stage stage : sch->stages) {
219 for (auto kv : stage->iter_var_attrs) {
220 if (kv.second->bind_thread.defined()) {
221 ICHECK(!ctx.bind_map.count(kv.first));
222 ctx.bind_map[kv.first] = kv.second->bind_thread;
223 }
224 }
225 ctx.op2stage_[stage->op.get()] = stage;
226 }
227 ctx.attach_path = CreateAttachPath(sch);
228 // Run inference.
229 std::unordered_map<IterVar, Range> ret;
230 for (size_t i = sch->stages.size(); i != 0; --i) {
231 const Stage& stage = sch->stages[i - 1];
232 InferRootBound(stage, ctx, &ret);
233
234 // bind bound of root iter vars.
235 for (auto iv : stage->op->root_iter_vars()) {
236 auto it = ret.find(iv);
237 if (it != ret.end()) {
238 analyzer.Bind(iv->var, it->second);
239 }
240 }
241
242 // pass down to get bound of all iter vars.
243 PassDownDomain(stage, &ret, &analyzer);
244 for (IterVar iv : stage->env_threads) {
245 ICHECK(iv->dom.defined());
246 ret[iv] = iv->dom;
247 }
248 }
249 for (auto it = ret.begin(); it != ret.end(); it++) {
250 DataType var_type = it->first->var.dtype();
251 it->second = Range::FromMinExtent(
252 // The range associated with each itervar must have the same dtype as the var
253 analyzer.Simplify(cast(var_type, it->second->min)),
254 analyzer.Simplify(cast(var_type, it->second->extent)));
255 }
256 return Map<IterVar, Range>(ret.begin(), ret.end());
257}
258
259TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound);
260
261} // namespace te
262} // namespace tvm
263