1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/generic_func.h
22 * \brief Generic function that can be specialzied on a per target basis.
23 */
24#ifndef TVM_TARGET_GENERIC_FUNC_H_
25#define TVM_TARGET_GENERIC_FUNC_H_
26
27#include <tvm/runtime/packed_func.h>
28#include <tvm/support/with.h>
29#include <tvm/target/target.h>
30
31#include <string>
32#include <unordered_map>
33#include <utility>
34#include <vector>
35
36namespace tvm {
37
38class GenericFuncNode;
39
40/*!
41 * \brief Generic function that can be specialized on a per-target basis.
42 */
43class GenericFunc : public ObjectRef {
44 public:
45 GenericFunc() {}
46 explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
47
48 /*!
49 * \brief Set the default function implementaiton.
50 * \param value The default function
51 * \param allow_override If true, this call may override a previously registered function. If
52 * false, an error will be logged if the call would override a previously registered function.
53 * \return reference to self.
54 */
55 TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
56 /*!
57 * \brief Register a specialized function
58 * \param tags The tags for this specialization
59 * \param value The specialized function
60 * \param allow_override If true, this call may override previously registered tags. If false,
61 * an error will be logged if the call would override previously registered tags.
62 * \return reference to self.
63 */
64 TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
65 const runtime::PackedFunc value, bool allow_override = false);
66 /*!
67 * \brief Call generic function by directly passing in unpacked format.
68 * \param args Arguments to be passed.
69 * \tparam Args arguments to be passed.
70 *
71 * \code
72 * // Example code on how to call generic function
73 * void CallGeneric(GenericFunc f) {
74 * // call like normal functions by pass in arguments
75 * // return value is automatically converted back
76 * int rvalue = f(1, 2.0);
77 * }
78 * \endcode
79 */
80 template <typename... Args>
81 inline runtime::TVMRetValue operator()(Args&&... args) const;
82 /*!
83 * \brief Invoke the relevant function for the current target context, set by set_target_context.
84 * Arguments are passed in packed format.
85 * \param args The arguments to pass to the function.
86 * \param ret The return value
87 */
88 TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
89 /*!
90 * \brief Get the packed function specified for the current target context.
91 */
92 TVM_DLL PackedFunc GetPacked() const;
93 /*!
94 * \brief Find or register the GenericFunc instance corresponding to the give name
95 * \param name The name of the registered GenericFunc
96 * \return The GenericFunc instance
97 */
98 TVM_DLL static GenericFunc Get(const std::string& name);
99
100 /*!
101 * \brief Add a GenericFunc instance to the registry
102 * \param func The GenericFunc instance
103 * \param name The name of the registered GenericFunc
104 */
105 TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
106
107 /*!
108 * \brief access the internal node container
109 * \return the pointer to the internal node container
110 */
111 inline GenericFuncNode* operator->();
112
113 // declare container type
114 using ContainerType = GenericFuncNode;
115
116 // Internal class.
117 struct Manager;
118
119 private:
120 friend struct Manager;
121};
122
123template <typename... Args>
124inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
125 const int kNumArgs = sizeof...(Args);
126 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
127 TVMValue values[kArraySize];
128 int type_codes[kArraySize];
129 runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
130 std::forward<Args>(args)...);
131 runtime::TVMRetValue rv;
132 CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
133 return rv;
134}
135
136/*!
137 * \brief Represents a generic function that can be specialized on a per-target basis.
138 */
139class GenericFuncNode : public Object {
140 public:
141 /*! \brief name of the function */
142 std::string name_;
143 /* \brief the generic builder */
144 runtime::PackedFunc generic_func_;
145 /* \brief map from keys to registered functions */
146 std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
147
148 void VisitAttrs(AttrVisitor* v) {}
149
150 static constexpr const char* _type_key = "GenericFunc";
151 TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
152};
153
154inline GenericFuncNode* GenericFunc::operator->() {
155 return static_cast<GenericFuncNode*>(get_mutable());
156}
157
158#define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
159
160/*!
161 * \def TVM_REGISTER_GENERIC_FUNC
162 * \brief Register a new generic function, or set a device-specific variant
163 * of the corresponding function.
164 *
165 * \param name The name of the function
166 */
167#define TVM_REGISTER_GENERIC_FUNC(name) \
168 TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
169
170} // namespace tvm
171#endif // TVM_TARGET_GENERIC_FUNC_H_
172