1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file graph.cc |
22 | * \brief Utilities to get information about schedule graph. |
23 | */ |
24 | #include "graph.h" |
25 | |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | #include <utility> |
34 | |
35 | namespace tvm { |
36 | namespace te { |
37 | // key to specific tensor dimension. |
38 | struct TensorDimKey { |
39 | Operation op; |
40 | int value_index; |
41 | int dim; |
42 | TensorDimKey() {} |
43 | TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {} |
44 | TensorDimKey(const Tensor& t, size_t dim) |
45 | : op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {} |
46 | inline bool operator==(const TensorDimKey& other) const { |
47 | return op == other.op && value_index == other.value_index && dim == other.dim; |
48 | } |
49 | inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); } |
50 | }; |
51 | } // namespace te |
52 | } // namespace tvm |
53 | |
54 | namespace std { |
55 | template <> |
56 | struct hash<::tvm::te::TensorDimKey> { |
57 | std::size_t operator()(const ::tvm::te::TensorDimKey& k) const { |
58 | size_t lhs = ::tvm::ObjectPtrHash()(k.op); |
59 | size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim); |
60 | lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); |
61 | return lhs; |
62 | } |
63 | }; |
64 | } // namespace std |
65 | |
66 | namespace tvm { |
67 | namespace te { |
68 | |
69 | // construct a read graph that gives readers of each operation |
70 | // that the root depend on |
71 | ReadGraph CreateReadGraph(const Array<Operation>& roots) { |
72 | ReadGraph rmap; |
73 | std::vector<Operation> stack; |
74 | std::unordered_set<const Object*> visited; |
75 | // initialize the roots |
76 | for (Operation op : roots) { |
77 | stack.push_back(op); |
78 | visited.insert(op.get()); |
79 | } |
80 | |
81 | while (!stack.empty()) { |
82 | Operation op = stack.back(); |
83 | stack.pop_back(); |
84 | Array<Tensor> deps = op->InputTensors(); |
85 | rmap.Set(op, deps); |
86 | for (Tensor t : deps) { |
87 | if (t->op.defined() && visited.count(t->op.get()) == 0) { |
88 | visited.insert(t->op.get()); |
89 | stack.push_back(t->op); |
90 | } |
91 | } |
92 | } |
93 | return rmap; |
94 | } |
95 | |
96 | // Do DFS visit to get the subgraph. |
97 | // Return if op is inside the subgraph. |
98 | bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set<const Object*>& boundary, |
99 | bool include_bounary, std::unordered_map<const Object*, bool>* visited, |
100 | Array<Operation>* result) { |
101 | if (visited->count(op.get())) { |
102 | return visited->at(op.get()); |
103 | } |
104 | if (boundary.count(op.get())) { |
105 | (*visited)[op.get()] = true; |
106 | if (include_bounary) { |
107 | result->push_back(op); |
108 | } |
109 | return true; |
110 | } |
111 | // mark to avoid loop |
112 | // Not necessary for DAG. |
113 | (*visited)[op.get()] = false; |
114 | // check if we can reach boundary. |
115 | bool reach_boundary = false; |
116 | for (Tensor t : op->InputTensors()) { |
117 | if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) { |
118 | reach_boundary = true; |
119 | } |
120 | } |
121 | (*visited)[op.get()] = reach_boundary; |
122 | if (reach_boundary) { |
123 | result->push_back(op); |
124 | } |
125 | return reach_boundary; |
126 | } |
127 | |
128 | Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs, |
129 | bool include_inputs) { |
130 | Array<Operation> result; |
131 | std::unordered_set<const Object*> boundary; |
132 | for (Tensor t : inputs) { |
133 | boundary.insert(t->op.get()); |
134 | } |
135 | std::unordered_map<const Object*, bool> visited; |
136 | for (Tensor t : outputs) { |
137 | GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); |
138 | } |
139 | return result; |
140 | } |
141 | |
142 | void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set<Operation>* visited, |
143 | Array<Operation>* post_order) { |
144 | if (visited->count(op)) return; |
145 | visited->insert(op); |
146 | for (const auto& t : g.at(op)) { |
147 | PostDFSOrder(t->op, g, visited, post_order); |
148 | } |
149 | post_order->push_back(op); |
150 | } |
151 | |
152 | Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g) { |
153 | std::unordered_set<Operation> visited; |
154 | Array<Operation> post_order; |
155 | for (Operation op : roots) { |
156 | PostDFSOrder(op, g, &visited, &post_order); |
157 | } |
158 | return post_order; |
159 | } |
160 | |
161 | FeedGraph CreateFeedGraph(const ReadGraph& g) { |
162 | FeedGraph fg; |
163 | for (auto kv : g) { |
164 | for (Tensor t : kv.second) { |
165 | fg[t].push_back(kv.first); |
166 | } |
167 | } |
168 | return fg; |
169 | } |
170 | |
171 | AttachPath CreateAttachPath(Schedule sch) { |
172 | AttachPath ret; |
173 | for (Stage stage : sch->stages) { |
174 | std::unordered_set<const Object*> visited; |
175 | Array<IterVar> path; |
176 | for (Stage s = stage; s.defined();) { |
177 | ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group" ; |
178 | visited.insert(s.get()); |
179 | Stage spec = s.GetAttachSpec(); |
180 | bool start_attach; |
181 | IterVar attach_ivar; |
182 | if (spec->attach_type == kScope) { |
183 | attach_ivar = spec->attach_ivar; |
184 | s = spec->attach_stage; |
185 | start_attach = false; |
186 | ICHECK(attach_ivar.defined()); |
187 | } else if (spec->attach_type == kScanUpdate) { |
188 | s = spec->attach_stage; |
189 | start_attach = true; |
190 | } else { |
191 | break; |
192 | } |
193 | ICHECK(s.defined()); |
194 | for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { |
195 | IterVar iv = s->leaf_iter_vars[i - 1]; |
196 | if (!start_attach && iv.same_as(attach_ivar)) { |
197 | start_attach = true; |
198 | } |
199 | if (start_attach) path.push_back(iv); |
200 | } |
201 | ICHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar |
202 | << " in the schedule of " << s->op; |
203 | } |
204 | if (!ret.count(stage->op)) { |
205 | ret.Set(stage->op, path); |
206 | } |
207 | } |
208 | return ret; |
209 | } |
210 | |
211 | // graph of push reach relation of tensor dimensions |
212 | using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey>>; |
213 | |
214 | ReachGraph GetReachGraph(const Array<Operation>& ops) { |
215 | ReachGraph reach; |
216 | std::unordered_set<const Object*> bset; |
217 | for (size_t i = 0; i < ops.size(); ++i) { |
218 | bset.insert(ops[i].get()); |
219 | } |
220 | |
221 | for (Operation op : ops) { |
222 | if (const auto* scan_op = op.as<ScanOpNode>()) { |
223 | const auto& update = scan_op->update; |
224 | const auto& init = scan_op->init; |
225 | for (size_t i = 0; i < update.size(); ++i) { |
226 | Tensor t = op.output(i); |
227 | for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) { |
228 | reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k)); |
229 | reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k)); |
230 | } |
231 | } |
232 | } else if (const auto* compute_op = op.as<ComputeOpNode>()) { |
233 | std::unordered_map<const Object*, TensorDimKey> vmap; |
234 | const auto& axis = compute_op->axis; |
235 | Tensor t = op.output(0); |
236 | for (size_t i = 0; i < axis.size(); ++i) { |
237 | vmap[axis[i]->var.get()] = TensorDimKey(t, i); |
238 | reach[TensorDimKey(t, i)] = {}; |
239 | } |
240 | auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { |
241 | if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
242 | Tensor t = Downcast<Tensor>(pload->producer); |
243 | if (!bset.count(t->op.get())) return; |
244 | for (size_t i = 0; i < pload->indices.size(); ++i) { |
245 | TensorDimKey dkey(t, static_cast<int>(i)); |
246 | auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { |
247 | const VarNode* v = node.as<VarNode>(); |
248 | auto it = vmap.find(v); |
249 | if (it != vmap.end()) { |
250 | reach[it->second].push_back(dkey); |
251 | } |
252 | }; |
253 | tir::PostOrderVisit(pload->indices[i], fpush); |
254 | } |
255 | } |
256 | }; |
257 | for (auto& e : compute_op->body) { |
258 | tir::PostOrderVisit(e, fvisit); |
259 | } |
260 | } |
261 | } |
262 | return reach; |
263 | } |
264 | |
265 | Array<Operation> ScanGetBody(const Operation& scan_op) { |
266 | const ScanOpNode* scan = scan_op.as<ScanOpNode>(); |
267 | // Get the body. |
268 | Array<Tensor> inputs; |
269 | for (Tensor t : scan->state_placeholder) { |
270 | inputs.push_back(t); |
271 | } |
272 | for (Tensor t : scan->inputs) { |
273 | inputs.push_back(t); |
274 | } |
275 | return GetSubGraph(scan->update, inputs, false); |
276 | } |
277 | |
278 | Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) { |
279 | const ScanOpNode* scan = scan_op.as<ScanOpNode>(); |
280 | Array<Operation> body = ScanGetBody(scan_op); |
281 | |
282 | std::unordered_map<TensorDimKey, const Object*> exact_reach; |
283 | std::unordered_set<const Object*> fail_set; |
284 | |
285 | for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { |
286 | for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { |
287 | TensorDimKey key(scan->state_placeholder[i], k); |
288 | exact_reach[key] = scan->spatial_axis_[sp_idx].get(); |
289 | } |
290 | } |
291 | // merge exact reach |
292 | auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) { |
293 | auto sit = exact_reach.find(src); |
294 | if (sit == exact_reach.end()) return; |
295 | auto dit = exact_reach.find(dst); |
296 | if (dit == exact_reach.end()) { |
297 | exact_reach[dst] = sit->second; |
298 | } else { |
299 | if (dit->second != sit->second) { |
300 | fail_set.insert(dit->second); |
301 | fail_set.insert(sit->second); |
302 | } |
303 | } |
304 | }; |
305 | // prop exact reach back. |
306 | for (size_t i = 0; i < body.size(); ++i) { |
307 | const Operation& op = body[i]; |
308 | if (const auto* scan_op = op.as<ScanOpNode>()) { |
309 | const auto& update = scan_op->update; |
310 | const auto& init = scan_op->init; |
311 | for (size_t i = 0; i < update.size(); ++i) { |
312 | Tensor t = op.output(i); |
313 | for (size_t k = 1; k < update[i]->shape.size(); ++k) { |
314 | f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k)); |
315 | f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k)); |
316 | } |
317 | } |
318 | } else if (const auto* compute_op = op.as<ComputeOpNode>()) { |
319 | std::unordered_map<const Object*, std::vector<TensorDimKey>> vmap; |
320 | const auto& axis = compute_op->axis; |
321 | for (size_t i = 0; i < axis.size(); ++i) { |
322 | std::vector<TensorDimKey> keys; |
323 | for (int j = 0; j < op->num_outputs(); ++j) { |
324 | keys.emplace_back(op.output(j), i); |
325 | } |
326 | vmap[axis[i]->var.get()] = std::move(keys); |
327 | } |
328 | auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) { |
329 | if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
330 | Tensor t = Downcast<Tensor>(pload->producer); |
331 | for (size_t i = 0; i < pload->indices.size(); ++i) { |
332 | auto it = vmap.find(pload->indices[i].get()); |
333 | TensorDimKey src(t, static_cast<int>(i)); |
334 | if (it != vmap.end()) { |
335 | const std::vector<TensorDimKey>& keys = it->second; |
336 | for (const auto& key : keys) { |
337 | f_merge_key(key, src); |
338 | } |
339 | } else { |
340 | if (exact_reach.count(src)) { |
341 | fail_set.insert(exact_reach.at(src)); |
342 | } |
343 | } |
344 | } |
345 | } |
346 | }; |
347 | for (auto& e : compute_op->body) { |
348 | tir::PostOrderVisit(e, fvisit); |
349 | } |
350 | } |
351 | } |
352 | ReachGraph reach; |
353 | Map<IterVar, PrimExpr> ret; |
354 | std::unordered_set<TensorDimKey> place_holder_ref; |
355 | for (size_t i = 0; i < scan->state_placeholder.size(); ++i) { |
356 | for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) { |
357 | place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k)); |
358 | } |
359 | } |
360 | |
361 | for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { |
362 | for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { |
363 | TensorDimKey key(scan->update[i], k); |
364 | TensorDimKey target(scan->state_placeholder[i], k); |
365 | IterVar sp_iv = scan->spatial_axis_[sp_idx]; |
366 | if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) || |
367 | exact_reach.at(key) != sp_iv.get()) { |
368 | ret.Set(sp_iv, make_const(DataType::Int(32), 0)); |
369 | } else { |
370 | // now we proved exact match, need to prove no interference with other graph. |
371 | if (reach.size() == 0) reach = GetReachGraph(body); |
372 | // do a DFS |
373 | std::unordered_set<TensorDimKey> visited; |
374 | std::vector<TensorDimKey> stack{key}; |
375 | visited.insert(key); |
376 | while (!stack.empty()) { |
377 | TensorDimKey k = stack.back(); |
378 | if (k != target && place_holder_ref.count(k)) break; |
379 | stack.pop_back(); |
380 | if (!reach.count(k)) { |
381 | LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim; |
382 | } |
383 | |
384 | for (TensorDimKey kk : reach.at(k)) { |
385 | if (visited.count(kk)) { |
386 | continue; |
387 | } |
388 | visited.insert(kk); |
389 | stack.push_back(kk); |
390 | } |
391 | } |
392 | if (!stack.empty()) { |
393 | // failed the prove. |
394 | ret.Set(sp_iv, make_const(DataType::Int(32), 0)); |
395 | } else { |
396 | ret.Set(sp_iv, make_const(DataType::Int(32), 1)); |
397 | } |
398 | } |
399 | } |
400 | } |
401 | return ret; |
402 | } |
403 | |
404 | TVM_REGISTER_GLOBAL("schedule.CreateReadGraph" ).set_body_typed(CreateReadGraph); |
405 | |
406 | TVM_REGISTER_GLOBAL("schedule.PostDFSOrder" ) |
407 | .set_body_typed([](const Array<Operation>& roots, const ReadGraph& g) { |
408 | return PostDFSOrder(roots, g); |
409 | }); |
410 | |
411 | TVM_REGISTER_GLOBAL("schedule.CreateAttachPath" ).set_body_typed(CreateAttachPath); |
412 | |
413 | TVM_REGISTER_GLOBAL("schedule.ScanGetBody" ).set_body_typed(ScanGetBody); |
414 | |
415 | TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis" ).set_body_typed(ScanFixPointAnalysis); |
416 | |
417 | } // namespace te |
418 | } // namespace tvm |
419 | |