1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/tec_compiler_cache.h |
22 | * \brief Utilities for compiling tensor expressions inside of the Relay compiler. |
23 | */ |
24 | #ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ |
25 | #define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ |
26 | |
27 | #include <tvm/ir/name_supply.h> |
28 | #include <tvm/node/structural_equal.h> |
29 | #include <tvm/node/structural_hash.h> |
30 | #include <tvm/relay/analysis.h> |
31 | #include <tvm/relay/attrs/memory.h> |
32 | #include <tvm/relay/expr.h> |
33 | #include <tvm/relay/op_strategy.h> |
34 | #include <tvm/relay/transform.h> |
35 | #include <tvm/runtime/module.h> |
36 | #include <tvm/topi/elemwise.h> |
37 | |
38 | #include <functional> |
39 | #include <string> |
40 | #include <tuple> |
41 | #include <unordered_map> |
42 | #include <utility> |
43 | |
44 | #include "../transforms/infer_layout_utils.h" |
45 | |
46 | namespace tvm { |
47 | namespace relay { |
48 | namespace tec { |
49 | |
50 | /*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ |
51 | enum ShapeFuncParamState { |
52 | kNoNeed = 0, |
53 | kNeedInputData = 1, |
54 | kNeedInputShape = 2, |
55 | kNeedBoth = 3, |
56 | }; |
57 | |
58 | struct LoweredOutputNode : public Object { |
59 | /*! \brief The outputs to the function */ |
60 | tvm::Array<te::Tensor> outputs; |
61 | /*! \brief The implementation used to compute the output */ |
62 | OpImplementation implementation; |
63 | |
64 | void VisitAttrs(tvm::AttrVisitor* v) { |
65 | v->Visit("outputs" , &outputs); |
66 | v->Visit("implementation" , &implementation); |
67 | } |
68 | static constexpr const char* _type_key = "relay.LoweredOutput" ; |
69 | TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); |
70 | }; |
71 | |
72 | class LoweredOutput : public ObjectRef { |
73 | public: |
74 | TVM_DLL LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl); |
75 | |
76 | TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); |
77 | }; |
78 | |
79 | class CCacheKey; |
80 | /*! \brief Compile cache key */ |
81 | class CCacheKeyNode : public Object { |
82 | public: |
83 | /*! \brief The source function to be lowered. */ |
84 | Function source_func; |
85 | /*! \brief The hardware target.*/ |
86 | Target target; |
87 | /*! \brief The virtual device constrains.*/ |
88 | VirtualDevice virtual_device; |
89 | |
90 | void VisitAttrs(tvm::AttrVisitor* v) { |
91 | v->Visit("source_func" , &source_func); |
92 | v->Visit("target" , &target); |
93 | v->Visit("virtual_device" , &virtual_device); |
94 | } |
95 | /*! \return The hash value of CCacheKey. */ |
96 | inline size_t Hash() const; |
97 | /*! |
98 | * \brief check content equality |
99 | * \param other The other value. |
100 | * \return The result of equality check. |
101 | */ |
102 | inline bool Equal(const CCacheKeyNode* other) const; |
103 | |
104 | static constexpr const char* _type_key = "relay.CCacheKey" ; |
105 | TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object); |
106 | |
107 | private: |
108 | /*! |
109 | * \brief internal cached hash value. |
110 | */ |
111 | mutable size_t hash_{0}; |
112 | }; |
113 | |
114 | /*! \brief cache entry used in compile engine */ |
115 | class CCacheKey : public ObjectRef { |
116 | public: |
117 | CCacheKey() {} |
118 | explicit CCacheKey(ObjectPtr<Object> n) : ObjectRef(n) {} |
119 | |
120 | /*! |
121 | * \brief The constructor |
122 | * \param source_func The source function. |
123 | * \param target The target device. |
124 | */ |
125 | TVM_DLL CCacheKey(Function source_func, Target target, |
126 | VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained()); |
127 | |
128 | const CCacheKeyNode* operator->() const { return static_cast<const CCacheKeyNode*>(get()); } |
129 | // comparator |
130 | inline bool operator==(const CCacheKey& other) const { |
131 | ICHECK(defined() && other.defined()); |
132 | return (*this)->Equal(other.operator->()); |
133 | } |
134 | using ContainerType = CCacheKeyNode; |
135 | }; |
136 | |
137 | /*! \brief Node container to represent a cached function. */ |
138 | struct CachedFuncNode : public Object { |
139 | /*! \brief compiled target */ |
140 | tvm::Target target; |
141 | /*! \brief Primitive Function Name */ |
142 | GlobalVar prim_fn_var; |
143 | /*! \brief The inputs to the function */ |
144 | tvm::Array<te::Tensor> inputs; |
145 | /*! \brief The outputs to the function */ |
146 | tvm::Array<te::Tensor> outputs; |
147 | /*! \brief The schedule to the function */ |
148 | te::Schedule schedule; |
149 | /*! \brief The TIR function if lowering in the meta schedule path */ |
150 | Optional<tir::PrimFunc> prim_func; |
151 | /*! \brief Parameter usage states in the shape function. */ |
152 | tvm::Array<Integer> shape_func_param_states; |
153 | /*! \brief The lowered functions to support the function. */ |
154 | IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({})); |
155 | std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors; |
156 | |
157 | void VisitAttrs(tvm::AttrVisitor* v) { |
158 | v->Visit("target" , &target); |
159 | v->Visit("prim_fn_var" , &prim_fn_var); |
160 | v->Visit("inputs" , &inputs); |
161 | v->Visit("outputs" , &outputs); |
162 | v->Visit("schedule" , &schedule); |
163 | v->Visit("prim_func" , &prim_func); |
164 | v->Visit("funcs" , &funcs); |
165 | v->Visit("shape_func_param_states" , &shape_func_param_states); |
166 | } |
167 | |
168 | static constexpr const char* _type_key = "relay.CachedFunc" ; |
169 | TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object); |
170 | }; |
171 | |
172 | class CachedFunc : public ObjectRef { |
173 | public: |
174 | CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array<te::Tensor> inputs, |
175 | tvm::Array<te::Tensor> outputs, te::Schedule schedule, tir::PrimFunc prim_func, |
176 | tvm::Array<Integer> shape_func_param_states, |
177 | IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({})), |
178 | std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors = {}); |
179 | |
180 | public: |
181 | TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode); |
182 | }; |
183 | |
184 | /*! \brief Node container for compile cache. */ |
185 | class CCacheValueNode : public Object { |
186 | public: |
187 | /*! \brief The corresponding function */ |
188 | CachedFunc cached_func; |
189 | /*! \brief Result of Packed function generated by JIT */ |
190 | PackedFunc packed_func; |
191 | /*! \brief usage statistics */ |
192 | int use_count{0}; |
193 | |
194 | void VisitAttrs(tvm::AttrVisitor* v) { |
195 | v->Visit("cached_func" , &cached_func); |
196 | v->Visit("use_count" , &use_count); |
197 | } |
198 | static constexpr const char* _type_key = "relay.CCacheValue" ; |
199 | TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object); |
200 | }; |
201 | |
202 | /*! \brief cache entry used in compile engine */ |
203 | class CCacheValue : public ObjectRef { |
204 | public: |
205 | CCacheValue() {} |
206 | explicit CCacheValue(ObjectPtr<Object> n) : ObjectRef(n) {} |
207 | CCacheValueNode* operator->() { return static_cast<CCacheValueNode*>(get_mutable()); } |
208 | const CCacheValueNode* operator->() const { return static_cast<const CCacheValueNode*>(get()); } |
209 | using ContainerType = CCacheValueNode; |
210 | }; |
211 | |
212 | Array<IndexExpr> GetShape(const Array<IndexExpr>& shape); |
213 | |
214 | /*! |
215 | * \brief Lower Relay primitive Function to TE Compute |
216 | * \param source_func The primitive function to be lowered. |
217 | * \param target The compilation target. |
218 | * \param constant_name_supply A name supplier for constants |
219 | * across different invocations of this function. |
220 | * \param return_inputs If true, prepend input tensors to the output array of tensors. |
221 | * \return Tuple of the lowered TE compute, constant raw data, and fused function name. |
222 | */ |
223 | std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute( |
224 | const Function& source_func, Target target, NameSupply constant_name_supply, |
225 | bool return_inputs = true); |
226 | |
227 | /*! |
228 | * \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. |
229 | * \param relay_func The primitive function to be lowered. |
230 | * \param target The compilation target. |
231 | * \param constant_name_supply A name supplier for constants |
232 | * across different invocations of this function. |
233 | * \return A pair of the created prim func and the name of the fused function. |
234 | */ |
235 | std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func, |
236 | Target target, |
237 | NameSupply constant_name_supply); |
238 | |
239 | /*! |
240 | * \brief Create schedule for target. |
241 | * \param source_func The primitive function to be lowered. |
242 | * \param target The compilation target. |
243 | * \param global_var_supply A name supplier for global variables. |
244 | * \param constant_name_supply A name supplier for constants. |
245 | * \return Pair of schedule and cache. |
246 | * The funcs field in cache is not yet populated. |
247 | */ |
248 | CachedFunc PrimFuncFor(const Function& source_func, const Target& target, |
249 | GlobalVarSupply global_var_supply, NameSupply constant_name_supply); |
250 | |
251 | /*! \brief A specialization of PrimFuncFor, meant to be used when the names of constants do not |
252 | * matter. */ |
253 | inline CachedFunc PrimFuncFor(const Function& source_func, const Target& target) { |
254 | return PrimFuncFor(source_func, target, GlobalVarSupply(NameSupply("" )), NameSupply("" )); |
255 | } |
256 | |
257 | CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, |
258 | GlobalVarSupply global_var_supply); |
259 | |
260 | // implementations |
261 | inline size_t CCacheKeyNode::Hash() const { |
262 | if (hash_ != 0) return hash_; |
263 | // do structral hash, avoid 0. |
264 | hash_ = tvm::StructuralHash()(this->source_func); |
265 | hash_ = dmlc::HashCombine(hash_, std::hash<std::string>()(target->str())); |
266 | if (hash_ == 0) hash_ = 1; |
267 | return hash_; |
268 | } |
269 | |
270 | inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { |
271 | if (Hash() != other->Hash()) return false; |
272 | return this->target->str() == other->target->str() && |
273 | this->virtual_device == other->virtual_device && |
274 | tvm::StructuralEqual()(this->source_func, other->source_func); |
275 | } |
276 | |
277 | } // namespace tec |
278 | } // namespace relay |
279 | } // namespace tvm |
280 | |
281 | namespace std { |
282 | // overload hash |
283 | template <> |
284 | struct hash<::tvm::relay::tec::CCacheKey> { |
285 | size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const { |
286 | ICHECK(key.defined()); |
287 | return key->Hash(); |
288 | } |
289 | }; |
290 | } // namespace std |
291 | |
292 | #endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_ |
293 | |