1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bound.cc |
22 | * \brief The bound inference logic. |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/te/operation.h> |
26 | #include <tvm/te/schedule_pass.h> |
27 | |
28 | #include <unordered_map> |
29 | #include <unordered_set> |
30 | |
31 | #include "../../runtime/thread_storage_scope.h" |
32 | #include "graph.h" |
33 | #include "message_passing.h" |
34 | |
35 | namespace tvm { |
36 | namespace te { |
37 | |
38 | using runtime::StorageRank; |
39 | using runtime::StorageScope; |
40 | using runtime::ThreadScope; |
41 | |
42 | /*! \brief The graph context used during bound inference. */ |
43 | struct GraphContext { |
44 | /*! \brief The feed graph */ |
45 | FeedGraph feed_graph; |
46 | /*! \brief Attachment path */ |
47 | AttachPath attach_path; |
48 | /*! \brief The bind map */ |
49 | std::unordered_map<IterVar, IterVar> bind_map; |
50 | /*! \brief map from op to stage */ |
51 | std::unordered_map<const Object*, Stage> op2stage_; |
52 | }; |
53 | |
54 | bool NeedRelax(const IterVar& iv, bool found_attach, |
55 | const std::unordered_map<IterVar, IterVar>& bind_map, |
56 | const runtime::StorageScope& scope) { |
57 | auto it = bind_map.find(iv); |
58 | const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); |
59 | if (tag.length() == 0 || tag == "pipeline" ) { |
60 | return !found_attach; |
61 | } |
62 | ThreadScope ts = ThreadScope::Create(tag); |
63 | |
64 | // When there is warp memory |
65 | // threadIdx.x must be set to be warp index. |
66 | if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { |
67 | return true; |
68 | } |
69 | return static_cast<int>(scope.rank) <= ts.rank; |
70 | } |
71 | |
72 | // infer storage scope, if not given |
73 | StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { |
74 | if (stage->scope.length() != 0) { |
75 | return StorageScope::Create(stage->scope); |
76 | } |
77 | int max_rank = -1; |
78 | for (IterVar iv : ctx.attach_path.at(stage->op)) { |
79 | auto it = ctx.bind_map.find(iv); |
80 | const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); |
81 | if (tag != "pipeline" && tag.length() != 0) { |
82 | max_rank = std::max(max_rank, ThreadScope::Create(tag).rank); |
83 | } |
84 | } |
85 | StorageScope s; |
86 | s.rank = runtime::DefaultStorageRank(max_rank); |
87 | return s; |
88 | } |
89 | |
90 | void InferRootBound(const Stage& stage, const GraphContext& ctx, |
91 | std::unordered_map<IterVar, Range>* rmap) { |
92 | ICHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops" ; |
93 | if (stage->attach_type == kInlinedAlready) return; |
94 | if (stage->is_output) { |
95 | // verify correctness. |
96 | ICHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root" ; |
97 | } |
98 | if (stage->is_output || stage->op.as<PlaceholderOpNode>()) { |
99 | for (auto iv : stage->op->root_iter_vars()) { |
100 | ICHECK(iv->dom.defined()); |
101 | ICHECK(!rmap->count(iv)); |
102 | (*rmap)[iv] = iv->dom; |
103 | } |
104 | return; |
105 | } |
106 | // The tensor domain. |
107 | std::unordered_map<Tensor, TensorDom> tmap; |
108 | // The consumers of the op. |
109 | std::unordered_set<Operation> consumers; |
110 | for (int i = 0; i < stage->op->num_outputs(); ++i) { |
111 | Tensor t = stage->op.output(i); |
112 | tmap.emplace(t, TensorDom(static_cast<int>(t.ndim()))); |
113 | auto it = ctx.feed_graph.find(t); |
114 | if (it != ctx.feed_graph.end()) { |
115 | for (const Operation& op : it->second) { |
116 | consumers.insert(op); |
117 | } |
118 | } else { |
119 | LOG(INFO) << "not in feed graph consumer = " << stage->op; |
120 | } |
121 | } |
122 | // storage scope. |
123 | runtime::StorageScope scope = InferStorageScope(stage, ctx); |
124 | // Bound prop by other consumers. |
125 | // - Compute bound by relaxation rules: NeedRelax |
126 | // - For normal index, use relative location of loop nest./ |
127 | // - For thread index, use the thread scope. |
128 | // |
129 | Array<IterVar> stage_attach = ctx.attach_path.at(stage->op); |
130 | // The parent set. |
131 | for (const Operation& op : consumers) { |
132 | Map<Var, IntSet> relax_set; |
133 | std::unordered_map<IterVar, IntSet> up_state; |
134 | bool found_attach = false; |
135 | ICHECK(ctx.op2stage_.count(op.get())); |
136 | const Stage& op_stage = ctx.op2stage_.at(op.get()); |
137 | // Consumer nest |
138 | for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) { |
139 | IterVar iv = op_stage->leaf_iter_vars[i - 1]; |
140 | if (stage_attach.size() != 0 && iv == stage_attach[0]) { |
141 | found_attach = true; |
142 | } |
143 | auto it = rmap->find(iv); |
144 | ICHECK(it != rmap->end()); |
145 | const Range& vrange = it->second; |
146 | if (is_one(vrange->extent)) { |
147 | up_state[iv] = IntSet::SinglePoint(vrange->min); |
148 | } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) { |
149 | ICHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " |
150 | << " call schedule.normalize to achieve this. " ; |
151 | if (ctx.bind_map.count(iv)) { |
152 | up_state[iv] = IntSet::SinglePoint(ctx.bind_map.at(iv)->var); |
153 | } else { |
154 | up_state[iv] = IntSet::SinglePoint(iv->var); |
155 | } |
156 | } else { |
157 | up_state[iv] = IntSet::FromRange(vrange); |
158 | } |
159 | } |
160 | // Consumer's attach nest |
161 | for (IterVar iv : ctx.attach_path.at(op)) { |
162 | if (stage_attach.size() != 0 && iv == stage_attach[0]) { |
163 | found_attach = true; |
164 | } |
165 | Range vrange = rmap->at(iv); |
166 | ICHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " |
167 | << "call schedule.normalize to achieve this." ; |
168 | if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { |
169 | relax_set.Set(iv->var, IntSet::FromRange(vrange)); |
170 | if (ctx.bind_map.count(iv)) { |
171 | relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::FromRange(vrange)); |
172 | } |
173 | } |
174 | } |
175 | ICHECK(found_attach || stage_attach.size() == 0) |
176 | << "Invalid Schedule, cannot find the producer " << stage->op |
177 | << " along the loop nest specified by compute_at of consumer " << op; |
178 | // Get the domain of the consumer |
179 | PassUpDomain(op_stage, *rmap, &up_state); |
180 | // Relax if needed. |
181 | std::unordered_map<const VarNode*, IntSet> dom_map; |
182 | arith::Analyzer analyzer; |
183 | for (auto entry : *rmap) { |
184 | analyzer.Bind(entry.first->var, entry.second); |
185 | } |
186 | for (auto iv : op->root_iter_vars()) { |
187 | Range r; |
188 | if (up_state.count(iv)) { |
189 | r = up_state.at(iv).CoverRange(iv->dom); |
190 | } else { |
191 | r = iv->dom; |
192 | } |
193 | if (relax_set.size() != 0) { |
194 | dom_map[iv->var.get()] = |
195 | IntSet::Interval(analyzer.int_set(r->min, relax_set).min(), |
196 | analyzer.int_set(r->min + r->extent - 1, relax_set).max()); |
197 | } else { |
198 | dom_map[iv->var.get()] = IntSet::FromRange(r); |
199 | } |
200 | analyzer.Bind(iv->var, r, true); |
201 | } |
202 | op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); |
203 | } |
204 | stage->op->GatherBound(stage->op, tmap, rmap); |
205 | } |
206 | |
207 | Map<IterVar, Range> InferBound(const Schedule& sch) { |
208 | // Prepare context |
209 | GraphContext ctx; |
210 | Array<Operation> roots; |
211 | arith::Analyzer analyzer; |
212 | |
213 | for (Operation op : sch->outputs) { |
214 | roots.push_back(sch->stage_map[op]->op); |
215 | } |
216 | ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots)); |
217 | |
218 | for (Stage stage : sch->stages) { |
219 | for (auto kv : stage->iter_var_attrs) { |
220 | if (kv.second->bind_thread.defined()) { |
221 | ICHECK(!ctx.bind_map.count(kv.first)); |
222 | ctx.bind_map[kv.first] = kv.second->bind_thread; |
223 | } |
224 | } |
225 | ctx.op2stage_[stage->op.get()] = stage; |
226 | } |
227 | ctx.attach_path = CreateAttachPath(sch); |
228 | // Run inference. |
229 | std::unordered_map<IterVar, Range> ret; |
230 | for (size_t i = sch->stages.size(); i != 0; --i) { |
231 | const Stage& stage = sch->stages[i - 1]; |
232 | InferRootBound(stage, ctx, &ret); |
233 | |
234 | // bind bound of root iter vars. |
235 | for (auto iv : stage->op->root_iter_vars()) { |
236 | auto it = ret.find(iv); |
237 | if (it != ret.end()) { |
238 | analyzer.Bind(iv->var, it->second); |
239 | } |
240 | } |
241 | |
242 | // pass down to get bound of all iter vars. |
243 | PassDownDomain(stage, &ret, &analyzer); |
244 | for (IterVar iv : stage->env_threads) { |
245 | ICHECK(iv->dom.defined()); |
246 | ret[iv] = iv->dom; |
247 | } |
248 | } |
249 | for (auto it = ret.begin(); it != ret.end(); it++) { |
250 | DataType var_type = it->first->var.dtype(); |
251 | it->second = Range::FromMinExtent( |
252 | // The range associated with each itervar must have the same dtype as the var |
253 | analyzer.Simplify(cast(var_type, it->second->min)), |
254 | analyzer.Simplify(cast(var_type, it->second->extent))); |
255 | } |
256 | return Map<IterVar, Range>(ret.begin(), ret.end()); |
257 | } |
258 | |
259 | TVM_REGISTER_GLOBAL("schedule.InferBound" ).set_body_typed(InferBound); |
260 | |
261 | } // namespace te |
262 | } // namespace tvm |
263 | |