1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/sub_graph.h
22 * \brief Represents a sub-graph of an overall Relay expression.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_
26#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_
27
28#include <tvm/ir/transform.h>
29#include <tvm/relay/op_attr_types.h>
30
31#include <string>
32#include <unordered_map>
33#include <utility>
34#include <vector>
35
36#include "../ir/dataflow_matcher_impl.h"
37#include "../ir/indexed_graph.h"
38#include "./dataflow_graph.h"
39#include "./index_set.h"
40
41namespace tvm {
42namespace relay {
43namespace collage {
44
45/*! \brief Returns operator pattern kind as single-letter string. */
46std::string KindToString(OpPatternKind kind);
47
48/*!
49 * \brief Returns a kind and label for the single \p sub_expr, ignoring its nested sub expressions.
50 */
51std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr);
52
53/*!
54 * \brief Returns a kind and label for all the nodes in \p inside.
55 */
56std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph,
57 const IndexSet& inside);
58
59/*!
60 * \brief Returns the index set representing all the sub-expression matched by \p matcher.
61 */
62IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher);
63
64/*!
65 * \brief Configuration controlling which sub-graphs are considered valid.
66 */
67struct SubGraphConfig {
68 /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */
69 size_t max_exits = 0;
70 /*!
71 * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside
72 * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs
73 * even with this flag false.
74 */
75 bool allow_taps = false;
76 /*!
77 * \brief Maximum allowed sub-graph depth, or zero if no-limit.
78 */
79 size_t max_depth = 0;
80
81 std::string ToString() const;
82};
83
84class SubGraph;
85using FunctionAttrsMap = Map<String, ObjectRef>;
86
87/*!
88 * \brief A nested sub-graph is a sub-graph which is to be nested inside a function as part of some
89 * enclosing sub-graph.
90 *
91 * Extraction yields a function with input nodes replaced by parameters and exit nodes in the
92 * function result. Rewriting replaces the sub-graph with a call to that function, and all
93 * outputs with (projections from) the call result.
94 *
95 * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class.
96 * However we found the implementation was easier to understand in this form since it makes
97 * the result of \p Extract unambiguous.)
98 */
99class NestedSubGraphNode : public Object {
100 public:
101 /*! \brief The nested sub-graph. */
102 ObjectRef /* actually SubGraph */ sub_graph_obj_;
103 /*! \brief Attributes (possibly empty) to attach to the extracted function. */
104 FunctionAttrsMap attrs_;
105
106 void VisitAttrs(AttrVisitor* v);
107
108 SubGraph sub_graph() const;
109
110 bool operator==(const NestedSubGraphNode& that) const;
111 bool operator!=(const NestedSubGraphNode& that) const { return !(*this == that); }
112 bool operator<(const NestedSubGraphNode& that) const;
113 size_t hash() const;
114
115 std::string ToString() const;
116
117 /*!
118 * \brief Returns the function representing this nested sub-graph within the overall expression
119 * represented by \p dataflow_graph:
120 * - All sub-graph inputs become parameters.
121 * - All sub-graph outputs become function results (either directly or as a field in a tuple).
122 * - The function has attrs_ for attributes (which may be empty).
123 * - The function body accounts for any rewrites implied by the nested sub-graph.
124 */
125 Function Extract(const DataflowGraph& dataflow_graph) const;
126
127 /*!
128 * \brief Returns \p expr rewritten to encode the partitioning implied by this nested sub-graph.
129 *
130 * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
131 * inside this nested sub-graph must correspond to nodes shared between \p dataflow_graph.expr()
132 * and \p expr. See \p SubGraph::ParallelRewrite below.
133 */
134 Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
135
136 static constexpr const char* _type_key = "relay.collage.NestedSubGraph";
137 TVM_DECLARE_FINAL_OBJECT_INFO(NestedSubGraphNode, Object);
138};
139
140class NestedSubGraph : public ObjectRef {
141 public:
142 NestedSubGraph(SubGraph sub_graph, FunctionAttrsMap attrs);
143
144 /*!
145 * \brief Returns copy of this nested sub-graph with all indexes substituted according to
146 * \p subst, whose range is w.r.t. \p new_dataflow_graph.
147 */
148 NestedSubGraph Subst(const DataflowGraph& new_dataflow_graph,
149 const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
150
151 /*!
152 * \brief Returns true if this can be safely unioned.
153 */
154 bool TriviallyUnionable(const NestedSubGraph& that) const;
155
156 /*!
157 * \brief Returns the disjoint union of this and \p that nested sub-graphs, which must agree on
158 * their attributes.
159 */
160 NestedSubGraph DisjointUnion(const DataflowGraph& dataflow_graph,
161 const NestedSubGraph& that) const;
162
163 /*!
164 * \brief Returns \p expr rewritten according to all the given nested sub-graphs. The
165 * nested sub-graphs can be given in any order, but must be disjoint.
166 *
167 * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
168 * inside the nested sub-graphs must correspond to nodes shared between \p dataflow_graph.expr()
169 * and \p expr. See \p SubGraph::ParallelRewrite below.
170 */
171 static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, const Expr& expr,
172 std::vector<NestedSubGraph> nested_sub_graphs);
173
174 TVM_DEFINE_OBJECT_REF_METHODS(NestedSubGraph, ObjectRef, NestedSubGraphNode);
175};
176
177using NestedSubGraphs = Array<NestedSubGraph>;
178
179/*!
180 * \brief A compact representation of a sub-graph within an (implied) overall Relay expression.
181 *
182 * Sub-graphs can be used to represent partitions/kernels/composite functions without having to
183 * pay the cost of constructing or rewriting any expressions. We also allow 'extracting' a
184 * function to use for measuring a partition/kernel's latency independently from 'rewriting'
185 * the overall Relay expression since only a tiny subset of candidate partitions will end up being
186 * needed after Collage has completed its search.
187 *
188 * We expect O(thousands) of sub-graphs to be in flight while processing a given model, so we are
189 * mindful of space overhead.
190 *
191 * A sub-graph classifies every dataflow node of the overall expression as either 'inside' or
192 * 'outside' the sub-graph. Obviously not all such divisions make sense, for example it is not
193 * valid for an inside node to feed into another inside node via outside nodes. We provide the
194 * \p IsValid method to check for validity, and \p SubGraphConfig to control which validity rules
195 * apply (such as maximum depth).
196 *
197 * We generally work with the \p DataflowGraph representation of the overall Relay expression
198 * rather than the expression itself. We use the post-dfs visit index to uniquely refer to
199 * expression nodes.
200 *
201 * As well as 'inside' and 'outside' we have four other flavors of dataflow nodes, all uniquely
202 * determined from the 'inside' nodes:
203 * - 'entry' nodes are those inside with at least one dataflow input outside.
204 * - 'exit' nodes are those inside with at least one dataflow output outside, or which
205 * are considered 'external' in the underlying dataflow graph (eg because they represent
206 * the result of the overall function).
207 * - 'input' nodes are those outside with at least one dataflow output inside.
208 * - 'output' nodes are those outside with at least one dataflow input inside.
209 * Index sets for these are cached with the sub-graph for performance.
210 *
211 * It is valid to have multiple entry nodes (we can bind a parameter for each). It may be valid to
212 * have multiple exit nodes (we can build a tuple of all such). It may be valid to have exit nodes
213 * which also contribute to other inside nodes (ie represent a 'tap' on an intermediate result).
214 *
215 * Sub-graphs are closed under:
216 * - Disjoint union.
217 * - Wrapping by a function with given attributes (see \p NestedSubGraph above). This can be used
218 * to encode "Composite" functions, or to represent a candidate kernel within a "Primitive"
219 * function. (By combining 'wrapping' with 'union' we can encode, eg, 'this sub-graph should
220 * be placed inside a primitive function which itself may have calls to composite functions).
221 * - Substitution, which allows a sub-graph w.r.t. one dataflow graph to be transformed to
222 * match some other (typically smaller) dataflow graph.
223 *
224 * See the subclasses of \p PartitionRule for how sub-graphs are built and combined during Collage
225 * search.
226 *
227 * To support some of the \p OpPatternKind-based fusion rule processing we give sub-graphs
228 * a kind, which is generally the maximum of the kinds of all the operator calls appearing
229 * inside it. We also given sub-graphs a (not necessarily unique) label to help debugging
230 * and guide the selection of global symbol names.
231 */
232class SubGraphNode : public Object {
233 public:
234 /*!
235 * \brief Which sub-expressions are inside the sub-graph (using their post-dfs indexes w.r.t.
236 * the implied DataflowGraph).
237 */
238 IndexSet inside_;
239
240 /*!
241 * \brief Index of first and last inside nodes.
242 *
243 * Cached for performance, uniquely determined by inside_.
244 */
245 PostDfsIndex first_inside_index_ = 0;
246 PostDfsIndex last_inside_index_ = 0;
247
248 /*!
249 * \brief Which sub-expressions are entry/exit/input/output for this sub-graph.
250 *
251 * Cached for performance, uniquely determined by inside_.
252 */
253 IndexSet entry_;
254 IndexSet exit_;
255 IndexSet input_;
256 IndexSet output_;
257
258 /*!
259 * \brief Maximum depth of any dataflow path from an entry to an output sub-expression.
260 *
261 * Cached for performance, uniquely determined by inside_.
262 */
263 size_t depth_ = 0;
264
265 /*!
266 * \brief The \p OpPatternKind summarizing the input/output behavior of the sub-graph.
267 *
268 * A sub-graph consisting of a single Relay expression node is given kind:
269 * - For Call to a Relay operator, the "TOpPattern" attribute of that operator (provided the
270 * call does not involve data-dependent dynamic shapes).
271 * - For Call to Relay Function, the "TOpPattern" attribute of the function (provided it has
272 * that attribute)
273 * - For Constants, \p kElemWise.
274 * - For Tuple and tuple projections, \p kInjective (provided all tuple fields are of tensor
275 * type)
276 * - All other nodes \p kOpaque.
277 * Sub-graphs with more than one node have the maximum of the kind of each node.
278 *
279 * Cached for performance, uniquely determined by inside_.
280 */
281 OpPatternKind kind_ = kOpaque;
282
283 /*!
284 * \brief A label for the sub-graph. Not guaranteed to be unique, but is a human-readable summary
285 * of the sub-graph which can help with debugging and guide the selection of global symbol names.
286 */
287 String label_;
288
289 /*!
290 * \brief Nested sub-graphs of this sub-graph which must be represented by functions. These must
291 * be disjoint, but it's ok for this sub-graph to have nodes not inside any nested sub-graph.
292 */
293 NestedSubGraphs nested_sub_graphs_;
294
295 void VisitAttrs(AttrVisitor* v);
296
297 // TODO(mbs): 'Anchor nodes' and rules for unioning them.
298 // In FuseOps it's just the unique kEWiseFusable node, if any.
299 // I'd like to allow writing vertical fusion rules, eg if two candidates are directly
300 // connected and have nn.conv2d anchors allow their join.
301 // I'd also like to allow horizontal fusion rules, eg if two candidates are not directly
302 // connected but could be joined without producing invalid (eg cyclic) and have nn.conv2d anchors
303 // then do so. Come back to this.
304
305 /*! \brief Number of nodes in overall dataflow graph. */
306 size_t overall_size() const { return inside_.end_index(); }
307
308 bool IsEmpty() const { return inside_.IsZero(); }
309
310 /*! \brief Number of nodes in sub-graph. */
311 size_t Size() const { return inside_.PopCount(); }
312
313 /*!
314 * \brief Returns the dataflow nodes downstream of all exit nodes.
315 */
316 IndexSet Downstream(const DataflowGraph& dataflow_graph) const;
317
318 /*!
319 * \brief Returns true if this sub-graph is valid. Ie:
320 * - no output of the sub-graph can flow to any input of the sub-graph (otherwise we'd end up
321 * with a dataflow cycle when we partition).
322 * - all inputs and outputs of the sub-graph are in the same scope, ie not separated by
323 * control flow (otherwise there'd be no consistent program point at which to eval the
324 * partitioned function).
325 * - no more than config.max_outputs outputs are required.
326 * - if config.allow_taps is false, no inside node has outputs to nodes both inside and
327 * outside the sub-graph.
328 */
329 bool IsValid(const DataflowGraph& dataflow_graph, const SubGraphConfig& config) const;
330
331 /*!
332 * \brief Returns this sub-graph extracted as a stand-alone function. The function will have
333 * no attributes, and is suitable for building and profiling by the \p CostEstimator.
334 */
335 Function ExtractAsFunction(const DataflowGraph& dataflow_graph) const;
336
337 /*!
338 * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-graph.
339 *
340 * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes
341 * inside this sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and
342 * \p expr. See \p SubGraph::ParallelRewrite below.
343 */
344 Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const;
345
346 std::string ToString() const;
347
348 bool operator==(const SubGraphNode& that) const;
349 bool operator!=(const SubGraphNode& that) const { return !(*this == that); }
350 bool operator<(const SubGraphNode& that) const;
351 size_t hash() const;
352
353 private:
354 /*! \brief Initialize the entry/exit/input/output sets given the inside and \p dataflow_graph. */
355 void Init(const DataflowGraph& dataflow_graph);
356
357 /*! \brief Calculates and returns the maximum path depth. */
358 size_t Depth(const DataflowGraph& dataflow_graph) const;
359
360 /*! \brief Returns true if any (input/output) of node is (outside/inside) the sub-graph. */
361 bool AnyInputOutside(const DataflowGraph::Node* node) const;
362 bool AnyInputInside(const DataflowGraph::Node* node) const;
363 bool AnyOutputOutside(const DataflowGraph::Node* node) const;
364 bool AnyOutputInside(const DataflowGraph::Node* node) const;
365
366 public:
367 static constexpr const char* _type_key = "relay.collage.SubGraph";
368 TVM_DECLARE_FINAL_OBJECT_INFO(SubGraphNode, Object);
369
370 friend class SubGraph;
371};
372
373class SubGraph : public ObjectRef {
374 public:
375 /*! \brief Primitive constructor. The following constructors are generally more convenient. */
376 SubGraph(const DataflowGraph& dataflow_graph, IndexSet inside, OpPatternKind kind = kOpaque,
377 String label = {}, std::vector<NestedSubGraph> nested_sub_graphs = {});
378
379 /*! \brief Constructs the empty sub-graph for \p dataflow_graph. */
380 explicit SubGraph(const DataflowGraph& dataflow_graph);
381
382 /*! \brief Returns true if this and that are disjoint. */
383 bool AreDisjoint(const SubGraph& that) const;
384
385 /*!
386 * \brief Returns true if:
387 * - \p this and \p that are disjoint, and
388 * - an output node of \p this coincides with an entry node of \p that, and
389 * - \p this and \p that are not obviously invalid after \p DisjointUnion
390 * (eg because such a sub-graph would produce a cycle).
391 * Note however that the \p DisjointUnion may not necessarily be valid even with the above
392 * checks.
393 */
394 bool AreTouching(const DataflowGraph& dataflow_graph, const SubGraph& that) const;
395
396 /*!
397 * \brief Returns true if:
398 * - all the outputs of \p this are entries for \p that, and
399 * - all the inputs of \p that are exits for \p this.
400 */
401 bool AreSelfContained(const SubGraph& that) const;
402
403 /*!
404 * \brief Returns disjoint union of this and \p that sub-graphs. The result may not be valid.
405 */
406 SubGraph DisjointUnion(const DataflowGraph& dataflow_graph, const SubGraph& that) const;
407
408 /*!
409 * \brief Returns copy of this sub-graph with all nodes placed inside a nested sub-graph with
410 * given attributes.
411 */
412 SubGraph WithAttrs(const DataflowGraph& dataflow_graph, FunctionAttrsMap attrs) const;
413
414 /*!
415 * \brief Returns copy of this sub-graph with all indexes substituted according to \p subst,
416 * whose range is w.r.t. \p new_dataflow_graph.
417 */
418 SubGraph Subst(const DataflowGraph& new_dataflow_graph,
419 const std::unordered_map<PostDfsIndex, PostDfsIndex>& subst) const;
420
421 /*!
422 * \brief Returns the root expression of \p dataflow_graph rewritten according to all the
423 * given sub-graphs. The sub-graphs can be given in any order, but must be disjoint.
424 */
425 static Expr ParallelRewrite(const DataflowGraph& dataflow_graph,
426 std::vector<SubGraph> sub_graphs);
427
428 TVM_DEFINE_OBJECT_REF_METHODS(SubGraph, ObjectRef, SubGraphNode);
429};
430
431struct SubGraphEqual {
432 bool operator()(const SubGraph& left, const SubGraph& right) const {
433 return *left.get() == *right.get();
434 }
435};
436
437struct SubGraphHash {
438 size_t operator()(const SubGraph& sub_graph) const { return sub_graph->hash(); }
439};
440
441/*!
442 * \brief Pass to partition every global function according to the post-dfs indexes
443 * given in an array. Visible for testing from Python only, would never make sense to use
444 * as a generic pass!
445 */
446tvm::transform::Pass PartitionOnIndexesForTesting(Array<Integer> indexes);
447
448} // namespace collage
449} // namespace relay
450} // namespace tvm
451
452#endif // TVM_RELAY_COLLAGE_SUB_GRAPH_H_
453