1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/generic_func.h |
22 | * \brief Generic function that can be specialzied on a per target basis. |
23 | */ |
24 | #ifndef TVM_TARGET_GENERIC_FUNC_H_ |
25 | #define TVM_TARGET_GENERIC_FUNC_H_ |
26 | |
27 | #include <tvm/runtime/packed_func.h> |
28 | #include <tvm/support/with.h> |
29 | #include <tvm/target/target.h> |
30 | |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | namespace tvm { |
37 | |
38 | class GenericFuncNode; |
39 | |
40 | /*! |
41 | * \brief Generic function that can be specialized on a per-target basis. |
42 | */ |
43 | class GenericFunc : public ObjectRef { |
44 | public: |
45 | GenericFunc() {} |
46 | explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {} |
47 | |
48 | /*! |
49 | * \brief Set the default function implementaiton. |
50 | * \param value The default function |
51 | * \param allow_override If true, this call may override a previously registered function. If |
52 | * false, an error will be logged if the call would override a previously registered function. |
53 | * \return reference to self. |
54 | */ |
55 | TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false); |
56 | /*! |
57 | * \brief Register a specialized function |
58 | * \param tags The tags for this specialization |
59 | * \param value The specialized function |
60 | * \param allow_override If true, this call may override previously registered tags. If false, |
61 | * an error will be logged if the call would override previously registered tags. |
62 | * \return reference to self. |
63 | */ |
64 | TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags, |
65 | const runtime::PackedFunc value, bool allow_override = false); |
66 | /*! |
67 | * \brief Call generic function by directly passing in unpacked format. |
68 | * \param args Arguments to be passed. |
69 | * \tparam Args arguments to be passed. |
70 | * |
71 | * \code |
72 | * // Example code on how to call generic function |
73 | * void CallGeneric(GenericFunc f) { |
74 | * // call like normal functions by pass in arguments |
75 | * // return value is automatically converted back |
76 | * int rvalue = f(1, 2.0); |
77 | * } |
78 | * \endcode |
79 | */ |
80 | template <typename... Args> |
81 | inline runtime::TVMRetValue operator()(Args&&... args) const; |
82 | /*! |
83 | * \brief Invoke the relevant function for the current target context, set by set_target_context. |
84 | * Arguments are passed in packed format. |
85 | * \param args The arguments to pass to the function. |
86 | * \param ret The return value |
87 | */ |
88 | TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const; |
89 | /*! |
90 | * \brief Get the packed function specified for the current target context. |
91 | */ |
92 | TVM_DLL PackedFunc GetPacked() const; |
93 | /*! |
94 | * \brief Find or register the GenericFunc instance corresponding to the give name |
95 | * \param name The name of the registered GenericFunc |
96 | * \return The GenericFunc instance |
97 | */ |
98 | TVM_DLL static GenericFunc Get(const std::string& name); |
99 | |
100 | /*! |
101 | * \brief Add a GenericFunc instance to the registry |
102 | * \param func The GenericFunc instance |
103 | * \param name The name of the registered GenericFunc |
104 | */ |
105 | TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name); |
106 | |
107 | /*! |
108 | * \brief access the internal node container |
109 | * \return the pointer to the internal node container |
110 | */ |
111 | inline GenericFuncNode* operator->(); |
112 | |
113 | // declare container type |
114 | using ContainerType = GenericFuncNode; |
115 | |
116 | // Internal class. |
117 | struct Manager; |
118 | |
119 | private: |
120 | friend struct Manager; |
121 | }; |
122 | |
123 | template <typename... Args> |
124 | inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const { |
125 | const int kNumArgs = sizeof...(Args); |
126 | const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; |
127 | TVMValue values[kArraySize]; |
128 | int type_codes[kArraySize]; |
129 | runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes), |
130 | std::forward<Args>(args)...); |
131 | runtime::TVMRetValue rv; |
132 | CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv); |
133 | return rv; |
134 | } |
135 | |
136 | /*! |
137 | * \brief Represents a generic function that can be specialized on a per-target basis. |
138 | */ |
139 | class GenericFuncNode : public Object { |
140 | public: |
141 | /*! \brief name of the function */ |
142 | std::string name_; |
143 | /* \brief the generic builder */ |
144 | runtime::PackedFunc generic_func_; |
145 | /* \brief map from keys to registered functions */ |
146 | std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_; |
147 | |
148 | void VisitAttrs(AttrVisitor* v) {} |
149 | |
150 | static constexpr const char* _type_key = "GenericFunc" ; |
151 | TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object); |
152 | }; |
153 | |
154 | inline GenericFuncNode* GenericFunc::operator->() { |
155 | return static_cast<GenericFuncNode*>(get_mutable()); |
156 | } |
157 | |
158 | #define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM |
159 | |
160 | /*! |
161 | * \def TVM_REGISTER_GENERIC_FUNC |
162 | * \brief Register a new generic function, or set a device-specific variant |
163 | * of the corresponding function. |
164 | * |
165 | * \param name The name of the function |
166 | */ |
167 | #define TVM_REGISTER_GENERIC_FUNC(name) \ |
168 | TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name) |
169 | |
170 | } // namespace tvm |
171 | #endif // TVM_TARGET_GENERIC_FUNC_H_ |
172 | |