1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/function.h |
22 | * \brief Relay Function. |
23 | */ |
24 | #ifndef TVM_RELAY_FUNCTION_H_ |
25 | #define TVM_RELAY_FUNCTION_H_ |
26 | |
27 | #include <tvm/ir/function.h> |
28 | #include <tvm/relay/expr.h> |
29 | |
30 | #include <string> |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | /*! |
36 | * \brief Relay Function container |
37 | * \sa Function |
38 | */ |
39 | class FunctionNode : public BaseFuncNode { |
40 | public: |
41 | /*! \brief Function parameters */ |
42 | tvm::Array<Var> params; |
43 | /*! |
44 | * \brief |
45 | * The expression which represents the computation of the function, |
46 | * the expression may reference the parameters, and the type of it |
47 | * or sub-expressions may reference the type variables. |
48 | */ |
49 | Expr body; |
50 | /*! \brief User annotated return type of the function. */ |
51 | Type ret_type; |
52 | /*! |
53 | * \brief Type parameters of the function. |
54 | * Enables the function to vary its type based on these. |
55 | * This corresponds to template paramaters in c++'s terminology. |
56 | * |
57 | * \note This can be usually empty for non-polymorphic functions. |
58 | */ |
59 | tvm::Array<TypeVar> type_params; |
60 | |
61 | void VisitAttrs(tvm::AttrVisitor* v) { |
62 | v->Visit("params" , ¶ms); |
63 | v->Visit("body" , &body); |
64 | v->Visit("ret_type" , &ret_type); |
65 | v->Visit("type_params" , &type_params); |
66 | v->Visit("attrs" , &attrs); |
67 | v->Visit("virtual_device_" , &virtual_device_); |
68 | v->Visit("span" , &span); |
69 | v->Visit("_checked_type_" , &checked_type_); |
70 | } |
71 | |
72 | bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { |
73 | // Important to make def equal first. |
74 | equal->MarkGraphNode(); |
75 | return equal.DefEqual(params, other->params) && |
76 | equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) && |
77 | equal(attrs, other->attrs) && equal(body, other->body); |
78 | } |
79 | |
80 | void SHashReduce(SHashReducer hash_reduce) const { |
81 | hash_reduce->MarkGraphNode(); |
82 | hash_reduce.DefHash(params); |
83 | hash_reduce.DefHash(type_params); |
84 | hash_reduce(ret_type); |
85 | hash_reduce(attrs); |
86 | hash_reduce(body); |
87 | } |
88 | |
89 | /*! |
90 | * \brief Return the derived function annotation of this expression. |
91 | * |
92 | * \return The function type annotation. |
93 | * \note The function type annotation can contain IncompleteType. |
94 | */ |
95 | TVM_DLL FuncType func_type_annotation() const; |
96 | |
97 | static constexpr const char* _type_key = "relay.Function" ; |
98 | TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); |
99 | }; |
100 | |
101 | /*! |
102 | * \brief Managed reference to FunctionNode. |
103 | * \sa FunctionNode |
104 | */ |
105 | class Function : public BaseFunc { |
106 | public: |
107 | /*! |
108 | * \brief Constructor |
109 | * \param params The parameters of the function. |
110 | * \param body The body of the function. |
111 | * \param ret_type The return type of the function. |
112 | * \param ty_params The type parameters. |
113 | * \param attrs Additional function attributes. |
114 | * \param span The span of the function. |
115 | */ |
116 | TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params, |
117 | tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span()); |
118 | |
119 | TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); |
120 | TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); |
121 | }; |
122 | |
123 | /*! |
124 | * \brief Returns \p function with the given properties. A null property denotes 'no change'. |
125 | * Returns \p function if all properties are unchanged. Otherwise, returns a copy with the new |
126 | * fields. |
127 | */ |
128 | Function WithFields(Function function, Optional<Array<Var>> opt_params = Optional<Array<Var>>(), |
129 | Optional<Expr> opt_body = Optional<Expr>(), |
130 | Optional<Type> opt_ret_type = Optional<Type>(), |
131 | Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(), |
132 | Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(), |
133 | Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(), |
134 | Optional<Span> opt_span = Optional<Span>()); |
135 | |
136 | /* |
137 | * \brief Returns the Relay FunctionNode represented by base_func if it should be optimized, |
138 | * otherwise returns nullptr. |
139 | * |
140 | * This means returns nullptr: |
141 | * - For PrimFuncs, since not Relay Functions. |
142 | * - For Functions marked for external compilation (with "Compiler"). |
143 | * - For Functions marked as already having an external definition (with "ExternalSymbol"). |
144 | * - For Functions marked as not to be optimized (with "SkipOptimization"). |
145 | * |
146 | * TODO(mbs): Audit all enumerations of IRModule::functions to use this or some family of such. |
147 | */ |
148 | const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func); |
149 | |
150 | /*! |
151 | * \brief namespace of the attributes that can be attached to a relay::Function. |
152 | */ |
153 | namespace attr { |
154 | |
155 | /*! |
156 | * \brief Mark the function as representing a sub-graph which is to be lowered or compiled as |
157 | * a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc. |
158 | * If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below. |
159 | * The function body should be considered opaque by Relay, and many passes simply ignore these |
160 | * functions. |
161 | * |
162 | * Type: Integer |
163 | */ |
164 | constexpr const char* kPrimitive = "Primitive" ; |
165 | |
166 | /*! |
167 | * \brief Mark the function as externally implemented, ie bound in a runtime::Module within the |
168 | * IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally |
169 | * the only attribute when present. |
170 | * |
171 | * Type: Integer |
172 | */ |
173 | constexpr const char* kExtern = "Extern" ; |
174 | |
175 | /*! |
176 | * \brief Indicates the name of the external codegen 'compiler' that should be used to lower |
177 | * or compile the function other than TVM's default lowering pipeline. The name may correspond |
178 | * to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'. |
179 | * |
180 | * Type: String |
181 | */ |
182 | constexpr const char* kCompiler = "Compiler" ; |
183 | |
184 | /*! \brief Indicate if the function is a closure. */ |
185 | constexpr const char* kClosure = "Closure" ; |
186 | /*! \brief Store a Var to parameter/Constant mapping on a Function. */ |
187 | constexpr const char* kParams = "__params__" ; |
188 | /*! \brief Mark if the function should be avoided being optimized. */ |
189 | constexpr const char* kSkipOptimization = "SkipOptimization" ; |
190 | /*! \brief Treat the function as a composite operator. */ |
191 | constexpr const char* kComposite = "Composite" ; |
192 | /*! \brief Mark the function to be inlined. */ |
193 | constexpr const char* kInline = "Inline" ; |
194 | /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ |
195 | constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern" ; |
196 | /*! \brief Mark the function as only composed of reshape operations. */ |
197 | constexpr const char* kReshapeOnly = "relay.reshape_only" ; |
198 | |
199 | } // namespace attr |
200 | |
201 | } // namespace relay |
202 | } // namespace tvm |
203 | #endif // TVM_RELAY_FUNCTION_H_ |
204 | |