1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/nn/nn.h
22 * \brief Properties def of nn operators for sharing.
23 */
24#ifndef TVM_RELAY_OP_NN_NN_H_
25#define TVM_RELAY_OP_NN_NN_H_
26
27#include <tvm/auto_scheduler/compute_dag.h>
28#include <tvm/ir/attrs.h>
29#include <tvm/ir/expr.h>
30#include <tvm/relay/type.h>
31
32#include <algorithm>
33#include <utility>
34#include <vector>
35
36#include "../op_common.h"
37
38namespace tvm {
39namespace relay {
40
41template <typename AttrType>
42bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
43 const TypeReporter& reporter) {
44 ICHECK_EQ(types.size(), 3);
45 const auto* tensor_a = types[0].as<TensorTypeNode>();
46 const auto* tensor_b = types[1].as<TensorTypeNode>();
47 if (tensor_a == nullptr) return false;
48 ICHECK(static_cast<int>(tensor_a->shape.size()) != 0);
49
50 const AttrType* param = attrs.as<AttrType>();
51 ICHECK(param != nullptr);
52 TensorType meta_schedule_tensor_b{nullptr};
53 if (param->meta_schedule_original_shape.size() > 0) {
54 meta_schedule_tensor_b = TensorType(param->meta_schedule_original_shape,
55 tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype);
56 tensor_b = meta_schedule_tensor_b.get();
57 }
58 // Default set to dense layout
59 bool transpose_a = false;
60 bool transpose_b = true;
61 const auto& mattrs = attrs.as<MatmulAttrs>();
62 if (mattrs != nullptr) {
63 transpose_a = mattrs->transpose_a;
64 transpose_b = mattrs->transpose_b;
65 }
66
67 const Array<tvm::PrimExpr>& dshape = tensor_a->shape;
68 Array<tvm::PrimExpr> oshape = dshape;
69 tvm::PrimExpr reduce = dshape[dshape.size() - 1];
70 if (transpose_a) {
71 reduce = dshape[dshape.size() - 2];
72 oshape.Set((oshape.size() - 2), dshape[oshape.size() - 1]);
73 }
74 auto tensor_b_dtype = (tensor_b == nullptr ? tensor_a->dtype : tensor_b->dtype);
75 if (param->units.defined()) {
76 // validate the tensor_b shape is proper if defined
77 // Assign tensor_b type
78 const Array<IndexExpr>& wshape = transpose_b ? Array<IndexExpr>({param->units, reduce})
79 : Array<IndexExpr>({reduce, param->units});
80 // It is possible for tensor_b to be nullptr in which case we will use
81 // data dtype as the tensor_b dtype. However if tensor_b dtype is explicitly
82 // present we will use that.
83 if (param->auto_scheduler_rewritten_layout.size() != 0) {
84 // If the layout is rewritten by auto-scheduler or meta-schedule,
85 // we just forcefully apply the layout provided by auto-scheduler and
86 // skip the normal inference logic.
87 {} // do nothing
88 } else if (param->meta_schedule_original_shape.size() == 0) {
89 // Normal case: assign result to reporter
90 reporter->Assign(types[1], TensorType(wshape, tensor_b_dtype));
91 }
92 oshape.Set((oshape.size() - 1), param->units);
93 } else {
94 if (tensor_b == nullptr) return false;
95 const Array<tvm::PrimExpr>& wshape = tensor_b->shape;
96 // When tensor_b's layout has been rewritten, figure it out based on the
97 // total number of elements and input dimensions.
98 if (param->auto_scheduler_rewritten_layout.size() != 0) {
99 PrimExpr tensor_b_elements = 1;
100 for (size_t i = 0; i < wshape.size(); i++) {
101 tensor_b_elements = tensor_b_elements * wshape[i];
102 }
103 oshape.Set(oshape.size() - 1, tensor_b_elements / dshape[dshape.size() - 1]);
104 // Otherwise just pull it out of the tensor_b shape directly.
105 } else {
106 ICHECK(static_cast<int>(tensor_b->shape.size()) == 2);
107 if (param->auto_scheduler_rewritten_layout.size() == 0 &&
108 param->meta_schedule_original_shape.size() == 0) {
109 // ensure inner dimension matches between data and weight. If one inner
110 // dimension is dynamic then it is inferred to match the other inner
111 // dimension.
112 std::vector<PrimExpr> A_shape(tensor_a->shape.begin(), tensor_a->shape.end());
113 std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end());
114 auto sa = A_shape.size();
115 auto sb = B_shape.size();
116 size_t index_swap_A;
117 size_t index_swap_B;
118 if (transpose_a && transpose_b) {
119 index_swap_A = sa - 2;
120 index_swap_B = sb - 1;
121 } else if (transpose_a) {
122 index_swap_A = sa - 2;
123 index_swap_B = sb - 2;
124 } else if (transpose_b) {
125 index_swap_A = sa - 1;
126 index_swap_B = sb - 1;
127 } else {
128 index_swap_A = sa - 1;
129 index_swap_B = sb - 2;
130 }
131
132 // Rewrite dynamic axes to static where constraints allow.
133 auto tmp = A_shape[index_swap_A];
134 if (A_shape[index_swap_A].as<tir::AnyNode>()) {
135 A_shape[index_swap_A] = B_shape[index_swap_B];
136 }
137 if (B_shape[index_swap_B].as<tir::AnyNode>()) {
138 B_shape[index_swap_B] = tmp;
139 }
140
141 // Update input types with new constrained shapes.
142 reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype));
143 reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype));
144 }
145 oshape.Set(oshape.size() - 1, transpose_b ? wshape[0] : wshape[1]);
146 }
147 }
148
149 DataType out_dtype = param->out_dtype;
150 if (out_dtype.bits() == 0) {
151 out_dtype = tensor_a->dtype;
152 }
153 // assign output type
154 reporter->Assign(types[2], TensorType(oshape, out_dtype));
155 return true;
156}
157
158template <typename AttrType>
159bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
160 const TypeReporter& reporter) {
161 ICHECK_EQ(types.size(), 3);
162 const auto* x = types[0].as<TensorTypeNode>();
163 const auto* y = types[1].as<TensorTypeNode>();
164 if (x == nullptr || y == nullptr) return false;
165
166 const AttrType* param = attrs.as<AttrType>();
167 DataType out_dtype = param->out_dtype;
168 if (out_dtype.bits() == 0) {
169 out_dtype = x->dtype;
170 if (x->dtype.bits() == 0) {
171 out_dtype = y->dtype;
172 }
173 }
174 TensorType meta_schedule_y{nullptr};
175 if (param->meta_schedule_original_shape.size() != 0) {
176 meta_schedule_y = TensorType(param->meta_schedule_original_shape, out_dtype);
177 y = meta_schedule_y.get();
178 }
179 ICHECK(param != nullptr);
180 bool transpose_a = param->transpose_a;
181 bool transpose_b = param->transpose_b;
182 Array<PrimExpr> y_shape{nullptr};
183 if (param->auto_scheduler_rewritten_layout.size() != 0) {
184 y_shape = auto_scheduler::GetShapeFromRewrittenLayout(
185 param->auto_scheduler_rewritten_layout,
186 transpose_b ? tvm::runtime::Array<tvm::runtime::String>({"b", "j", "k"})
187 : tvm::runtime::Array<tvm::runtime::String>({"b", "k", "j"}));
188 } else if (param->meta_schedule_original_shape.size() != 0) {
189 y_shape = param->meta_schedule_original_shape;
190 } else {
191 y_shape = y->shape;
192 }
193 ICHECK(x->shape.size() == 3 && y_shape.size() == 3);
194 const PrimExpr& xb = x->shape[0];
195 const PrimExpr& xi = x->shape[transpose_a ? 2 : 1];
196 const PrimExpr& xk = x->shape[transpose_a ? 1 : 2];
197 const PrimExpr& yb = y_shape[0];
198 const PrimExpr& yk = y_shape[transpose_b ? 2 : 1];
199 const PrimExpr& yj = y_shape[transpose_b ? 1 : 2];
200
201 bool is_dyn = false;
202 for (size_t i = 0; i < 3; ++i) {
203 if (x->shape[i].as<tir::AnyNode>() != nullptr || y_shape[i].as<tir::AnyNode>() != nullptr) {
204 is_dyn = true;
205 break;
206 }
207 }
208 if (!is_dyn) {
209 ICHECK(reporter->AssertEQ(xb, yb) || reporter->AssertEQ(xb, 1) || reporter->AssertEQ(yb, 1))
210 << "BatchDot: batch dimensions don't match, "
211 << " x shape=" << x->shape << ", y shape=" << y_shape;
212 ICHECK(reporter->AssertEQ(xk, yk)) << "BatchDot: shapes of x and y is inconsistent, "
213 << " x shape=" << x->shape << ", y shape=" << y_shape;
214 }
215
216 // assign output type
217 const auto& out_b =
218 xb->IsInstance<tir::AnyNode>() || yb->IsInstance<tir::AnyNode>() ? tir::Any() : max(xb, yb);
219 reporter->Assign(types[2], TensorType(Array<tvm::PrimExpr>({out_b, xi, yj}), out_dtype));
220 return true;
221}
222
223InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
224 const Array<Layout>& new_in_layouts,
225 const Array<Layout>& old_in_layouts,
226 const Array<tvm::relay::Type>& old_in_types);
227
228} // namespace relay
229} // namespace tvm
230#endif // TVM_RELAY_OP_NN_NN_H_
231