1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/sub_graph.h |
22 | * \brief Represents a sub-graph of an overall Relay expression. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_ |
26 | #define TVM_RELAY_COLLAGE_SUB_GRAPH_H_ |
27 | |
28 | #include <tvm/ir/transform.h> |
29 | #include <tvm/relay/op_attr_types.h> |
30 | |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | #include "../ir/dataflow_matcher_impl.h" |
37 | #include "../ir/indexed_graph.h" |
38 | #include "./dataflow_graph.h" |
39 | #include "./index_set.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | namespace collage { |
44 | |
45 | /*! \brief Returns operator pattern kind as single-letter string. */ |
46 | std::string KindToString(OpPatternKind kind); |
47 | |
48 | /*! |
49 | * \brief Returns a kind and label for the single \p sub_expr, ignoring its nested sub expressions. |
50 | */ |
51 | std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr); |
52 | |
53 | /*! |
54 | * \brief Returns a kind and label for all the nodes in \p inside. |
55 | */ |
56 | std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, |
57 | const IndexSet& inside); |
58 | |
59 | /*! |
60 | * \brief Returns the index set representing all the sub-expression matched by \p matcher. |
61 | */ |
62 | IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher); |
63 | |
64 | /*! |
65 | * \brief Configuration controlling which sub-graphs are considered valid. |
66 | */ |
67 | struct SubGraphConfig { |
68 | /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */ |
69 | size_t max_exits = 0; |
70 | /*! |
71 | * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside |
72 | * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs |
73 | * even with this flag false. |
74 | */ |
75 | bool allow_taps = false; |
76 | /*! |
77 | * \brief Maximum allowed sub-graph depth, or zero if no-limit. |
78 | */ |
79 | size_t max_depth = 0; |
80 | |
81 | std::string ToString() const; |
82 | }; |
83 | |
84 | class SubGraph; |
85 | using FunctionAttrsMap = Map<String, ObjectRef>; |
86 | |
87 | /*! |
88 | * \brief A nested sub-graph is a sub-graph which is to be nested inside a function as part of some |
89 | * enclosing sub-graph. |
90 | * |
91 | * Extraction yields a function with input nodes replaced by parameters and exit nodes in the |
92 | * function result. Rewriting replaces the sub-graph with a call to that function, and all |
93 | * outputs with (projections from) the call result. |
94 | * |
95 | * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class. |
96 | * However we found the implementation was easier to understand in this form since it makes |
97 | * the result of \p Extract unambiguous.) |
98 | */ |
99 | class NestedSubGraphNode : public Object { |
100 | public: |
101 | /*! \brief The nested sub-graph. */ |
102 | ObjectRef /* actually SubGraph */ sub_graph_obj_; |
103 | /*! \brief Attributes (possibly empty) to attach to the extracted function. */ |
104 | FunctionAttrsMap attrs_; |
105 | |
106 | void VisitAttrs(AttrVisitor* v); |
107 | |
108 | SubGraph sub_graph() const; |
109 | |
110 | bool operator==(const NestedSubGraphNode& that) const; |
111 | bool operator!=(const NestedSubGraphNode& that) const { return !(*this == that); } |
112 | bool operator<(const NestedSubGraphNode& that) const; |
113 | size_t hash() const; |
114 | |
115 | std::string ToString() const; |
116 | |
117 | /*! |
118 | * \brief Returns the function representing this nested sub-graph within the overall expression |
119 | * represented by \p dataflow_graph: |
120 | * - All sub-graph inputs become parameters. |
121 | * - All sub-graph outputs become function results (either directly or as a field in a tuple). |
122 | * - The function has attrs_ for attributes (which may be empty). |
123 | * - The function body accounts for any rewrites implied by the nested sub-graph. |
124 | */ |
125 | Function (const DataflowGraph& dataflow_graph) const; |
126 | |
127 | /*! |
128 | * \brief Returns \p expr rewritten to encode the partitioning implied by this nested sub-graph. |
129 | * |
130 | * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes |
131 | * inside this nested sub-graph must correspond to nodes shared between \p dataflow_graph.expr() |
132 | * and \p expr. See \p SubGraph::ParallelRewrite below. |
133 | */ |
134 | Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const; |
135 | |
136 | static constexpr const char* _type_key = "relay.collage.NestedSubGraph" ; |
137 | TVM_DECLARE_FINAL_OBJECT_INFO(NestedSubGraphNode, Object); |
138 | }; |
139 | |
140 | class NestedSubGraph : public ObjectRef { |
141 | public: |
142 | NestedSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs); |
143 | |
144 | /*! |
145 | * \brief Returns copy of this nested sub-graph with all indexes substituted according to |
146 | * \p subst, whose range is w.r.t. \p new_dataflow_graph. |
147 | */ |
148 | NestedSubGraph Subst(const DataflowGraph& new_dataflow_graph, |
149 | const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const; |
150 | |
151 | /*! |
152 | * \brief Returns true if this can be safely unioned. |
153 | */ |
154 | bool TriviallyUnionable(const NestedSubGraph& that) const; |
155 | |
156 | /*! |
157 | * \brief Returns the disjoint union of this and \p that nested sub-graphs, which must agree on |
158 | * their attributes. |
159 | */ |
160 | NestedSubGraph DisjointUnion(const DataflowGraph& dataflow_graph, |
161 | const NestedSubGraph& that) const; |
162 | |
163 | /*! |
164 | * \brief Returns \p expr rewritten according to all the given nested sub-graphs. The |
165 | * nested sub-graphs can be given in any order, but must be disjoint. |
166 | * |
167 | * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes |
168 | * inside the nested sub-graphs must correspond to nodes shared between \p dataflow_graph.expr() |
169 | * and \p expr. See \p SubGraph::ParallelRewrite below. |
170 | */ |
171 | static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr, |
172 | std::vector<NestedSubGraph> nested_sub_graphs); |
173 | |
174 | TVM_DEFINE_OBJECT_REF_METHODS(NestedSubGraph, ObjectRef, NestedSubGraphNode); |
175 | }; |
176 | |
177 | using NestedSubGraphs = Array<NestedSubGraph>; |
178 | |
179 | /*! |
180 | * \brief A compact representation of a sub-graph within an (implied) overall Relay expression. |
181 | * |
182 | * Sub-graphs can be used to represent partitions/kernels/composite functions without having to |
183 | * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a |
184 | * function to use for measuring a partition/kernel's latency independently from 'rewriting' |
185 | * the overall Relay expression since only a tiny subset of candidate partitions will end up being |
186 | * needed after Collage has completed its search. |
187 | * |
188 | * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so we are |
189 | * mindful of space overhead. |
190 | * |
191 | * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or |
192 | * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not |
193 | * valid for an inside node to feed into another inside node via outside nodes. We provide the |
194 | * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules |
195 | * apply (such as maximum depth). |
196 | * |
197 | * We generally work with the \p DataflowGraph representation of the overall Relay expression |
198 | * rather than the expression itself. We use the post-dfs visit index to uniquely refer to |
199 | * expression nodes. |
200 | * |
201 | * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely |
202 | * determined from the 'inside' nodes: |
203 | * - 'entry' nodes are those inside with at least one dataflow input outside. |
204 | * - 'exit' nodes are those inside with at least one dataflow output outside, or which |
205 | * are considered 'external' in the underlying dataflow graph (eg because they represent |
206 | * the result of the overall function). |
207 | * - 'input' nodes are those outside with at least one dataflow output inside. |
208 | * - 'output' nodes are those outside with at least one dataflow input inside. |
209 | * Index sets for these are cached with the sub-graph for performance. |
210 | * |
211 | * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to |
212 | * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes |
213 | * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result). |
214 | * |
215 | * Sub-graphs are closed under: |
216 | * - Disjoint union. |
217 | * - Wrapping by a function with given attributes (see \p NestedSubGraph above). This can be used |
218 | * to encode "Composite" functions, or to represent a candidate kernel within a "Primitive" |
219 | * function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should |
220 | * be placed inside a primitive function which itself may have calls to composite functions). |
221 | * - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to |
222 | * match some other (typically smaller) dataflow graph. |
223 | * |
224 | * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage |
225 | * search. |
226 | * |
227 | * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs |
228 | * a kind, which is generally the maximum of the kinds of all the operator calls appearing |
229 | * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging |
230 | * and guide the selection of global symbol names. |
231 | */ |
232 | class SubGraphNode : public Object { |
233 | public: |
234 | /*! |
235 | * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t. |
236 | * the implied DataflowGraph). |
237 | */ |
238 | IndexSet inside_; |
239 | |
240 | /*! |
241 | * \brief Index of first and last inside nodes. |
242 | * |
243 | * Cached for performance, uniquely determined by inside_. |
244 | */ |
245 | PostDfsIndex first_inside_index_ = 0; |
246 | PostDfsIndex last_inside_index_ = 0; |
247 | |
248 | /*! |
249 | * \brief Which sub-expressions are entry/exit/input/output for this sub-graph. |
250 | * |
251 | * Cached for performance, uniquely determined by inside_. |
252 | */ |
253 | IndexSet entry_; |
254 | IndexSet exit_; |
255 | IndexSet input_; |
256 | IndexSet output_; |
257 | |
258 | /*! |
259 | * \brief Maximum depth of any dataflow path from an entry to an output sub-expression. |
260 | * |
261 | * Cached for performance, uniquely determined by inside_. |
262 | */ |
263 | size_t depth_ = 0; |
264 | |
265 | /*! |
266 | * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph. |
267 | * |
268 | * A sub-graph consisting of a single Relay expression node is given kind: |
269 | * - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the |
270 | * call does not involve data-dependent dynamic shapes). |
271 | * - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has |
272 | * that attribute) |
273 | * - For Constants, \p kElemWise. |
274 | * - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor |
275 | * type) |
276 | * - All other nodes \p kOpaque. |
277 | * Sub-graphs with more than one node have the maximum of the kind of each node. |
278 | * |
279 | * Cached for performance, uniquely determined by inside_. |
280 | */ |
281 | OpPatternKind kind_ = kOpaque; |
282 | |
283 | /*! |
284 | * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary |
285 | * of the sub-graph which can help with debugging and guide the selection of global symbol names. |
286 | */ |
287 | String label_; |
288 | |
289 | /*! |
290 | * \brief Nested sub-graphs of this sub-graph which must be represented by functions. These must |
291 | * be disjoint, but it's ok for this sub-graph to have nodes not inside any nested sub-graph. |
292 | */ |
293 | NestedSubGraphs nested_sub_graphs_; |
294 | |
295 | void VisitAttrs(AttrVisitor* v); |
296 | |
297 | // TODO(mbs): 'Anchor nodes' and rules for unioning them. |
298 | // In FuseOps it's just the unique kEWiseFusable node, if any. |
299 | // I'd like to allow writing vertical fusion rules, eg if two candidates are directly |
300 | // connected and have nn.conv2d anchors allow their join. |
301 | // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly |
302 | // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors |
303 | // then do so. Come back to this. |
304 | |
305 | /*! \brief Number of nodes in overall dataflow graph. */ |
306 | size_t overall_size() const { return inside_.end_index(); } |
307 | |
308 | bool IsEmpty() const { return inside_.IsZero(); } |
309 | |
310 | /*! \brief Number of nodes in sub-graph. */ |
311 | size_t Size() const { return inside_.PopCount(); } |
312 | |
313 | /*! |
314 | * \brief Returns the dataflow nodes downstream of all exit nodes. |
315 | */ |
316 | IndexSet Downstream(const DataflowGraph& dataflow_graph) const; |
317 | |
318 | /*! |
319 | * \brief Returns true if this sub-graph is valid. Ie: |
320 | * - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up |
321 | * with a dataflow cycle when we partition). |
322 | * - all inputs and outputs of the sub-graph are in the same scope, ie not separated by |
323 | * control flow (otherwise there'd be no consistent program point at which to eval the |
324 | * partitioned function). |
325 | * - no more than config.max_outputs outputs are required. |
326 | * - if config.allow_taps is false, no inside node has outputs to nodes both inside and |
327 | * outside the sub-graph. |
328 | */ |
329 | bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const; |
330 | |
331 | /*! |
332 | * \brief Returns this sub-graph extracted as a stand-alone function. The function will have |
333 | * no attributes, and is suitable for building and profiling by the \p CostEstimator. |
334 | */ |
335 | Function (const DataflowGraph& dataflow_graph) const; |
336 | |
337 | /*! |
338 | * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph. |
339 | * |
340 | * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes |
341 | * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and |
342 | * \p expr. See \p SubGraph::ParallelRewrite below. |
343 | */ |
344 | Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const; |
345 | |
346 | std::string ToString() const; |
347 | |
348 | bool operator==(const SubGraphNode& that) const; |
349 | bool operator!=(const SubGraphNode& that) const { return !(*this == that); } |
350 | bool operator<(const SubGraphNode& that) const; |
351 | size_t hash() const; |
352 | |
353 | private: |
354 | /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */ |
355 | void Init(const DataflowGraph& dataflow_graph); |
356 | |
357 | /*! \brief Calculates and returns the maximum path depth. */ |
358 | size_t Depth(const DataflowGraph& dataflow_graph) const; |
359 | |
360 | /*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */ |
361 | bool AnyInputOutside(const DataflowGraph::Node* node) const; |
362 | bool AnyInputInside(const DataflowGraph::Node* node) const; |
363 | bool AnyOutputOutside(const DataflowGraph::Node* node) const; |
364 | bool AnyOutputInside(const DataflowGraph::Node* node) const; |
365 | |
366 | public: |
367 | static constexpr const char* _type_key = "relay.collage.SubGraph" ; |
368 | TVM_DECLARE_FINAL_OBJECT_INFO(SubGraphNode, Object); |
369 | |
370 | friend class SubGraph; |
371 | }; |
372 | |
373 | class SubGraph : public ObjectRef { |
374 | public: |
375 | /*! \brief Primitive constructor. The following constructors are generally more convenient. */ |
376 | SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind = kOpaque, |
377 | String label = {}, std::vector<NestedSubGraph> nested_sub_graphs = {}); |
378 | |
379 | /*! \brief Constructs the empty sub-graph for \p dataflow_graph. */ |
380 | explicit SubGraph(const DataflowGraph& dataflow_graph); |
381 | |
382 | /*! \brief Returns true if this and that are disjoint. */ |
383 | bool AreDisjoint(const SubGraph& that) const; |
384 | |
385 | /*! |
386 | * \brief Returns true if: |
387 | * - \p this and \p that are disjoint, and |
388 | * - an output node of \p this coincides with an entry node of \p that, and |
389 | * - \p this and \p that are not obviously invalid after \p DisjointUnion |
390 | * (eg because such a sub-graph would produce a cycle). |
391 | * Note however that the \p DisjointUnion may not necessarily be valid even with the above |
392 | * checks. |
393 | */ |
394 | bool AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const; |
395 | |
396 | /*! |
397 | * \brief Returns true if: |
398 | * - all the outputs of \p this are entries for \p that, and |
399 | * - all the inputs of \p that are exits for \p this. |
400 | */ |
401 | bool AreSelfContained(const SubGraph& that) const; |
402 | |
403 | /*! |
404 | * \brief Returns disjoint union of this and \p that sub-graphs. The result may not be valid. |
405 | */ |
406 | SubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const; |
407 | |
408 | /*! |
409 | * \brief Returns copy of this sub-graph with all nodes placed inside a nested sub-graph with |
410 | * given attributes. |
411 | */ |
412 | SubGraph WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const; |
413 | |
414 | /*! |
415 | * \brief Returns copy of this sub-graph with all indexes substituted according to \p subst, |
416 | * whose range is w.r.t. \p new_dataflow_graph. |
417 | */ |
418 | SubGraph Subst(const DataflowGraph& new_dataflow_graph, |
419 | const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const; |
420 | |
421 | /*! |
422 | * \brief Returns the root expression of \p dataflow_graph rewritten according to all the |
423 | * given sub-graphs. The sub-graphs can be given in any order, but must be disjoint. |
424 | */ |
425 | static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, |
426 | std::vector<SubGraph> sub_graphs); |
427 | |
428 | TVM_DEFINE_OBJECT_REF_METHODS(SubGraph, ObjectRef, SubGraphNode); |
429 | }; |
430 | |
431 | struct SubGraphEqual { |
432 | bool operator()(const SubGraph& left, const SubGraph& right) const { |
433 | return *left.get() == *right.get(); |
434 | } |
435 | }; |
436 | |
437 | struct SubGraphHash { |
438 | size_t operator()(const SubGraph& sub_graph) const { return sub_graph->hash(); } |
439 | }; |
440 | |
441 | /*! |
442 | * \brief Pass to partition every global function according to the post-dfs indexes |
443 | * given in an array. Visible for testing from Python only, would never make sense to use |
444 | * as a generic pass! |
445 | */ |
446 | tvm::transform::Pass PartitionOnIndexesForTesting(Array<Integer> indexes); |
447 | |
448 | } // namespace collage |
449 | } // namespace relay |
450 | } // namespace tvm |
451 | |
452 | #endif // TVM_RELAY_COLLAGE_SUB_GRAPH_H_ |
453 | |