1/*r
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/auto_scheduler/compute_dag.h
22 * \brief The auto-scheduler's computational graph and related program analyses.
23 *
24 * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
25 * subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
26 * some static analysis results for the DAG (e.g. the total float operation count, consumer/producer
27 * relations of operations, whether an operation stage should be tiled/compute inlined ...).
28 * These analyses can help the search policy to make decisions during the search.
29 * ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
30 * TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing
31 * `LoopState` with extra information got from TVM schedule ...).
32 */
33
34#ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
35#define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
36
37#include <tvm/auto_scheduler/loop_state.h>
38#include <tvm/runtime/c_runtime_api.h>
39#include <tvm/te/schedule.h>
40
41#include <unordered_map>
42#include <unordered_set>
43#include <utility>
44#include <vector>
45
46namespace tvm {
47namespace auto_scheduler {
48
49/*! \brief Static analyzer for a ComputeDAG */
50class AccessAnalyzerNode : public Object {
51 public:
52 template <class T>
53 using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
54
55 /*! \brief Map an operation to all operations it reads from.
56 * For each operation pair, use a two-dimensional array for multiple multi-dimensional accesses
57 * The inner vector represents the indices of multi-dimensional access.*/
58 OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
59 /*! \brief Map an operation to all operations it is read by.
60 * For each operation pair, use a two-dimensional array for multiple multi-dimensional accesses
61 * The inner vector represents the indices of multi-dimensional access.*/
62 OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
63 /*! \brief Store the number of common outer iterators for operation pairs that have
64 * read-write relations. */
65 OperationMap<OperationMap<int>> num_common_outer_iterators;
66 /*! \brief Store whether the operation is an op with only simple access.
67 * (e.g., injective, broadcast and elementwise ops without reduction) */
68 OperationMap<bool> is_simple_access;
69 /*! \brief Store whether the operation is strictly inlineable
70 * (e.g., injective, broadcast and elementwise without reduction, branch or expensive operations)
71 */
72 OperationMap<bool> is_strictly_inlineable;
73 /*! \brief Store whether the operation needs multi-level tiling
74 * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */
75 OperationMap<bool> needs_multi_level_tiling;
76 /*! \brief Store whether the operation is an output operation */
77 OperationMap<bool> is_output;
78 /*! \brief Store the topological order of operations */
79 Array<te::Operation> ops_topo_order;
80
81 static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer";
82 TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object);
83};
84
85/*!
86 * \brief Managed reference to AccessAnalyzerNode.
87 * \sa AccessAnalyzerNode
88 */
89class AccessAnalyzer : public ObjectRef {
90 public:
91 explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
92
93 /*!
94 * \brief Return whether this operation is an op with simple access
95 * (e.g., injective, broadcast and elementwise ops without reduction)
96 * \param op The operation
97 */
98 TVM_DLL bool IsSimpleAccess(const te::Operation& op) const;
99
100 /*!
101 * \brief Return whether this operation is strictly inlineable
102 * (e.g., injective, broadcast and elementwise without reduction, branch or expensive operations)
103 * \param op The operation
104 */
105 TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const;
106
107 /*!
108 * \brief Return whether this operation needs multi-level tiling
109 * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d)
110 * \param op The operation
111 */
112 TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
113
114 /*!
115 * \brief Return whether this operation is an output operation
116 * \param op The operation
117 */
118 TVM_DLL bool IsOutput(const te::Operation& op) const;
119
120 /*!
121 * \brief Get all consumers of an operation
122 * \param state The current loop state
123 * \param op The operation
124 * \return The set of consumers
125 * \note This function propagates the relation for inlined ops
126 */
127 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers(
128 const State& state, const te::Operation& op) const;
129
130 /*!
131 * \brief Get all producers of an operation
132 * \param state The current loop state
133 * \param op The operation
134 * \return The set of producers
135 * \note This function propagates the relation for inlined ops
136 */
137 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers(
138 const State& state, const te::Operation& op) const;
139
140 /*!
141 * \brief Get all direct producers of an operation
142 * \param op The operation
143 * \return The set of direct producers
144 * \note This function DOES NOT propagate the relation for inlined ops
145 */
146 TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers(
147 const te::Operation& op) const;
148
149 /*!
150 * \brief Get the number of common outer iterators.
151 * \param op The operation
152 * \param target_op The target operation
153 * \note This function propagates the relation for chains with multiple ops.
154 */
155 TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op,
156 const te::Operation& target_op) const;
157
158 /*!
159 * \brief Return whether two operations are elementwise-matched
160 * (e.g. conv2d and relu are elementwise-matched)
161 * \note This function propagates the relation for chains with multiple ops.
162 */
163 TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
164
165 TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
166};
167
168/*! \brief The auto-scheduler's computational graph and related program analyses. */
169class ComputeDAGNode : public Object {
170 public:
171 /*!
172 * \brief Input and output tensors.
173 * This is used as the input of `tvm.lower` or `tvm.build`.
174 */
175 Array<te::Tensor> tensors;
176 /*! \brief All used operations in topo order. */
177 Array<te::Operation> ops;
178 /*! \brief The number of float operations in this ComputeDAG. */
179 double flop_ct;
180 /*! \brief The initial state without any transform steps. */
181 State init_state;
182 /*! \brief The static read-write access analyzer. */
183 AccessAnalyzer access_analyzer;
184
185 void VisitAttrs(tvm::AttrVisitor* v) {
186 v->Visit("tensors", &tensors);
187 v->Visit("ops", &ops);
188 v->Visit("flop_ct", &flop_ct);
189 v->Visit("init_state", &init_state);
190 v->Visit("access_analyzer", &access_analyzer);
191 }
192
193 static constexpr const char* _type_key = "auto_scheduler.ComputeDAG";
194 TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
195};
196
197/*!
198 * \brief Options for applying layout rewrite.
199 * This is an optimization to rewrite the layout of input tensors according to the schedule we get.
200 */
201enum class LayoutRewriteOption : int {
202 /*! \brief Do not perform layout rewrite. */
203 NoRewrite = 0,
204 /*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
205 InsertTransformStage = 1,
206 /*!
207 * \brief Do not insert layout transformation stages and assume the input placeholders
208 * are pre-transformed.
209 * \note The lowered function with this option does not accept the origial input shapes,
210 * so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay.
211 */
212 RewriteForPreTransformed = 2,
213};
214
215/*!
216 * \brief Managed reference to ComputeDAGNode.
217 * \sa ComputeDAGNode
218 */
219class ComputeDAG : public ObjectRef {
220 public:
221 /*! \brief Construct a DAG from a list of output tensors.
222 * \param tensors `te::Tensor`s for a compute declaration.
223 */
224 TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);
225
226 /*! \brief Construct a DAG based on a schedule.
227 * \param sch `te::Schedule`s for a compute declaration.
228 */
229 TVM_DLL explicit ComputeDAG(const te::Schedule& sch);
230
231 /*!
232 * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
233 * according to the loop nest derived with `transform_steps`.
234 * \param transform_steps Transform steps of a state.
235 * \param layout_rewrite Different options in layout rewrite.
236 * \return The updated ComputeDAG after layout rewrite.
237 */
238 ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;
239
240 /*!
241 * \brief Apply the history transform steps to get a TVM schedule.
242 * \param transform_steps Transform steps of a state.
243 * \param stages The list of stages after applying the steps.
244 * Pass a valid pointer if this information needs to be used outside this function.
245 * \param stage_to_axes The map that stores all axes for one stage.
246 * Pass a valid pointer if this information needs to be used outside this function.
247 * \param layout_rewrite Rewrite the layout of placeholders specified by
248 * attr `layout_free_placeholders`.
249 * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
250 * or `tvm.build`.
251 */
252 std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
253 const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
254 StageToAxesMap* stage_to_axes = nullptr,
255 LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;
256
257 /*!
258 * \brief Print transform steps as equivalent python schedule API.
259 * This can be used for debugging.
260 * \param transform_steps Transform steps of a state.
261 * \return The Python schedule code.
262 */
263 String PrintStepsAsPython(const Array<Step>& transform_steps) const;
264
265 /*!
266 * \brief Print the compute DAG to a string. This is also used to generate the ComputeDAG hash.
267 * \param simple_mode Simple mode will only include the op names and brief compute.
268 * \return The ComputeDAG in a string.
269 */
270 String PrintDAG(bool simple_mode = false) const;
271
272 /*!
273 * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
274 * The states can lose complete bound information after some transform steps (e.g., compute_at).
275 * We can call this function to infer and fill all the bound information.
276 * This function calls TVM InferBound pass internally to get the bound.
277 * The returned state of this function is guaranteed to have complete bound information.
278 * \param state The input state.
279 * \return The State with complete bound information
280 */
281 State InferBound(const State& state) const;
282
283 /*!
284 * \brief Fill the correct bound information for the given states by calling ir_pass::InferBound.
285 * The states can lose complete bound information after some transform steps (e.g., compute_at).
286 * We can call this function to infer and fill all the bound information.
287 * This function calls TVM InferBound pass internally to get the bound.
288 * The returned state of this function is guaranteed to have complete bound information.
289 * \param states The input states.
290 * \return The States with complete bound information.
291 * \note The returned array will contains empty State, if there're infer bound failure on some
292 * states.
293 */
294 Array<State> InferBound(const Array<State>& states) const;
295
296 /*!
297 * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
298 * ComputeDAG may not be up-to-date. This function replays the given transform steps from the
299 * initial state and returns an up-to-date ComputeDAG.
300 * \param steps The steps to be replayed. Usually we'll filter out the unused steps to speed up
301 * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage
302 * structure.
303 * \return The up-to-date ComputeDAG.
304 */
305 ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;
306
307 static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders";
308
309 TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
310 TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
311};
312
313/*!
314 * \brief Get the orginal shape from a rewritten layout string.
315 * \param rewritten_layout The layout after auto-scheduler's layout rewrite.
316 * \param axis_names Specifiy the names of axes.
317 * \return shape The original shape.
318 */
319Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names);
320
321} // namespace auto_scheduler
322} // namespace tvm
323
324#endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_
325