1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bitserial.cc |
22 | * \brief Property def of bitserial operators. |
23 | */ |
24 | |
25 | #include <tvm/relay/attrs/bitserial.h> |
26 | #include <tvm/relay/op.h> |
27 | #include <tvm/tir/data_layout.h> |
28 | |
29 | #include "../../transforms/infer_layout_utils.h" |
30 | #include "../op_common.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | // relay.nn.bitpack |
36 | TVM_REGISTER_NODE_TYPE(BitPackAttrs); |
37 | |
38 | template <typename T> |
39 | InferCorrectLayoutOutput BinaryConv2DInferCorrectLayout( |
40 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
41 | const Array<tvm::relay::Type>& old_in_types) { |
42 | const T* params = attrs.as<T>(); |
43 | |
44 | // We always make other operators to fit the layouts of convolution layers |
45 | // So this inference ignores all inputs |
46 | return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout}, |
47 | {params->data_layout}, attrs); |
48 | } |
49 | |
50 | bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
51 | const TypeReporter& reporter) { |
52 | const BitPackAttrs* param = attrs.as<BitPackAttrs>(); |
53 | ICHECK_EQ(types.size(), 2); |
54 | const auto* data = types[0].as<TensorTypeNode>(); |
55 | ICHECK(data); |
56 | int ndim = data->shape.size(); |
57 | int bits = param->bits; |
58 | int pack_axis = param->pack_axis; |
59 | int bit_axis = param->bit_axis; |
60 | DataType pack_type = param->pack_type; |
61 | |
62 | int pack_bits = pack_type.bits(); |
63 | |
64 | Array<IndexExpr> out_shape; |
65 | for (int i = 0; i < ndim; ++i) { |
66 | if (i == bit_axis) { |
67 | out_shape.push_back(bits); |
68 | if (i == pack_axis) { |
69 | out_shape.push_back(indexdiv(data->shape[i], pack_bits)); |
70 | } else { |
71 | out_shape.push_back(data->shape[i]); |
72 | } |
73 | } else if (i == pack_axis) { |
74 | out_shape.push_back(indexdiv(data->shape[i], pack_bits)); |
75 | } else { |
76 | out_shape.push_back(data->shape[i]); |
77 | } |
78 | } |
79 | // Add extra check for last axis expansion. |
80 | if (bit_axis == ndim) { |
81 | out_shape.push_back(bits); |
82 | } |
83 | |
84 | reporter->Assign(types[1], TensorType(out_shape, pack_type)); |
85 | return true; |
86 | } |
87 | |
88 | Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, |
89 | String name) { |
90 | auto attrs = make_object<BitPackAttrs>(); |
91 | attrs->bits = bits; |
92 | attrs->pack_axis = pack_axis; |
93 | attrs->bit_axis = bit_axis; |
94 | attrs->pack_type = pack_type; |
95 | attrs->name = name; |
96 | static const Op& op = Op::Get("nn.bitpack" ); |
97 | return Call(op, {data}, Attrs(attrs), {}); |
98 | } |
99 | |
100 | TVM_REGISTER_GLOBAL("relay.op.nn._make.bitpack" ).set_body_typed(MakeBitPack); |
101 | |
102 | RELAY_REGISTER_OP("nn.bitpack" ) |
103 | .describe(R"code(Bitpack layer that prepares data for bitserial operations. |
104 | |
105 | This layer backs the bits of an input into a single datatype, allowing |
106 | efficient implementation of bitserial operations. |
107 | |
108 | - **data**: Input tensor of any shape, dimension that is to be |
109 | packed must be divisible by number of bits. |
110 | - **out**: Packed tensor with shape appropriately compressed. |
111 | )code" TVM_ADD_FILELINE) |
112 | .set_num_inputs(1) |
113 | .set_attrs_type<BitPackAttrs>() |
114 | .add_argument("data" , "Tensor" , "Input data." ) |
115 | .set_support_level(2) |
116 | .add_type_rel("BitPack" , BitPackRel) |
117 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
118 | |
119 | // relay.nn.bitserial_conv2d |
120 | TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); |
121 | |
122 | bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
123 | const TypeReporter& reporter) { |
124 | ICHECK_EQ(types.size(), 3); |
125 | const auto* data = types[0].as<TensorTypeNode>(); |
126 | if (data == nullptr) return false; |
127 | |
128 | const BinaryConv2DAttrs* param = attrs.as<BinaryConv2DAttrs>(); |
129 | ICHECK(param != nullptr); |
130 | |
131 | static const Layout kNCHW("NCHW" ); |
132 | |
133 | const Layout in_layout(param->data_layout); |
134 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
135 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
136 | ICHECK(param->channels.defined()); |
137 | ICHECK(param->kernel_size.defined()); |
138 | Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0}); |
139 | IndexExpr pad_h, pad_w; |
140 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
141 | oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); |
142 | oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); |
143 | DataType out_dtype = param->out_dtype; |
144 | oshape = trans_in_layout.BackwardShape(oshape); |
145 | // assign output type |
146 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
147 | return true; |
148 | } |
149 | |
150 | // Positional relay function to create binaryconv2d operator |
151 | // used by frontend FFI. |
152 | Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
153 | IndexExpr channels, Array<IndexExpr> kernel_size, int activation_bits, |
154 | int weight_bits, String data_layout, String kernel_layout, |
155 | DataType pack_dtype, DataType out_dtype, bool unipolar) { |
156 | auto attrs = make_object<BinaryConv2DAttrs>(); |
157 | attrs->strides = std::move(strides); |
158 | attrs->padding = std::move(padding); |
159 | attrs->channels = std::move(channels); |
160 | attrs->kernel_size = std::move(kernel_size); |
161 | attrs->activation_bits = activation_bits; |
162 | attrs->weight_bits = weight_bits; |
163 | attrs->data_layout = std::move(data_layout); |
164 | attrs->kernel_layout = std::move(kernel_layout); |
165 | attrs->pack_dtype = std::move(pack_dtype); |
166 | attrs->out_dtype = std::move(out_dtype); |
167 | attrs->unipolar = unipolar; |
168 | static const Op& op = Op::Get("nn.bitserial_conv2d" ); |
169 | return Call(op, {data, weight}, Attrs(attrs), {}); |
170 | } |
171 | |
172 | TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_conv2d" ).set_body_typed(MakeBinaryConv2D); |
173 | |
174 | RELAY_REGISTER_OP("nn.bitserial_conv2d" ) |
175 | .describe(R"code(2D convolution using packed binary computation. |
176 | |
177 | This layer creates a convolution kernel that is convolved with the |
178 | layer input using bitserial computation. This enables faster processing |
179 | on some platforms. |
180 | |
181 | - **data**: 4D input tensor that can be either `NCHW` or `NHWC` layout. |
182 | |
183 | - **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D). |
184 | When data is NCHW, weight is expected to be OIHW or OIHWi. |
185 | When data is NHWC weight is expected to be HWIO or HWIOi. |
186 | |
187 | - **out**: Output with same layout as input. |
188 | )code" TVM_ADD_FILELINE) |
189 | .set_attrs_type<BinaryConv2DAttrs>() |
190 | .set_num_inputs(2) |
191 | .add_argument("data" , "Tensor" , "The input tensor." ) |
192 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
193 | .set_support_level(2) |
194 | .add_type_rel("BinaryConv2D" , BinaryConv2DRel) |
195 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
196 | BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>) |
197 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
198 | |
199 | // relay.nn.bitserial_dense |
200 | TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); |
201 | |
202 | bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
203 | const TypeReporter& reporter) { |
204 | ICHECK_EQ(types.size(), 3); |
205 | const auto* data = types[0].as<TensorTypeNode>(); |
206 | if (data == nullptr) return false; |
207 | |
208 | const BinaryDenseAttrs* param = attrs.as<BinaryDenseAttrs>(); |
209 | ICHECK(param != nullptr); |
210 | |
211 | ICHECK(static_cast<int>(data->shape.size()) != 0); |
212 | ICHECK(param->units.defined()); |
213 | |
214 | Array<tvm::PrimExpr> oshape = data->shape; |
215 | oshape.Set((oshape.size() - 1), param->units); |
216 | |
217 | DataType out_dtype = param->out_dtype; |
218 | if (out_dtype.bits() == 0) { |
219 | out_dtype = data->dtype; |
220 | } |
221 | |
222 | // Assign output type. |
223 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
224 | return true; |
225 | } |
226 | |
227 | // Positional relay function to create bitserial dense operator used by frontend FFI. |
228 | Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits, |
229 | DataType pack_dtype, DataType out_dtype, bool unipolar) { |
230 | auto attrs = make_object<BinaryDenseAttrs>(); |
231 | attrs->units = units; |
232 | attrs->data_bits = data_bits; |
233 | attrs->weight_bits = weight_bits; |
234 | attrs->pack_dtype = pack_dtype; |
235 | attrs->out_dtype = out_dtype; |
236 | attrs->unipolar = unipolar; |
237 | static const Op& op = Op::Get("nn.bitserial_dense" ); |
238 | return Call(op, {data, weight}, Attrs(attrs), {}); |
239 | } |
240 | |
241 | TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_dense" ).set_body_typed(MakeBinaryDense); |
242 | |
243 | RELAY_REGISTER_OP("nn.bitserial_dense" ) |
244 | .describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`. |
245 | |
246 | - **data**: `(x1, x2, ..., xn, input_dim)` |
247 | - **weight**: `(units, input_dim)` |
248 | - **out**: `(x1, x2, ..., xn, units)`. |
249 | |
250 | )code" TVM_ADD_FILELINE) |
251 | .set_attrs_type<BinaryDenseAttrs>() |
252 | .set_num_inputs(2) |
253 | .add_argument("data" , "2D Tensor" , "Input data." ) |
254 | .add_argument("weight" , "2D Tensor" , "Weight matrix." ) |
255 | .set_support_level(1) |
256 | .add_type_rel("BinaryDense" , BinaryDenseRel) |
257 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
258 | |
259 | } // namespace relay |
260 | } // namespace tvm |
261 | |