1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/ir/env_func.h
22 * \brief Serializable global function used in IR.
23 */
24#ifndef TVM_IR_ENV_FUNC_H_
25#define TVM_IR_ENV_FUNC_H_
26
27#include <tvm/node/reflection.h>
28
29#include <string>
30#include <utility>
31
32namespace tvm {
33/*!
34 * \brief A serializable function backed by TVM's global environment.
35 *
36 * This is a wrapper to enable serializable global PackedFunc.
37 * An EnvFunc is saved by its name in the global registry
38 * under the assumption that the same function is registered during load.
39 * \sa EnvFunc
40 */
41class EnvFuncNode : public Object {
42 public:
43 /*! \brief Unique name of the global function */
44 String name;
45 /*! \brief The internal packed function */
46 runtime::PackedFunc func;
47 /*! \brief constructor */
48 EnvFuncNode() {}
49
50 void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
51
52 bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
53 // name uniquely identifies the env function.
54 return name == other->name;
55 }
56
57 void SHashReduce(SHashReducer hash_reduce) const {
58 // Name uniquely identifies the env function.
59 hash_reduce(name);
60 }
61
62 static constexpr const char* _type_key = "EnvFunc";
63 static constexpr bool _type_has_method_sequal_reduce = true;
64 static constexpr bool _type_has_method_shash_reduce = true;
65 TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
66};
67
68/*!
69 * \brief Managed reference to EnvFuncNode.
70 * \sa EnvFuncNode
71 */
72class EnvFunc : public ObjectRef {
73 public:
74 EnvFunc() {}
75 explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
76 /*! \return The internal global function pointer */
77 const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
78 /*!
79 * \brief Invoke the function.
80 * \param args The arguments
81 * \returns The return value.
82 */
83 template <typename... Args>
84 runtime::TVMRetValue operator()(Args&&... args) const {
85 const EnvFuncNode* n = operator->();
86 ICHECK(n != nullptr);
87 return n->func(std::forward<Args>(args)...);
88 }
89 /*!
90 * \brief Get a global function based on the name.
91 * \param name The name of the global function.
92 * \return The created global function.
93 * \note The function can be unique
94 */
95 TVM_DLL static EnvFunc Get(const String& name);
96 /*! \brief specify container node */
97 using ContainerType = EnvFuncNode;
98};
99
100/*!
101 * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
102 */
103template <typename FType>
104class TypedEnvFunc;
105
106/*!
107 * \anchor TypedEnvFuncAnchor
108 * \brief A typed version of EnvFunc.
109 * It is backed by a GlobalFuncNode internally.
110 *
111 * \tparam R The return value of the function.
112 * \tparam Args The argument signature of the function.
113 * \sa EnvFunc
114 */
115template <typename R, typename... Args>
116class TypedEnvFunc<R(Args...)> : public ObjectRef {
117 public:
118 /*! \brief short hand for this function type */
119 using TSelf = TypedEnvFunc<R(Args...)>;
120 TypedEnvFunc() {}
121 explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
122 /*!
123 * \brief Assign global function to a TypedEnvFunc
124 * \param other Another global function.
125 * \return reference to self.
126 */
127 TSelf& operator=(const EnvFunc& other) {
128 ObjectRef::operator=(other);
129 return *this;
130 }
131 /*! \return The internal global function pointer */
132 const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
133 /*!
134 * \brief Invoke the function.
135 * \param args The arguments
136 * \returns The return value.
137 */
138 R operator()(Args... args) const {
139 const EnvFuncNode* n = operator->();
140 ICHECK(n != nullptr);
141 return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func,
142 std::forward<Args>(args)...);
143 }
144 /*! \brief specify container node */
145 using ContainerType = EnvFuncNode;
146};
147
148} // namespace tvm
149#endif // TVM_IR_ENV_FUNC_H_
150