1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/nn/nn.h |
22 | * \brief Properties def of nn operators for sharing. |
23 | */ |
24 | #ifndef TVM_RELAY_OP_NN_NN_H_ |
25 | #define TVM_RELAY_OP_NN_NN_H_ |
26 | |
27 | #include <tvm/auto_scheduler/compute_dag.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/ir/expr.h> |
30 | #include <tvm/relay/type.h> |
31 | |
32 | #include <algorithm> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | #include "../op_common.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | |
41 | template <typename AttrType> |
42 | bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
43 | const TypeReporter& reporter) { |
44 | ICHECK_EQ(types.size(), 3); |
45 | const auto* tensor_a = types[0].as<TensorTypeNode>(); |
46 | const auto* tensor_b = types[1].as<TensorTypeNode>(); |
47 | if (tensor_a == nullptr) return false; |
48 | ICHECK(static_cast<int>(tensor_a->shape.size()) != 0); |
49 | |
50 | const AttrType* param = attrs.as<AttrType>(); |
51 | ICHECK(param != nullptr); |
52 | TensorType meta_schedule_tensor_b{nullptr}; |
53 | if (param->meta_schedule_original_shape.size() > 0) { |
54 | meta_schedule_tensor_b = TensorType(param->meta_schedule_original_shape, |
55 | tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); |
56 | tensor_b = meta_schedule_tensor_b.get(); |
57 | } |
58 | // Default set to dense layout |
59 | bool transpose_a = false; |
60 | bool transpose_b = true; |
61 | const auto& mattrs = attrs.as<MatmulAttrs>(); |
62 | if (mattrs != nullptr) { |
63 | transpose_a = mattrs->transpose_a; |
64 | transpose_b = mattrs->transpose_b; |
65 | } |
66 | |
67 | const Array<tvm::PrimExpr>& dshape = tensor_a->shape; |
68 | Array<tvm::PrimExpr> oshape = dshape; |
69 | tvm::PrimExpr reduce = dshape[dshape.size() - 1]; |
70 | if (transpose_a) { |
71 | reduce = dshape[dshape.size() - 2]; |
72 | oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]); |
73 | } |
74 | auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype); |
75 | if (param->units.defined()) { |
76 | // validate the tensor_b shape is proper if defined |
77 | // Assign tensor_b type |
78 | const Array<IndexExpr>& wshape = transpose_b ? Array<IndexExpr>({param->units, reduce}) |
79 | : Array<IndexExpr>({reduce, param->units}); |
80 | // It is possible for tensor_b to be nullptr in which case we will use |
81 | // data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly |
82 | // present we will use that. |
83 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
84 | // If the layout is rewritten by auto-scheduler or meta-schedule, |
85 | // we just forcefully apply the layout provided by auto-scheduler and |
86 | // skip the normal inference logic. |
87 | {} // do nothing |
88 | } else if (param->meta_schedule_original_shape.size() == 0) { |
89 | // Normal case: assign result to reporter |
90 | reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype)); |
91 | } |
92 | oshape.Set((oshape.size() - 1), param->units); |
93 | } else { |
94 | if (tensor_b == nullptr) return false; |
95 | const Array<tvm::PrimExpr>& wshape = tensor_b->shape; |
96 | // When tensor_b's layout has been rewritten, figure it out based on the |
97 | // total number of elements and input dimensions. |
98 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
99 | PrimExpr tensor_b_elements = 1; |
100 | for (size_t i = 0; i < wshape.size(); i++) { |
101 | tensor_b_elements = tensor_b_elements * wshape[i]; |
102 | } |
103 | oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() - 1]); |
104 | // Otherwise just pull it out of the tensor_b shape directly. |
105 | } else { |
106 | ICHECK(static_cast<int>(tensor_b->shape.size()) == 2); |
107 | if (param->auto_scheduler_rewritten_layout.size() == 0 && |
108 | param->meta_schedule_original_shape.size() == 0) { |
109 | // ensure inner dimension matches between data and weight. If one inner |
110 | // dimension is dynamic then it is inferred to match the other inner |
111 | // dimension. |
112 | std::vector<PrimExpr> A_shape(tensor_a->shape.begin(), tensor_a->shape.end()); |
113 | std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end()); |
114 | auto sa = A_shape.size(); |
115 | auto sb = B_shape.size(); |
116 | size_t index_swap_A; |
117 | size_t index_swap_B; |
118 | if (transpose_a && transpose_b) { |
119 | index_swap_A = sa - 2; |
120 | index_swap_B = sb - 1; |
121 | } else if (transpose_a) { |
122 | index_swap_A = sa - 2; |
123 | index_swap_B = sb - 2; |
124 | } else if (transpose_b) { |
125 | index_swap_A = sa - 1; |
126 | index_swap_B = sb - 1; |
127 | } else { |
128 | index_swap_A = sa - 1; |
129 | index_swap_B = sb - 2; |
130 | } |
131 | |
132 | // Rewrite dynamic axes to static where constraints allow. |
133 | auto tmp = A_shape[index_swap_A]; |
134 | if (A_shape[index_swap_A].as<tir::AnyNode>()) { |
135 | A_shape[index_swap_A] = B_shape[index_swap_B]; |
136 | } |
137 | if (B_shape[index_swap_B].as<tir::AnyNode>()) { |
138 | B_shape[index_swap_B] = tmp; |
139 | } |
140 | |
141 | // Update input types with new constrained shapes. |
142 | reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype)); |
143 | reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype)); |
144 | } |
145 | oshape.Set(oshape.size() - 1, transpose_b ? wshape[0] : wshape[1]); |
146 | } |
147 | } |
148 | |
149 | DataType out_dtype = param->out_dtype; |
150 | if (out_dtype.bits() == 0) { |
151 | out_dtype = tensor_a->dtype; |
152 | } |
153 | // assign output type |
154 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
155 | return true; |
156 | } |
157 | |
158 | template <typename AttrType> |
159 | bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
160 | const TypeReporter& reporter) { |
161 | ICHECK_EQ(types.size(), 3); |
162 | const auto* x = types[0].as<TensorTypeNode>(); |
163 | const auto* y = types[1].as<TensorTypeNode>(); |
164 | if (x == nullptr || y == nullptr) return false; |
165 | |
166 | const AttrType* param = attrs.as<AttrType>(); |
167 | DataType out_dtype = param->out_dtype; |
168 | if (out_dtype.bits() == 0) { |
169 | out_dtype = x->dtype; |
170 | if (x->dtype.bits() == 0) { |
171 | out_dtype = y->dtype; |
172 | } |
173 | } |
174 | TensorType meta_schedule_y{nullptr}; |
175 | if (param->meta_schedule_original_shape.size() != 0) { |
176 | meta_schedule_y = TensorType(param->meta_schedule_original_shape, out_dtype); |
177 | y = meta_schedule_y.get(); |
178 | } |
179 | ICHECK(param != nullptr); |
180 | bool transpose_a = param->transpose_a; |
181 | bool transpose_b = param->transpose_b; |
182 | Array<PrimExpr> y_shape{nullptr}; |
183 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
184 | y_shape = auto_scheduler::GetShapeFromRewrittenLayout( |
185 | param->auto_scheduler_rewritten_layout, |
186 | transpose_b ? tvm::runtime::Array<tvm::runtime::String>({"b" , "j" , "k" }) |
187 | : tvm::runtime::Array<tvm::runtime::String>({"b" , "k" , "j" })); |
188 | } else if (param->meta_schedule_original_shape.size() != 0) { |
189 | y_shape = param->meta_schedule_original_shape; |
190 | } else { |
191 | y_shape = y->shape; |
192 | } |
193 | ICHECK(x->shape.size() == 3 && y_shape.size() == 3); |
194 | const PrimExpr& xb = x->shape[0]; |
195 | const PrimExpr& xi = x->shape[transpose_a ? 2 : 1]; |
196 | const PrimExpr& xk = x->shape[transpose_a ? 1 : 2]; |
197 | const PrimExpr& yb = y_shape[0]; |
198 | const PrimExpr& yk = y_shape[transpose_b ? 2 : 1]; |
199 | const PrimExpr& yj = y_shape[transpose_b ? 1 : 2]; |
200 | |
201 | bool is_dyn = false; |
202 | for (size_t i = 0; i < 3; ++i) { |
203 | if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) { |
204 | is_dyn = true; |
205 | break; |
206 | } |
207 | } |
208 | if (!is_dyn) { |
209 | ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1)) |
210 | << "BatchDot: batch dimensions don't match, " |
211 | << " x shape=" << x->shape << ", y shape=" << y_shape; |
212 | ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, " |
213 | << " x shape=" << x->shape << ", y shape=" << y_shape; |
214 | } |
215 | |
216 | // assign output type |
217 | const auto& out_b = |
218 | xb->IsInstance<tir::AnyNode>() || yb->IsInstance<tir::AnyNode>() ? tir::Any() : max(xb, yb); |
219 | reporter->Assign(types[2], TensorType(Array<tvm::PrimExpr>({out_b, xi, yj}), out_dtype)); |
220 | return true; |
221 | } |
222 | |
223 | InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, |
224 | const Array<Layout>& new_in_layouts, |
225 | const Array<Layout>& old_in_layouts, |
226 | const Array<tvm::relay::Type>& old_in_types); |
227 | |
228 | } // namespace relay |
229 | } // namespace tvm |
230 | #endif // TVM_RELAY_OP_NN_NN_H_ |
231 | |