1 | /*r |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/auto_scheduler/compute_dag.h |
22 | * \brief The auto-scheduler's computational graph and related program analyses. |
23 | * |
24 | * We convert a compute declaration described by `tvm.compute` (could be a single operator or a |
25 | * subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and |
26 | * some static analysis results for the DAG (e.g. the total float operation count, consumer/producer |
27 | * relations of operations, whether an operation stage should be tiled/compute inlined ...). |
28 | * These analyses can help the search policy to make decisions during the search. |
29 | * ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and |
30 | * TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing |
31 | * `LoopState` with extra information got from TVM schedule ...). |
32 | */ |
33 | |
34 | #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ |
35 | #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ |
36 | |
37 | #include <tvm/auto_scheduler/loop_state.h> |
38 | #include <tvm/runtime/c_runtime_api.h> |
39 | #include <tvm/te/schedule.h> |
40 | |
41 | #include <unordered_map> |
42 | #include <unordered_set> |
43 | #include <utility> |
44 | #include <vector> |
45 | |
46 | namespace tvm { |
47 | namespace auto_scheduler { |
48 | |
49 | /*! \brief Static analyzer for a ComputeDAG */ |
50 | class AccessAnalyzerNode : public Object { |
51 | public: |
52 | template <class T> |
53 | using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>; |
54 | |
55 | /*! \brief Map an operation to all operations it reads from. |
56 | * For each operation pair, use a two-dimensional array for multiple multi-dimensional accesses |
57 | * The inner vector represents the indices of multi-dimensional access.*/ |
58 | OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from; |
59 | /*! \brief Map an operation to all operations it is read by. |
60 | * For each operation pair, use a two-dimensional array for multiple multi-dimensional accesses |
61 | * The inner vector represents the indices of multi-dimensional access.*/ |
62 | OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by; |
63 | /*! \brief Store the number of common outer iterators for operation pairs that have |
64 | * read-write relations. */ |
65 | OperationMap<OperationMap<int>> num_common_outer_iterators; |
66 | /*! \brief Store whether the operation is an op with only simple access. |
67 | * (e.g., injective, broadcast and elementwise ops without reduction) */ |
68 | OperationMap<bool> is_simple_access; |
69 | /*! \brief Store whether the operation is strictly inlineable |
70 | * (e.g., injective, broadcast and elementwise without reduction, branch or expensive operations) |
71 | */ |
72 | OperationMap<bool> is_strictly_inlineable; |
73 | /*! \brief Store whether the operation needs multi-level tiling |
74 | * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ |
75 | OperationMap<bool> needs_multi_level_tiling; |
76 | /*! \brief Store whether the operation is an output operation */ |
77 | OperationMap<bool> is_output; |
78 | /*! \brief Store the topological order of operations */ |
79 | Array<te::Operation> ops_topo_order; |
80 | |
81 | static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer" ; |
82 | TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); |
83 | }; |
84 | |
85 | /*! |
86 | * \brief Managed reference to AccessAnalyzerNode. |
87 | * \sa AccessAnalyzerNode |
88 | */ |
89 | class AccessAnalyzer : public ObjectRef { |
90 | public: |
91 | explicit AccessAnalyzer(const Array<te::Tensor>& tensors); |
92 | |
93 | /*! |
94 | * \brief Return whether this operation is an op with simple access |
95 | * (e.g., injective, broadcast and elementwise ops without reduction) |
96 | * \param op The operation |
97 | */ |
98 | TVM_DLL bool IsSimpleAccess(const te::Operation& op) const; |
99 | |
100 | /*! |
101 | * \brief Return whether this operation is strictly inlineable |
102 | * (e.g., injective, broadcast and elementwise without reduction, branch or expensive operations) |
103 | * \param op The operation |
104 | */ |
105 | TVM_DLL bool IsStrictlyInlineable(const te::Operation& op) const; |
106 | |
107 | /*! |
108 | * \brief Return whether this operation needs multi-level tiling |
109 | * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) |
110 | * \param op The operation |
111 | */ |
112 | TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; |
113 | |
114 | /*! |
115 | * \brief Return whether this operation is an output operation |
116 | * \param op The operation |
117 | */ |
118 | TVM_DLL bool IsOutput(const te::Operation& op) const; |
119 | |
120 | /*! |
121 | * \brief Get all consumers of an operation |
122 | * \param state The current loop state |
123 | * \param op The operation |
124 | * \return The set of consumers |
125 | * \note This function propagates the relation for inlined ops |
126 | */ |
127 | TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetConsumers( |
128 | const State& state, const te::Operation& op) const; |
129 | |
130 | /*! |
131 | * \brief Get all producers of an operation |
132 | * \param state The current loop state |
133 | * \param op The operation |
134 | * \return The set of producers |
135 | * \note This function propagates the relation for inlined ops |
136 | */ |
137 | TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetProducers( |
138 | const State& state, const te::Operation& op) const; |
139 | |
140 | /*! |
141 | * \brief Get all direct producers of an operation |
142 | * \param op The operation |
143 | * \return The set of direct producers |
144 | * \note This function DOES NOT propagate the relation for inlined ops |
145 | */ |
146 | TVM_DLL std::unordered_set<te::Operation, ObjectHash, ObjectEqual> GetDirectProducers( |
147 | const te::Operation& op) const; |
148 | |
149 | /*! |
150 | * \brief Get the number of common outer iterators. |
151 | * \param op The operation |
152 | * \param target_op The target operation |
153 | * \note This function propagates the relation for chains with multiple ops. |
154 | */ |
155 | TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op, |
156 | const te::Operation& target_op) const; |
157 | |
158 | /*! |
159 | * \brief Return whether two operations are elementwise-matched |
160 | * (e.g. conv2d and relu are elementwise-matched) |
161 | * \note This function propagates the relation for chains with multiple ops. |
162 | */ |
163 | TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; |
164 | |
165 | TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); |
166 | }; |
167 | |
168 | /*! \brief The auto-scheduler's computational graph and related program analyses. */ |
169 | class ComputeDAGNode : public Object { |
170 | public: |
171 | /*! |
172 | * \brief Input and output tensors. |
173 | * This is used as the input of `tvm.lower` or `tvm.build`. |
174 | */ |
175 | Array<te::Tensor> tensors; |
176 | /*! \brief All used operations in topo order. */ |
177 | Array<te::Operation> ops; |
178 | /*! \brief The number of float operations in this ComputeDAG. */ |
179 | double flop_ct; |
180 | /*! \brief The initial state without any transform steps. */ |
181 | State init_state; |
182 | /*! \brief The static read-write access analyzer. */ |
183 | AccessAnalyzer access_analyzer; |
184 | |
185 | void VisitAttrs(tvm::AttrVisitor* v) { |
186 | v->Visit("tensors" , &tensors); |
187 | v->Visit("ops" , &ops); |
188 | v->Visit("flop_ct" , &flop_ct); |
189 | v->Visit("init_state" , &init_state); |
190 | v->Visit("access_analyzer" , &access_analyzer); |
191 | } |
192 | |
193 | static constexpr const char* _type_key = "auto_scheduler.ComputeDAG" ; |
194 | TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); |
195 | }; |
196 | |
197 | /*! |
198 | * \brief Options for applying layout rewrite. |
199 | * This is an optimization to rewrite the layout of input tensors according to the schedule we get. |
200 | */ |
201 | enum class LayoutRewriteOption : int { |
202 | /*! \brief Do not perform layout rewrite. */ |
203 | NoRewrite = 0, |
204 | /*! \brief Insert layout transformation stages for input placeholders in the compute DAG */ |
205 | InsertTransformStage = 1, |
206 | /*! |
207 | * \brief Do not insert layout transformation stages and assume the input placeholders |
208 | * are pre-transformed. |
209 | * \note The lowered function with this option does not accept the origial input shapes, |
210 | * so this option must be used along with `AutoSchedulerLayoutRewrite` pass in Relay. |
211 | */ |
212 | RewriteForPreTransformed = 2, |
213 | }; |
214 | |
215 | /*! |
216 | * \brief Managed reference to ComputeDAGNode. |
217 | * \sa ComputeDAGNode |
218 | */ |
219 | class ComputeDAG : public ObjectRef { |
220 | public: |
221 | /*! \brief Construct a DAG from a list of output tensors. |
222 | * \param tensors `te::Tensor`s for a compute declaration. |
223 | */ |
224 | TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors); |
225 | |
226 | /*! \brief Construct a DAG based on a schedule. |
227 | * \param sch `te::Schedule`s for a compute declaration. |
228 | */ |
229 | TVM_DLL explicit ComputeDAG(const te::Schedule& sch); |
230 | |
231 | /*! |
232 | * \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders` |
233 | * according to the loop nest derived with `transform_steps`. |
234 | * \param transform_steps Transform steps of a state. |
235 | * \param layout_rewrite Different options in layout rewrite. |
236 | * \return The updated ComputeDAG after layout rewrite. |
237 | */ |
238 | ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const; |
239 | |
240 | /*! |
241 | * \brief Apply the history transform steps to get a TVM schedule. |
242 | * \param transform_steps Transform steps of a state. |
243 | * \param stages The list of stages after applying the steps. |
244 | * Pass a valid pointer if this information needs to be used outside this function. |
245 | * \param stage_to_axes The map that stores all axes for one stage. |
246 | * Pass a valid pointer if this information needs to be used outside this function. |
247 | * \param layout_rewrite Rewrite the layout of placeholders specified by |
248 | * attr `layout_free_placeholders`. |
249 | * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` |
250 | * or `tvm.build`. |
251 | */ |
252 | std::pair<te::Schedule, Array<te::Tensor>> ApplySteps( |
253 | const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr, |
254 | StageToAxesMap* stage_to_axes = nullptr, |
255 | LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const; |
256 | |
257 | /*! |
258 | * \brief Print transform steps as equivalent python schedule API. |
259 | * This can be used for debugging. |
260 | * \param transform_steps Transform steps of a state. |
261 | * \return The Python schedule code. |
262 | */ |
263 | String PrintStepsAsPython(const Array<Step>& transform_steps) const; |
264 | |
265 | /*! |
266 | * \brief Print the compute DAG to a string. This is also used to generate the ComputeDAG hash. |
267 | * \param simple_mode Simple mode will only include the op names and brief compute. |
268 | * \return The ComputeDAG in a string. |
269 | */ |
270 | String PrintDAG(bool simple_mode = false) const; |
271 | |
272 | /*! |
273 | * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. |
274 | * The states can lose complete bound information after some transform steps (e.g., compute_at). |
275 | * We can call this function to infer and fill all the bound information. |
276 | * This function calls TVM InferBound pass internally to get the bound. |
277 | * The returned state of this function is guaranteed to have complete bound information. |
278 | * \param state The input state. |
279 | * \return The State with complete bound information |
280 | */ |
281 | State InferBound(const State& state) const; |
282 | |
283 | /*! |
284 | * \brief Fill the correct bound information for the given states by calling ir_pass::InferBound. |
285 | * The states can lose complete bound information after some transform steps (e.g., compute_at). |
286 | * We can call this function to infer and fill all the bound information. |
287 | * This function calls TVM InferBound pass internally to get the bound. |
288 | * The returned state of this function is guaranteed to have complete bound information. |
289 | * \param states The input states. |
290 | * \return The States with complete bound information. |
291 | * \note The returned array will contains empty State, if there're infer bound failure on some |
292 | * states. |
293 | */ |
294 | Array<State> InferBound(const Array<State>& states) const; |
295 | |
296 | /*! |
297 | * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial |
298 | * ComputeDAG may not be up-to-date. This function replays the given transform steps from the |
299 | * initial state and returns an up-to-date ComputeDAG. |
300 | * \param steps The steps to be replayed. Usually we'll filter out the unused steps to speed up |
301 | * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage |
302 | * structure. |
303 | * \return The up-to-date ComputeDAG. |
304 | */ |
305 | ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const; |
306 | |
307 | static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders" ; |
308 | |
309 | TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); |
310 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); |
311 | }; |
312 | |
313 | /*! |
314 | * \brief Get the orginal shape from a rewritten layout string. |
315 | * \param rewritten_layout The layout after auto-scheduler's layout rewrite. |
316 | * \param axis_names Specifiy the names of axes. |
317 | * \return shape The original shape. |
318 | */ |
319 | Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names); |
320 | |
321 | } // namespace auto_scheduler |
322 | } // namespace tvm |
323 | |
324 | #endif // TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ |
325 | |