1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file op_common.h
22 * \brief A set of utilities and common functionality
23 * for relay ops.
24 */
25#ifndef TVM_RELAY_OP_OP_COMMON_H_
26#define TVM_RELAY_OP_OP_COMMON_H_
27
28#include <tvm/relay/expr.h>
29#include <tvm/relay/op.h>
30#include <tvm/relay/op_attr_types.h>
31
32#include <string>
33#include <unordered_map>
34#include <vector>
35
36#include "../transforms/infer_layout_utils.h"
37#include "type_relations.h"
38
39namespace tvm {
40namespace relay {
41
42/*! Quick helper macro
43 * - Expose a positional make function to construct the node.
44 * - Register op to the registry.
45 *
46 * We make the decision to always only expose positional argument.
47 * We will do rewrapping in the frontend to support language
48 * sugars such as keyword arguments and default value.
49
50 * \param OpName the name of registry.
51 */
52#define RELAY_REGISTER_UNARY_OP(OpName) \
53 TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \
54 static const Op& op = Op::Get(OpName); \
55 return Call(op, {data}, Attrs(), {}); \
56 }); \
57 RELAY_REGISTER_OP(OpName) \
58 .set_num_inputs(1) \
59 .add_argument("data", "Tensor", "The input tensor.") \
60 .add_type_rel("Identity", IdentityRel) \
61 .set_attr<TOpPattern>("TOpPattern", kElemWise) \
62 .set_attr<TOpIsStateful>("TOpIsStateful", false) \
63 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
64
65/*! Quick helper macro
66 * - Expose a positional make function to construct the node.
67 * - Register op to the registry.
68 *
69 * We make the decision to always only expose positional argument.
70 * We will do rewrapping in the frontend to support language
71 * sugars such as keyword arguments and default value.
72 *
73 * \param OpName the name of registry.
74 */
75#define RELAY_REGISTER_BINARY_OP(OpName) \
76 TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
77 static const Op& op = Op::Get(OpName); \
78 return Call(op, {lhs, rhs}, Attrs(), {}); \
79 }); \
80 RELAY_REGISTER_OP(OpName) \
81 .set_num_inputs(2) \
82 .add_argument("lhs", "Tensor", "The left hand side tensor.") \
83 .add_argument("rhs", "Tensor", "The right hand side tensor.") \
84 .add_type_rel("Broadcast", BroadcastRel) \
85 .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
86 .set_attr<TOpIsStateful>("TOpIsStateful", false) \
87 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
88
89// Comparisons
90#define RELAY_REGISTER_CMP_OP(OpName) \
91 TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
92 static const Op& op = Op::Get(OpName); \
93 return Call(op, {lhs, rhs}, Attrs(), {}); \
94 }); \
95 RELAY_REGISTER_OP(OpName) \
96 .set_num_inputs(2) \
97 .add_argument("lhs", "Tensor", "The left hand side tensor.") \
98 .add_argument("rhs", "Tensor", "The right hand side tensor.") \
99 .add_type_rel("BroadcastComp", BroadcastCompRel) \
100 .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
101 .set_attr<TOpIsStateful>("TOpIsStateful", false) \
102 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout)
103
104/*! \brief A helper class for matching and rewriting operators. */
105template <typename R>
106class OpMatch {
107 public:
108 using MatchFunc =
109 std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
110
111 /*! \brief Match an operator with the given name.
112 * \param op_name The name of the operator to match.
113 * \param func The function to execute when it matches.
114 * \return A self-reference for builder style API.
115 */
116 inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
117 auto op = Op::Get(op_name);
118 match_map_.insert({op, func});
119 return *this;
120 }
121
122 /*! \brief Rewrite a call operation based on the operator and the registered
123 * match functions.
124 * \param call The call to rewrite.
125 * \return The result of rewriting.
126 */
127 inline R operator()(const Call& call) {
128 auto it = match_map_.find(Downcast<Op>(call->op));
129 if (it != match_map_.end()) {
130 return it->second(call->args, call->attrs, call->type_args);
131 } else {
132 if (default_ != nullptr) {
133 return default_(call->args, call->attrs, call->type_args);
134 } else {
135 LOG(FATAL) << "unexpected operation " << call->op;
136 }
137 }
138 }
139
140 private:
141 /*! \brief The match function map. */
142 std::unordered_map<Op, MatchFunc, ObjectPtrHash, ObjectPtrEqual> match_map_;
143 /*! \brief An optional default case. */
144 MatchFunc default_;
145};
146
147/*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */
148inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
149 if (padding.size() == 1) {
150 *pad_w = padding[0] * 2;
151 } else if (padding.size() == 2) {
152 *pad_w = padding[0] + padding[1];
153 } else {
154 ICHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size();
155 }
156}
157
158/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
159inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
160 IndexExpr* pad_w) {
161 if (padding.size() == 1) {
162 *pad_h = padding[0] * 2;
163 *pad_w = padding[0] * 2;
164 } else if (padding.size() == 2) {
165 *pad_h = padding[0] * 2;
166 *pad_w = padding[1] * 2;
167 } else if (padding.size() == 4) {
168 *pad_h = padding[0] + padding[2];
169 *pad_w = padding[1] + padding[3];
170 } else {
171 ICHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size();
172 }
173}
174
175/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
176inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
177 IndexExpr* pad_h, IndexExpr* pad_w) {
178 if (padding.size() == 1) {
179 *pad_d = padding[0] * 2;
180 *pad_h = padding[0] * 2;
181 *pad_w = padding[0] * 2;
182 } else if (padding.size() == 3) {
183 *pad_d = padding[0] * 2;
184 *pad_h = padding[1] * 2;
185 *pad_w = padding[2] * 2;
186 } else if (padding.size() == 6) {
187 *pad_d = padding[0] + padding[3];
188 *pad_h = padding[1] + padding[4];
189 *pad_w = padding[2] + padding[5];
190 } else {
191 ICHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size();
192 }
193}
194
195} // namespace relay
196} // namespace tvm
197
198#endif // TVM_RELAY_OP_OP_COMMON_H_
199