1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/function.h
22 * \brief Relay Function.
23 */
24#ifndef TVM_RELAY_FUNCTION_H_
25#define TVM_RELAY_FUNCTION_H_
26
27#include <tvm/ir/function.h>
28#include <tvm/relay/expr.h>
29
30#include <string>
31
32namespace tvm {
33namespace relay {
34
35/*!
36 * \brief Relay Function container
37 * \sa Function
38 */
39class FunctionNode : public BaseFuncNode {
40 public:
41 /*! \brief Function parameters */
42 tvm::Array<Var> params;
43 /*!
44 * \brief
45 * The expression which represents the computation of the function,
46 * the expression may reference the parameters, and the type of it
47 * or sub-expressions may reference the type variables.
48 */
49 Expr body;
50 /*! \brief User annotated return type of the function. */
51 Type ret_type;
52 /*!
53 * \brief Type parameters of the function.
54 * Enables the function to vary its type based on these.
55 * This corresponds to template paramaters in c++'s terminology.
56 *
57 * \note This can be usually empty for non-polymorphic functions.
58 */
59 tvm::Array<TypeVar> type_params;
60
61 void VisitAttrs(tvm::AttrVisitor* v) {
62 v->Visit("params", &params);
63 v->Visit("body", &body);
64 v->Visit("ret_type", &ret_type);
65 v->Visit("type_params", &type_params);
66 v->Visit("attrs", &attrs);
67 v->Visit("virtual_device_", &virtual_device_);
68 v->Visit("span", &span);
69 v->Visit("_checked_type_", &checked_type_);
70 }
71
72 bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
73 // Important to make def equal first.
74 equal->MarkGraphNode();
75 return equal.DefEqual(params, other->params) &&
76 equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) &&
77 equal(attrs, other->attrs) && equal(body, other->body);
78 }
79
80 void SHashReduce(SHashReducer hash_reduce) const {
81 hash_reduce->MarkGraphNode();
82 hash_reduce.DefHash(params);
83 hash_reduce.DefHash(type_params);
84 hash_reduce(ret_type);
85 hash_reduce(attrs);
86 hash_reduce(body);
87 }
88
89 /*!
90 * \brief Return the derived function annotation of this expression.
91 *
92 * \return The function type annotation.
93 * \note The function type annotation can contain IncompleteType.
94 */
95 TVM_DLL FuncType func_type_annotation() const;
96
97 static constexpr const char* _type_key = "relay.Function";
98 TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
99};
100
101/*!
102 * \brief Managed reference to FunctionNode.
103 * \sa FunctionNode
104 */
105class Function : public BaseFunc {
106 public:
107 /*!
108 * \brief Constructor
109 * \param params The parameters of the function.
110 * \param body The body of the function.
111 * \param ret_type The return type of the function.
112 * \param ty_params The type parameters.
113 * \param attrs Additional function attributes.
114 * \param span The span of the function.
115 */
116 TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
117 tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
118
119 TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
120 TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
121};
122
123/*!
124 * \brief Returns \p function with the given properties. A null property denotes 'no change'.
125 * Returns \p function if all properties are unchanged. Otherwise, returns a copy with the new
126 * fields.
127 */
128Function WithFields(Function function, Optional<Array<Var>> opt_params = Optional<Array<Var>>(),
129 Optional<Expr> opt_body = Optional<Expr>(),
130 Optional<Type> opt_ret_type = Optional<Type>(),
131 Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
132 Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
133 Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
134 Optional<Span> opt_span = Optional<Span>());
135
136/*
137 * \brief Returns the Relay FunctionNode represented by base_func if it should be optimized,
138 * otherwise returns nullptr.
139 *
140 * This means returns nullptr:
141 * - For PrimFuncs, since not Relay Functions.
142 * - For Functions marked for external compilation (with "Compiler").
143 * - For Functions marked as already having an external definition (with "ExternalSymbol").
144 * - For Functions marked as not to be optimized (with "SkipOptimization").
145 *
146 * TODO(mbs): Audit all enumerations of IRModule::functions to use this or some family of such.
147 */
148const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
149
150/*!
151 * \brief namespace of the attributes that can be attached to a relay::Function.
152 */
153namespace attr {
154
155/*!
156 * \brief Mark the function as representing a sub-graph which is to be lowered or compiled as
157 * a unit. For example, the function may represent a kernel which TVM will lower to a PrimFunc.
158 * If present should be bound to \p Integer(1). May be accompanied by "Compiler", see below.
159 * The function body should be considered opaque by Relay, and many passes simply ignore these
160 * functions.
161 *
162 * Type: Integer
163 */
164constexpr const char* kPrimitive = "Primitive";
165
166/*!
167 * \brief Mark the function as externally implemented, ie bound in a runtime::Module within the
168 * IRModule's "external_mods" attribute. If present should be bound to \p Integer(1). Generally
169 * the only attribute when present.
170 *
171 * Type: Integer
172 */
173constexpr const char* kExtern = "Extern";
174
175/*!
176 * \brief Indicates the name of the external codegen 'compiler' that should be used to lower
177 * or compile the function other than TVM's default lowering pipeline. The name may correspond
178 * to a TargetKind name. There may be a global function registered under 'relay.ext.{name}'.
179 *
180 * Type: String
181 */
182constexpr const char* kCompiler = "Compiler";
183
184/*! \brief Indicate if the function is a closure. */
185constexpr const char* kClosure = "Closure";
186/*! \brief Store a Var to parameter/Constant mapping on a Function. */
187constexpr const char* kParams = "__params__";
188/*! \brief Mark if the function should be avoided being optimized. */
189constexpr const char* kSkipOptimization = "SkipOptimization";
190/*! \brief Treat the function as a composite operator. */
191constexpr const char* kComposite = "Composite";
192/*! \brief Mark the function to be inlined. */
193constexpr const char* kInline = "Inline";
194/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
195constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
196/*! \brief Mark the function as only composed of reshape operations. */
197constexpr const char* kReshapeOnly = "relay.reshape_only";
198
199} // namespace attr
200
201} // namespace relay
202} // namespace tvm
203#endif // TVM_RELAY_FUNCTION_H_
204