1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/op_attr_types.h
22 * \brief The Expr and related elements in DataFlow construction.
23 */
24#ifndef TVM_RELAY_OP_ATTR_TYPES_H_
25#define TVM_RELAY_OP_ATTR_TYPES_H_
26
27#include <tvm/relay/expr.h>
28#include <tvm/relay/type.h>
29#include <tvm/target/generic_func.h>
30#include <tvm/target/target.h>
31#include <tvm/te/schedule.h>
32#include <tvm/te/tensor.h>
33#include <tvm/tir/data_layout.h>
34
35#include <string>
36
37namespace tvm {
38namespace relay {
39
40using tir::BijectiveLayoutNode;
41using tir::Layout;
42using tir::LayoutAxis;
43
44/*! \brief operator pattern used in graph fusion */
45enum OpPatternKind {
46 // Elementwise operation
47 kElemWise = 0,
48 // Broadcasting operator, can always map output axis to the input in order.
49 // for example :code:`out[i, ax1, j, ax2] = input[i, j]`.
50 // Note that the axis need to be in order so transpose is not a bcast operator.
51 kBroadcast = 1,
52 // Injective operator, can always injectively map output axis to a single input axis.
53 // All injective operator can still be safely fused to injective and reduction.
54 kInjective = 2,
55 // Communicative reduction operator.
56 kCommReduce = 3,
57 // Complex operation, can still fuse elemwise operations into its output.
58 // but cannot chain another complex op
59 kOutEWiseFusable = 4,
60 // The pattern for tuple nodes. Can fuse into subsequent injective ops,
61 // but treated specially
62 kTuple = 7,
63 // Opaque operation, cannot fuse anything.
64 kOpaque = 8
65};
66
67/*! \brief the operator pattern */
68using TOpPattern = int;
69
70/*!
71 * \brief Whether operator is stateful or contain internal state.
72 *
73 * All the primitive ops we registered so far are pure.
74 * This attribute is left for potential future compatible reasons.
75 * We can always work around the stateful ops by adding an additional
76 * handle argument and return it.
77 */
78using TOpIsStateful = bool;
79
80/*!
81 * \brief Mark the operator as non-computational.
82 */
83using TNonComputational = bool;
84
85/*!
86 * \brief Mark the operator as reshape op of its first input
87 * and can be turned into a nop when the first input and output
88 * shares the same piece of memory.
89 */
90using TReshapeOp = bool;
91
92/*!
93 * \brief Mark the operator whether output shape is data dependent.
94 */
95using TShapeDataDependent = Array<Integer>;
96
97/*!
98 * \brief Computation description interface.
99 *
100 * \note This function have a special convention
101 * for functions with tuple input/output.
102 *
103 * So far we restrict tuple support to the following case:
104 * - Function which takes a single tuple as input.
105 * - Function which outputs a single tuple.
106 *
107 * In both cases, the tuple is flattened as array.
108 *
109 * \param attrs The attribute of the primitive
110 * \param inputs The input tensors.
111 * \param out_type The output type information
112 & these are always placeholders.
113 * \return The output compute description of the operator.
114 */
115using FTVMCompute = runtime::TypedPackedFunc<Array<te::Tensor>(
116 const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type)>;
117
118/*!
119 * \brief Build the computation schedule for
120 * op whose root is at current op.
121 *
122 * \param attrs The attribute of the node.
123 * \param outs The output tensors.
124 * \param target The build target.
125 * \return schedule The computation schedule.
126 */
127using FTVMSchedule = runtime::TypedPackedFunc<te::Schedule(
128 const Attrs& attrs, const Array<te::Tensor>& outs, const Target& target)>;
129
130/*!
131 * \brief Generate the strategy of operators. This function is a generic
132 * function and can be re-defined for different targets.
133 *
134 * The function signature of generic function is:
135 * OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
136 * const Type& out_type, const Target& target)
137 */
138using FTVMStrategy = GenericFunc;
139
140/*!
141 * \brief Alternate the layout of operators or replace the
142 * operator with other expressions. This function will be invoked
143 * in AlterOpLayout pass.
144 * \param attrs The attribute of the original node.
145 * \param args The input symbols of the original node.
146 * \param tinfos An array of placeholders, use for getting the inferred shape
147 * and dtype of the inputs.
148 * \return new_expr The modified expression.
149 */
150using FTVMAlterOpLayout =
151 runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Array<Expr>& args,
152 const Array<te::Tensor>& tinfos, const Type& out_type)>;
153
154/*!
155 * \brief Convert the layout of operators or replace the
156 * operator with other expressions. This function will be invoked
157 * in ConvertLayout pass.
158 * \param attrs The attribute of the original node.
159 * \param inputs The input symbols of the original node.
160 * \param tinfos An array of placeholders, use for getting the inferred shape
161 * and dtype of the inputs.
162 * \param desired_layouts Specify an array of desired layouts for each input.
163 * For example a conv2d op: Array("NHWC", "OHWI"), this
164 * specifies the desired layout for data then kernel.
165 * \return new_expr The modified expression.
166 */
167using FTVMConvertOpLayout = runtime::TypedPackedFunc<Expr(
168 const Attrs& attrs, const Array<Expr>& args, const Array<te::Tensor>& tinfos,
169 const Array<String>& desired_layouts)>;
170/*!
171 * \brief Legalizes an expression with another expression. This function will be
172 * invoked in Legalize pass. It is a target-dependent pass.
173 * \param attrs The attribute of the original node.
174 * \param args The input symbols of the original node.
175 * \param arg_types An array of placeholders, use for getting the inferred shape
176 * and dtype of the inputs.
177 * \return new_expr The modified expression.
178 */
179using FTVMLegalize = runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Array<Expr>& args,
180 const Array<tvm::relay::Type>& arg_types)>;
181
182/*!
183 * \brief Annotates an expression to indicate if an op should be compiled using
184 * the given compiler/target.
185 * \param expr The original expr.
186 * \return true if this op should be registered to invoke a specific compiler
187 * for codegen, otherwise, false.
188 */
189using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Expr& expr)>;
190
191/*!
192 * \brief Forward rewriting rule for a specific op.
193 *
194 * \param ref_call The reference old call type to be rewritten.
195 * We can make use of the op and type information.
196 * \param new_args The new arguments (some of them could be TempExpr).
197 * \param ctx Optional context information about ref_call.
198 * \return The rewriten result call, can also return nullptr,
199 * which indicate the rewriter should use the default fallback
200 * rule that realizes all its input and compose the call.
201 *
202 * \note When we register the function, we can register
203 * a different signature with ctx to be a specific node type.
204 */
205using FForwardRewrite = runtime::TypedPackedFunc<Expr(
206 const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx)>;
207
208/*!
209 * \brief Gradient for a specific op.
210 *
211 * \param orig_call the original Expr.
212 * \param output_grad the gradient of the Expr.
213 * \return the gradient for each parameters.
214 */
215using FPrimalGradient =
216 runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call, const Expr& output_grad)>;
217
218/*!
219 * \brief The codegeneration strategy for dynamic dimensions.
220 */
221enum AnyCodegenStrategy {
222 /*! \brief The default strategy of using completely variable dimensions. */
223 kVariableDimensions
224};
225
226/*! \brief A runtime representation of shape. */
227using Shape = Array<IndexExpr>;
228
229using FShapeFunc = runtime::TypedPackedFunc<Array<te::Tensor>(
230 const Attrs& attrs, const Array<te::Tensor>& inputs, const Array<IndexExpr>& out_ndims)>;
231
232} // namespace relay
233} // namespace tvm
234#endif // TVM_RELAY_OP_ATTR_TYPES_H_
235