1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file op_common.h |
22 | * \brief A set of utilities and common functionality |
23 | * for relay ops. |
24 | */ |
25 | #ifndef TVM_RELAY_OP_OP_COMMON_H_ |
26 | #define TVM_RELAY_OP_OP_COMMON_H_ |
27 | |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/op.h> |
30 | #include <tvm/relay/op_attr_types.h> |
31 | |
32 | #include <string> |
33 | #include <unordered_map> |
34 | #include <vector> |
35 | |
36 | #include "../transforms/infer_layout_utils.h" |
37 | #include "type_relations.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | |
42 | /*! Quick helper macro |
43 | * - Expose a positional make function to construct the node. |
44 | * - Register op to the registry. |
45 | * |
46 | * We make the decision to always only expose positional argument. |
47 | * We will do rewrapping in the frontend to support language |
48 | * sugars such as keyword arguments and default value. |
49 | |
50 | * \param OpName the name of registry. |
51 | */ |
52 | #define RELAY_REGISTER_UNARY_OP(OpName) \ |
53 | TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \ |
54 | static const Op& op = Op::Get(OpName); \ |
55 | return Call(op, {data}, Attrs(), {}); \ |
56 | }); \ |
57 | RELAY_REGISTER_OP(OpName) \ |
58 | .set_num_inputs(1) \ |
59 | .add_argument("data", "Tensor", "The input tensor.") \ |
60 | .add_type_rel("Identity", IdentityRel) \ |
61 | .set_attr<TOpPattern>("TOpPattern", kElemWise) \ |
62 | .set_attr<TOpIsStateful>("TOpIsStateful", false) \ |
63 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) |
64 | |
65 | /*! Quick helper macro |
66 | * - Expose a positional make function to construct the node. |
67 | * - Register op to the registry. |
68 | * |
69 | * We make the decision to always only expose positional argument. |
70 | * We will do rewrapping in the frontend to support language |
71 | * sugars such as keyword arguments and default value. |
72 | * |
73 | * \param OpName the name of registry. |
74 | */ |
75 | #define RELAY_REGISTER_BINARY_OP(OpName) \ |
76 | TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ |
77 | static const Op& op = Op::Get(OpName); \ |
78 | return Call(op, {lhs, rhs}, Attrs(), {}); \ |
79 | }); \ |
80 | RELAY_REGISTER_OP(OpName) \ |
81 | .set_num_inputs(2) \ |
82 | .add_argument("lhs", "Tensor", "The left hand side tensor.") \ |
83 | .add_argument("rhs", "Tensor", "The right hand side tensor.") \ |
84 | .add_type_rel("Broadcast", BroadcastRel) \ |
85 | .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ |
86 | .set_attr<TOpIsStateful>("TOpIsStateful", false) \ |
87 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout) |
88 | |
89 | // Comparisons |
90 | #define RELAY_REGISTER_CMP_OP(OpName) \ |
91 | TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ |
92 | static const Op& op = Op::Get(OpName); \ |
93 | return Call(op, {lhs, rhs}, Attrs(), {}); \ |
94 | }); \ |
95 | RELAY_REGISTER_OP(OpName) \ |
96 | .set_num_inputs(2) \ |
97 | .add_argument("lhs", "Tensor", "The left hand side tensor.") \ |
98 | .add_argument("rhs", "Tensor", "The right hand side tensor.") \ |
99 | .add_type_rel("BroadcastComp", BroadcastCompRel) \ |
100 | .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ |
101 | .set_attr<TOpIsStateful>("TOpIsStateful", false) \ |
102 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BinaryBroadcastLayout) |
103 | |
104 | /*! \brief A helper class for matching and rewriting operators. */ |
105 | template <typename R> |
106 | class OpMatch { |
107 | public: |
108 | using MatchFunc = |
109 | std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>; |
110 | |
111 | /*! \brief Match an operator with the given name. |
112 | * \param op_name The name of the operator to match. |
113 | * \param func The function to execute when it matches. |
114 | * \return A self-reference for builder style API. |
115 | */ |
116 | inline OpMatch& Match(const std::string& op_name, MatchFunc func) { |
117 | auto op = Op::Get(op_name); |
118 | match_map_.insert({op, func}); |
119 | return *this; |
120 | } |
121 | |
122 | /*! \brief Rewrite a call operation based on the operator and the registered |
123 | * match functions. |
124 | * \param call The call to rewrite. |
125 | * \return The result of rewriting. |
126 | */ |
127 | inline R operator()(const Call& call) { |
128 | auto it = match_map_.find(Downcast<Op>(call->op)); |
129 | if (it != match_map_.end()) { |
130 | return it->second(call->args, call->attrs, call->type_args); |
131 | } else { |
132 | if (default_ != nullptr) { |
133 | return default_(call->args, call->attrs, call->type_args); |
134 | } else { |
135 | LOG(FATAL) << "unexpected operation " << call->op; |
136 | } |
137 | } |
138 | } |
139 | |
140 | private: |
141 | /*! \brief The match function map. */ |
142 | std::unordered_map<Op, MatchFunc, ObjectPtrHash, ObjectPtrEqual> match_map_; |
143 | /*! \brief An optional default case. */ |
144 | MatchFunc default_; |
145 | }; |
146 | |
147 | /*! \brief A utility function to get padding width from a 1 or 2 ints tuple. */ |
148 | inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) { |
149 | if (padding.size() == 1) { |
150 | *pad_w = padding[0] * 2; |
151 | } else if (padding.size() == 2) { |
152 | *pad_w = padding[0] + padding[1]; |
153 | } else { |
154 | ICHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size(); |
155 | } |
156 | } |
157 | |
158 | /*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */ |
159 | inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h, |
160 | IndexExpr* pad_w) { |
161 | if (padding.size() == 1) { |
162 | *pad_h = padding[0] * 2; |
163 | *pad_w = padding[0] * 2; |
164 | } else if (padding.size() == 2) { |
165 | *pad_h = padding[0] * 2; |
166 | *pad_w = padding[1] * 2; |
167 | } else if (padding.size() == 4) { |
168 | *pad_h = padding[0] + padding[2]; |
169 | *pad_w = padding[1] + padding[3]; |
170 | } else { |
171 | ICHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size(); |
172 | } |
173 | } |
174 | |
175 | /*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */ |
176 | inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d, |
177 | IndexExpr* pad_h, IndexExpr* pad_w) { |
178 | if (padding.size() == 1) { |
179 | *pad_d = padding[0] * 2; |
180 | *pad_h = padding[0] * 2; |
181 | *pad_w = padding[0] * 2; |
182 | } else if (padding.size() == 3) { |
183 | *pad_d = padding[0] * 2; |
184 | *pad_h = padding[1] * 2; |
185 | *pad_w = padding[2] * 2; |
186 | } else if (padding.size() == 6) { |
187 | *pad_d = padding[0] + padding[3]; |
188 | *pad_h = padding[1] + padding[4]; |
189 | *pad_w = padding[2] + padding[5]; |
190 | } else { |
191 | ICHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size(); |
192 | } |
193 | } |
194 | |
195 | } // namespace relay |
196 | } // namespace tvm |
197 | |
198 | #endif // TVM_RELAY_OP_OP_COMMON_H_ |
199 | |