1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file graph.cc
22 * \brief Utilities to get information about schedule graph.
23 */
24#include "graph.h"
25
26#include <tvm/runtime/registry.h>
27#include <tvm/te/operation.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt_functor.h>
30
31#include <unordered_map>
32#include <unordered_set>
33#include <utility>
34
35namespace tvm {
36namespace te {
37// key to specific tensor dimension.
38struct TensorDimKey {
39 Operation op;
40 int value_index;
41 int dim;
42 TensorDimKey() {}
43 TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
44 TensorDimKey(const Tensor& t, size_t dim)
45 : op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
46 inline bool operator==(const TensorDimKey& other) const {
47 return op == other.op && value_index == other.value_index && dim == other.dim;
48 }
49 inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
50};
51} // namespace te
52} // namespace tvm
53
54namespace std {
55template <>
56struct hash<::tvm::te::TensorDimKey> {
57 std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
58 size_t lhs = ::tvm::ObjectPtrHash()(k.op);
59 size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
60 lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
61 return lhs;
62 }
63};
64} // namespace std
65
66namespace tvm {
67namespace te {
68
69// construct a read graph that gives readers of each operation
70// that the root depend on
71ReadGraph CreateReadGraph(const Array<Operation>& roots) {
72 ReadGraph rmap;
73 std::vector<Operation> stack;
74 std::unordered_set<const Object*> visited;
75 // initialize the roots
76 for (Operation op : roots) {
77 stack.push_back(op);
78 visited.insert(op.get());
79 }
80
81 while (!stack.empty()) {
82 Operation op = stack.back();
83 stack.pop_back();
84 Array<Tensor> deps = op->InputTensors();
85 rmap.Set(op, deps);
86 for (Tensor t : deps) {
87 if (t->op.defined() && visited.count(t->op.get()) == 0) {
88 visited.insert(t->op.get());
89 stack.push_back(t->op);
90 }
91 }
92 }
93 return rmap;
94}
95
96// Do DFS visit to get the subgraph.
97// Return if op is inside the subgraph.
98bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set<const Object*>& boundary,
99 bool include_bounary, std::unordered_map<const Object*, bool>* visited,
100 Array<Operation>* result) {
101 if (visited->count(op.get())) {
102 return visited->at(op.get());
103 }
104 if (boundary.count(op.get())) {
105 (*visited)[op.get()] = true;
106 if (include_bounary) {
107 result->push_back(op);
108 }
109 return true;
110 }
111 // mark to avoid loop
112 // Not necessary for DAG.
113 (*visited)[op.get()] = false;
114 // check if we can reach boundary.
115 bool reach_boundary = false;
116 for (Tensor t : op->InputTensors()) {
117 if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) {
118 reach_boundary = true;
119 }
120 }
121 (*visited)[op.get()] = reach_boundary;
122 if (reach_boundary) {
123 result->push_back(op);
124 }
125 return reach_boundary;
126}
127
128Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
129 bool include_inputs) {
130 Array<Operation> result;
131 std::unordered_set<const Object*> boundary;
132 for (Tensor t : inputs) {
133 boundary.insert(t->op.get());
134 }
135 std::unordered_map<const Object*, bool> visited;
136 for (Tensor t : outputs) {
137 GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result);
138 }
139 return result;
140}
141
142void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set<Operation>* visited,
143 Array<Operation>* post_order) {
144 if (visited->count(op)) return;
145 visited->insert(op);
146 for (const auto& t : g.at(op)) {
147 PostDFSOrder(t->op, g, visited, post_order);
148 }
149 post_order->push_back(op);
150}
151
152Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g) {
153 std::unordered_set<Operation> visited;
154 Array<Operation> post_order;
155 for (Operation op : roots) {
156 PostDFSOrder(op, g, &visited, &post_order);
157 }
158 return post_order;
159}
160
161FeedGraph CreateFeedGraph(const ReadGraph& g) {
162 FeedGraph fg;
163 for (auto kv : g) {
164 for (Tensor t : kv.second) {
165 fg[t].push_back(kv.first);
166 }
167 }
168 return fg;
169}
170
171AttachPath CreateAttachPath(Schedule sch) {
172 AttachPath ret;
173 for (Stage stage : sch->stages) {
174 std::unordered_set<const Object*> visited;
175 Array<IterVar> path;
176 for (Stage s = stage; s.defined();) {
177 ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
178 visited.insert(s.get());
179 Stage spec = s.GetAttachSpec();
180 bool start_attach;
181 IterVar attach_ivar;
182 if (spec->attach_type == kScope) {
183 attach_ivar = spec->attach_ivar;
184 s = spec->attach_stage;
185 start_attach = false;
186 ICHECK(attach_ivar.defined());
187 } else if (spec->attach_type == kScanUpdate) {
188 s = spec->attach_stage;
189 start_attach = true;
190 } else {
191 break;
192 }
193 ICHECK(s.defined());
194 for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
195 IterVar iv = s->leaf_iter_vars[i - 1];
196 if (!start_attach && iv.same_as(attach_ivar)) {
197 start_attach = true;
198 }
199 if (start_attach) path.push_back(iv);
200 }
201 ICHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar
202 << " in the schedule of " << s->op;
203 }
204 if (!ret.count(stage->op)) {
205 ret.Set(stage->op, path);
206 }
207 }
208 return ret;
209}
210
211// graph of push reach relation of tensor dimensions
212using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey>>;
213
214ReachGraph GetReachGraph(const Array<Operation>& ops) {
215 ReachGraph reach;
216 std::unordered_set<const Object*> bset;
217 for (size_t i = 0; i < ops.size(); ++i) {
218 bset.insert(ops[i].get());
219 }
220
221 for (Operation op : ops) {
222 if (const auto* scan_op = op.as<ScanOpNode>()) {
223 const auto& update = scan_op->update;
224 const auto& init = scan_op->init;
225 for (size_t i = 0; i < update.size(); ++i) {
226 Tensor t = op.output(i);
227 for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
228 reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k));
229 reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k));
230 }
231 }
232 } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
233 std::unordered_map<const Object*, TensorDimKey> vmap;
234 const auto& axis = compute_op->axis;
235 Tensor t = op.output(0);
236 for (size_t i = 0; i < axis.size(); ++i) {
237 vmap[axis[i]->var.get()] = TensorDimKey(t, i);
238 reach[TensorDimKey(t, i)] = {};
239 }
240 auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
241 if (auto* pload = n.as<tir::ProducerLoadNode>()) {
242 Tensor t = Downcast<Tensor>(pload->producer);
243 if (!bset.count(t->op.get())) return;
244 for (size_t i = 0; i < pload->indices.size(); ++i) {
245 TensorDimKey dkey(t, static_cast<int>(i));
246 auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
247 const VarNode* v = node.as<VarNode>();
248 auto it = vmap.find(v);
249 if (it != vmap.end()) {
250 reach[it->second].push_back(dkey);
251 }
252 };
253 tir::PostOrderVisit(pload->indices[i], fpush);
254 }
255 }
256 };
257 for (auto& e : compute_op->body) {
258 tir::PostOrderVisit(e, fvisit);
259 }
260 }
261 }
262 return reach;
263}
264
265Array<Operation> ScanGetBody(const Operation& scan_op) {
266 const ScanOpNode* scan = scan_op.as<ScanOpNode>();
267 // Get the body.
268 Array<Tensor> inputs;
269 for (Tensor t : scan->state_placeholder) {
270 inputs.push_back(t);
271 }
272 for (Tensor t : scan->inputs) {
273 inputs.push_back(t);
274 }
275 return GetSubGraph(scan->update, inputs, false);
276}
277
278Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
279 const ScanOpNode* scan = scan_op.as<ScanOpNode>();
280 Array<Operation> body = ScanGetBody(scan_op);
281
282 std::unordered_map<TensorDimKey, const Object*> exact_reach;
283 std::unordered_set<const Object*> fail_set;
284
285 for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
286 for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
287 TensorDimKey key(scan->state_placeholder[i], k);
288 exact_reach[key] = scan->spatial_axis_[sp_idx].get();
289 }
290 }
291 // merge exact reach
292 auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) {
293 auto sit = exact_reach.find(src);
294 if (sit == exact_reach.end()) return;
295 auto dit = exact_reach.find(dst);
296 if (dit == exact_reach.end()) {
297 exact_reach[dst] = sit->second;
298 } else {
299 if (dit->second != sit->second) {
300 fail_set.insert(dit->second);
301 fail_set.insert(sit->second);
302 }
303 }
304 };
305 // prop exact reach back.
306 for (size_t i = 0; i < body.size(); ++i) {
307 const Operation& op = body[i];
308 if (const auto* scan_op = op.as<ScanOpNode>()) {
309 const auto& update = scan_op->update;
310 const auto& init = scan_op->init;
311 for (size_t i = 0; i < update.size(); ++i) {
312 Tensor t = op.output(i);
313 for (size_t k = 1; k < update[i]->shape.size(); ++k) {
314 f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
315 f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
316 }
317 }
318 } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
319 std::unordered_map<const Object*, std::vector<TensorDimKey>> vmap;
320 const auto& axis = compute_op->axis;
321 for (size_t i = 0; i < axis.size(); ++i) {
322 std::vector<TensorDimKey> keys;
323 for (int j = 0; j < op->num_outputs(); ++j) {
324 keys.emplace_back(op.output(j), i);
325 }
326 vmap[axis[i]->var.get()] = std::move(keys);
327 }
328 auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) {
329 if (auto* pload = n.as<tir::ProducerLoadNode>()) {
330 Tensor t = Downcast<Tensor>(pload->producer);
331 for (size_t i = 0; i < pload->indices.size(); ++i) {
332 auto it = vmap.find(pload->indices[i].get());
333 TensorDimKey src(t, static_cast<int>(i));
334 if (it != vmap.end()) {
335 const std::vector<TensorDimKey>& keys = it->second;
336 for (const auto& key : keys) {
337 f_merge_key(key, src);
338 }
339 } else {
340 if (exact_reach.count(src)) {
341 fail_set.insert(exact_reach.at(src));
342 }
343 }
344 }
345 }
346 };
347 for (auto& e : compute_op->body) {
348 tir::PostOrderVisit(e, fvisit);
349 }
350 }
351 }
352 ReachGraph reach;
353 Map<IterVar, PrimExpr> ret;
354 std::unordered_set<TensorDimKey> place_holder_ref;
355 for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
356 for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
357 place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
358 }
359 }
360
361 for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
362 for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
363 TensorDimKey key(scan->update[i], k);
364 TensorDimKey target(scan->state_placeholder[i], k);
365 IterVar sp_iv = scan->spatial_axis_[sp_idx];
366 if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) ||
367 exact_reach.at(key) != sp_iv.get()) {
368 ret.Set(sp_iv, make_const(DataType::Int(32), 0));
369 } else {
370 // now we proved exact match, need to prove no interference with other graph.
371 if (reach.size() == 0) reach = GetReachGraph(body);
372 // do a DFS
373 std::unordered_set<TensorDimKey> visited;
374 std::vector<TensorDimKey> stack{key};
375 visited.insert(key);
376 while (!stack.empty()) {
377 TensorDimKey k = stack.back();
378 if (k != target && place_holder_ref.count(k)) break;
379 stack.pop_back();
380 if (!reach.count(k)) {
381 LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
382 }
383
384 for (TensorDimKey kk : reach.at(k)) {
385 if (visited.count(kk)) {
386 continue;
387 }
388 visited.insert(kk);
389 stack.push_back(kk);
390 }
391 }
392 if (!stack.empty()) {
393 // failed the prove.
394 ret.Set(sp_iv, make_const(DataType::Int(32), 0));
395 } else {
396 ret.Set(sp_iv, make_const(DataType::Int(32), 1));
397 }
398 }
399 }
400 }
401 return ret;
402}
403
404TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph);
405
406TVM_REGISTER_GLOBAL("schedule.PostDFSOrder")
407 .set_body_typed([](const Array<Operation>& roots, const ReadGraph& g) {
408 return PostDFSOrder(roots, g);
409 });
410
411TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath);
412
413TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody);
414
415TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis);
416
417} // namespace te
418} // namespace tvm
419