1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/tec_compiler_cache.h
22 * \brief Utilities for compiling tensor expressions inside of the Relay compiler.
23 */
24#ifndef TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
25#define TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
26
27#include <tvm/ir/name_supply.h>
28#include <tvm/node/structural_equal.h>
29#include <tvm/node/structural_hash.h>
30#include <tvm/relay/analysis.h>
31#include <tvm/relay/attrs/memory.h>
32#include <tvm/relay/expr.h>
33#include <tvm/relay/op_strategy.h>
34#include <tvm/relay/transform.h>
35#include <tvm/runtime/module.h>
36#include <tvm/topi/elemwise.h>
37
38#include <functional>
39#include <string>
40#include <tuple>
41#include <unordered_map>
42#include <utility>
43
44#include "../transforms/infer_layout_utils.h"
45
46namespace tvm {
47namespace relay {
48namespace tec {
49
50/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */
51enum ShapeFuncParamState {
52 kNoNeed = 0,
53 kNeedInputData = 1,
54 kNeedInputShape = 2,
55 kNeedBoth = 3,
56};
57
58struct LoweredOutputNode : public Object {
59 /*! \brief The outputs to the function */
60 tvm::Array<te::Tensor> outputs;
61 /*! \brief The implementation used to compute the output */
62 OpImplementation implementation;
63
64 void VisitAttrs(tvm::AttrVisitor* v) {
65 v->Visit("outputs", &outputs);
66 v->Visit("implementation", &implementation);
67 }
68 static constexpr const char* _type_key = "relay.LoweredOutput";
69 TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object);
70};
71
72class LoweredOutput : public ObjectRef {
73 public:
74 TVM_DLL LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl);
75
76 TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode);
77};
78
79class CCacheKey;
80/*! \brief Compile cache key */
81class CCacheKeyNode : public Object {
82 public:
83 /*! \brief The source function to be lowered. */
84 Function source_func;
85 /*! \brief The hardware target.*/
86 Target target;
87 /*! \brief The virtual device constrains.*/
88 VirtualDevice virtual_device;
89
90 void VisitAttrs(tvm::AttrVisitor* v) {
91 v->Visit("source_func", &source_func);
92 v->Visit("target", &target);
93 v->Visit("virtual_device", &virtual_device);
94 }
95 /*! \return The hash value of CCacheKey. */
96 inline size_t Hash() const;
97 /*!
98 * \brief check content equality
99 * \param other The other value.
100 * \return The result of equality check.
101 */
102 inline bool Equal(const CCacheKeyNode* other) const;
103
104 static constexpr const char* _type_key = "relay.CCacheKey";
105 TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object);
106
107 private:
108 /*!
109 * \brief internal cached hash value.
110 */
111 mutable size_t hash_{0};
112};
113
114/*! \brief cache entry used in compile engine */
115class CCacheKey : public ObjectRef {
116 public:
117 CCacheKey() {}
118 explicit CCacheKey(ObjectPtr<Object> n) : ObjectRef(n) {}
119
120 /*!
121 * \brief The constructor
122 * \param source_func The source function.
123 * \param target The target device.
124 */
125 TVM_DLL CCacheKey(Function source_func, Target target,
126 VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained());
127
128 const CCacheKeyNode* operator->() const { return static_cast<const CCacheKeyNode*>(get()); }
129 // comparator
130 inline bool operator==(const CCacheKey& other) const {
131 ICHECK(defined() && other.defined());
132 return (*this)->Equal(other.operator->());
133 }
134 using ContainerType = CCacheKeyNode;
135};
136
137/*! \brief Node container to represent a cached function. */
138struct CachedFuncNode : public Object {
139 /*! \brief compiled target */
140 tvm::Target target;
141 /*! \brief Primitive Function Name */
142 GlobalVar prim_fn_var;
143 /*! \brief The inputs to the function */
144 tvm::Array<te::Tensor> inputs;
145 /*! \brief The outputs to the function */
146 tvm::Array<te::Tensor> outputs;
147 /*! \brief The schedule to the function */
148 te::Schedule schedule;
149 /*! \brief The TIR function if lowering in the meta schedule path */
150 Optional<tir::PrimFunc> prim_func;
151 /*! \brief Parameter usage states in the shape function. */
152 tvm::Array<Integer> shape_func_param_states;
153 /*! \brief The lowered functions to support the function. */
154 IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({}));
155 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors;
156
157 void VisitAttrs(tvm::AttrVisitor* v) {
158 v->Visit("target", &target);
159 v->Visit("prim_fn_var", &prim_fn_var);
160 v->Visit("inputs", &inputs);
161 v->Visit("outputs", &outputs);
162 v->Visit("schedule", &schedule);
163 v->Visit("prim_func", &prim_func);
164 v->Visit("funcs", &funcs);
165 v->Visit("shape_func_param_states", &shape_func_param_states);
166 }
167
168 static constexpr const char* _type_key = "relay.CachedFunc";
169 TVM_DECLARE_FINAL_OBJECT_INFO(CachedFuncNode, Object);
170};
171
172class CachedFunc : public ObjectRef {
173 public:
174 CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array<te::Tensor> inputs,
175 tvm::Array<te::Tensor> outputs, te::Schedule schedule, tir::PrimFunc prim_func,
176 tvm::Array<Integer> shape_func_param_states,
177 IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({})),
178 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors = {});
179
180 public:
181 TVM_DEFINE_OBJECT_REF_METHODS(CachedFunc, ObjectRef, CachedFuncNode);
182};
183
184/*! \brief Node container for compile cache. */
185class CCacheValueNode : public Object {
186 public:
187 /*! \brief The corresponding function */
188 CachedFunc cached_func;
189 /*! \brief Result of Packed function generated by JIT */
190 PackedFunc packed_func;
191 /*! \brief usage statistics */
192 int use_count{0};
193
194 void VisitAttrs(tvm::AttrVisitor* v) {
195 v->Visit("cached_func", &cached_func);
196 v->Visit("use_count", &use_count);
197 }
198 static constexpr const char* _type_key = "relay.CCacheValue";
199 TVM_DECLARE_FINAL_OBJECT_INFO(CCacheValueNode, tvm::Object);
200};
201
202/*! \brief cache entry used in compile engine */
203class CCacheValue : public ObjectRef {
204 public:
205 CCacheValue() {}
206 explicit CCacheValue(ObjectPtr<Object> n) : ObjectRef(n) {}
207 CCacheValueNode* operator->() { return static_cast<CCacheValueNode*>(get_mutable()); }
208 const CCacheValueNode* operator->() const { return static_cast<const CCacheValueNode*>(get()); }
209 using ContainerType = CCacheValueNode;
210};
211
212Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);
213
214/*!
215 * \brief Lower Relay primitive Function to TE Compute
216 * \param source_func The primitive function to be lowered.
217 * \param target The compilation target.
218 * \param constant_name_supply A name supplier for constants
219 * across different invocations of this function.
220 * \param return_inputs If true, prepend input tensors to the output array of tensors.
221 * \return Tuple of the lowered TE compute, constant raw data, and fused function name.
222 */
223std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute(
224 const Function& source_func, Target target, NameSupply constant_name_supply,
225 bool return_inputs = true);
226
227/*!
228 * \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc.
229 * \param relay_func The primitive function to be lowered.
230 * \param target The compilation target.
231 * \param constant_name_supply A name supplier for constants
232 * across different invocations of this function.
233 * \return A pair of the created prim func and the name of the fused function.
234 */
235std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
236 Target target,
237 NameSupply constant_name_supply);
238
239/*!
240 * \brief Create schedule for target.
241 * \param source_func The primitive function to be lowered.
242 * \param target The compilation target.
243 * \param global_var_supply A name supplier for global variables.
244 * \param constant_name_supply A name supplier for constants.
245 * \return Pair of schedule and cache.
246 * The funcs field in cache is not yet populated.
247 */
248CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
249 GlobalVarSupply global_var_supply, NameSupply constant_name_supply);
250
251/*! \brief A specialization of PrimFuncFor, meant to be used when the names of constants do not
252 * matter. */
253inline CachedFunc PrimFuncFor(const Function& source_func, const Target& target) {
254 return PrimFuncFor(source_func, target, GlobalVarSupply(NameSupply("")), NameSupply(""));
255}
256
257CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
258 GlobalVarSupply global_var_supply);
259
260// implementations
261inline size_t CCacheKeyNode::Hash() const {
262 if (hash_ != 0) return hash_;
263 // do structral hash, avoid 0.
264 hash_ = tvm::StructuralHash()(this->source_func);
265 hash_ = dmlc::HashCombine(hash_, std::hash<std::string>()(target->str()));
266 if (hash_ == 0) hash_ = 1;
267 return hash_;
268}
269
270inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const {
271 if (Hash() != other->Hash()) return false;
272 return this->target->str() == other->target->str() &&
273 this->virtual_device == other->virtual_device &&
274 tvm::StructuralEqual()(this->source_func, other->source_func);
275}
276
277} // namespace tec
278} // namespace relay
279} // namespace tvm
280
281namespace std {
282// overload hash
283template <>
284struct hash<::tvm::relay::tec::CCacheKey> {
285 size_t operator()(const ::tvm::relay::tec::CCacheKey& key) const {
286 ICHECK(key.defined());
287 return key->Hash();
288 }
289};
290} // namespace std
291
292#endif // TVM_RELAY_BACKEND_TE_COMPILER_CACHE_H_
293