1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/ir/env_func.h |
22 | * \brief Serializable global function used in IR. |
23 | */ |
24 | #ifndef TVM_IR_ENV_FUNC_H_ |
25 | #define TVM_IR_ENV_FUNC_H_ |
26 | |
27 | #include <tvm/node/reflection.h> |
28 | |
29 | #include <string> |
30 | #include <utility> |
31 | |
32 | namespace tvm { |
33 | /*! |
34 | * \brief A serializable function backed by TVM's global environment. |
35 | * |
36 | * This is a wrapper to enable serializable global PackedFunc. |
37 | * An EnvFunc is saved by its name in the global registry |
38 | * under the assumption that the same function is registered during load. |
39 | * \sa EnvFunc |
40 | */ |
41 | class EnvFuncNode : public Object { |
42 | public: |
43 | /*! \brief Unique name of the global function */ |
44 | String name; |
45 | /*! \brief The internal packed function */ |
46 | runtime::PackedFunc func; |
47 | /*! \brief constructor */ |
48 | EnvFuncNode() {} |
49 | |
50 | void VisitAttrs(AttrVisitor* v) { v->Visit("name" , &name); } |
51 | |
52 | bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { |
53 | // name uniquely identifies the env function. |
54 | return name == other->name; |
55 | } |
56 | |
57 | void SHashReduce(SHashReducer hash_reduce) const { |
58 | // Name uniquely identifies the env function. |
59 | hash_reduce(name); |
60 | } |
61 | |
62 | static constexpr const char* _type_key = "EnvFunc" ; |
63 | static constexpr bool _type_has_method_sequal_reduce = true; |
64 | static constexpr bool _type_has_method_shash_reduce = true; |
65 | TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); |
66 | }; |
67 | |
68 | /*! |
69 | * \brief Managed reference to EnvFuncNode. |
70 | * \sa EnvFuncNode |
71 | */ |
72 | class EnvFunc : public ObjectRef { |
73 | public: |
74 | EnvFunc() {} |
75 | explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {} |
76 | /*! \return The internal global function pointer */ |
77 | const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); } |
78 | /*! |
79 | * \brief Invoke the function. |
80 | * \param args The arguments |
81 | * \returns The return value. |
82 | */ |
83 | template <typename... Args> |
84 | runtime::TVMRetValue operator()(Args&&... args) const { |
85 | const EnvFuncNode* n = operator->(); |
86 | ICHECK(n != nullptr); |
87 | return n->func(std::forward<Args>(args)...); |
88 | } |
89 | /*! |
90 | * \brief Get a global function based on the name. |
91 | * \param name The name of the global function. |
92 | * \return The created global function. |
93 | * \note The function can be unique |
94 | */ |
95 | TVM_DLL static EnvFunc Get(const String& name); |
96 | /*! \brief specify container node */ |
97 | using ContainerType = EnvFuncNode; |
98 | }; |
99 | |
100 | /*! |
101 | * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>" |
102 | */ |
103 | template <typename FType> |
104 | class TypedEnvFunc; |
105 | |
106 | /*! |
107 | * \anchor TypedEnvFuncAnchor |
108 | * \brief A typed version of EnvFunc. |
109 | * It is backed by a GlobalFuncNode internally. |
110 | * |
111 | * \tparam R The return value of the function. |
112 | * \tparam Args The argument signature of the function. |
113 | * \sa EnvFunc |
114 | */ |
115 | template <typename R, typename... Args> |
116 | class TypedEnvFunc<R(Args...)> : public ObjectRef { |
117 | public: |
118 | /*! \brief short hand for this function type */ |
119 | using TSelf = TypedEnvFunc<R(Args...)>; |
120 | TypedEnvFunc() {} |
121 | explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {} |
122 | /*! |
123 | * \brief Assign global function to a TypedEnvFunc |
124 | * \param other Another global function. |
125 | * \return reference to self. |
126 | */ |
127 | TSelf& operator=(const EnvFunc& other) { |
128 | ObjectRef::operator=(other); |
129 | return *this; |
130 | } |
131 | /*! \return The internal global function pointer */ |
132 | const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); } |
133 | /*! |
134 | * \brief Invoke the function. |
135 | * \param args The arguments |
136 | * \returns The return value. |
137 | */ |
138 | R operator()(Args... args) const { |
139 | const EnvFuncNode* n = operator->(); |
140 | ICHECK(n != nullptr); |
141 | return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func, |
142 | std::forward<Args>(args)...); |
143 | } |
144 | /*! \brief specify container node */ |
145 | using ContainerType = EnvFuncNode; |
146 | }; |
147 | |
148 | } // namespace tvm |
149 | #endif // TVM_IR_ENV_FUNC_H_ |
150 | |