1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file bitserial.cc
22 * \brief Property def of bitserial operators.
23 */
24
25#include <tvm/relay/attrs/bitserial.h>
26#include <tvm/relay/op.h>
27#include <tvm/tir/data_layout.h>
28
29#include "../../transforms/infer_layout_utils.h"
30#include "../op_common.h"
31
32namespace tvm {
33namespace relay {
34
35// relay.nn.bitpack
36TVM_REGISTER_NODE_TYPE(BitPackAttrs);
37
38template <typename T>
39InferCorrectLayoutOutput BinaryConv2DInferCorrectLayout(
40 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
41 const Array<tvm::relay::Type>& old_in_types) {
42 const T* params = attrs.as<T>();
43
44 // We always make other operators to fit the layouts of convolution layers
45 // So this inference ignores all inputs
46 return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout},
47 {params->data_layout}, attrs);
48}
49
50bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
51 const TypeReporter& reporter) {
52 const BitPackAttrs* param = attrs.as<BitPackAttrs>();
53 ICHECK_EQ(types.size(), 2);
54 const auto* data = types[0].as<TensorTypeNode>();
55 ICHECK(data);
56 int ndim = data->shape.size();
57 int bits = param->bits;
58 int pack_axis = param->pack_axis;
59 int bit_axis = param->bit_axis;
60 DataType pack_type = param->pack_type;
61
62 int pack_bits = pack_type.bits();
63
64 Array<IndexExpr> out_shape;
65 for (int i = 0; i < ndim; ++i) {
66 if (i == bit_axis) {
67 out_shape.push_back(bits);
68 if (i == pack_axis) {
69 out_shape.push_back(indexdiv(data->shape[i], pack_bits));
70 } else {
71 out_shape.push_back(data->shape[i]);
72 }
73 } else if (i == pack_axis) {
74 out_shape.push_back(indexdiv(data->shape[i], pack_bits));
75 } else {
76 out_shape.push_back(data->shape[i]);
77 }
78 }
79 // Add extra check for last axis expansion.
80 if (bit_axis == ndim) {
81 out_shape.push_back(bits);
82 }
83
84 reporter->Assign(types[1], TensorType(out_shape, pack_type));
85 return true;
86}
87
88Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type,
89 String name) {
90 auto attrs = make_object<BitPackAttrs>();
91 attrs->bits = bits;
92 attrs->pack_axis = pack_axis;
93 attrs->bit_axis = bit_axis;
94 attrs->pack_type = pack_type;
95 attrs->name = name;
96 static const Op& op = Op::Get("nn.bitpack");
97 return Call(op, {data}, Attrs(attrs), {});
98}
99
100TVM_REGISTER_GLOBAL("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
101
102RELAY_REGISTER_OP("nn.bitpack")
103 .describe(R"code(Bitpack layer that prepares data for bitserial operations.
104
105This layer backs the bits of an input into a single datatype, allowing
106efficient implementation of bitserial operations.
107
108- **data**: Input tensor of any shape, dimension that is to be
109 packed must be divisible by number of bits.
110- **out**: Packed tensor with shape appropriately compressed.
111)code" TVM_ADD_FILELINE)
112 .set_num_inputs(1)
113 .set_attrs_type<BitPackAttrs>()
114 .add_argument("data", "Tensor", "Input data.")
115 .set_support_level(2)
116 .add_type_rel("BitPack", BitPackRel)
117 .set_attr<TOpPattern>("TOpPattern", kInjective);
118
119// relay.nn.bitserial_conv2d
120TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs);
121
122bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
123 const TypeReporter& reporter) {
124 ICHECK_EQ(types.size(), 3);
125 const auto* data = types[0].as<TensorTypeNode>();
126 if (data == nullptr) return false;
127
128 const BinaryConv2DAttrs* param = attrs.as<BinaryConv2DAttrs>();
129 ICHECK(param != nullptr);
130
131 static const Layout kNCHW("NCHW");
132
133 const Layout in_layout(param->data_layout);
134 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
135 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
136 ICHECK(param->channels.defined());
137 ICHECK(param->kernel_size.defined());
138 Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
139 IndexExpr pad_h, pad_w;
140 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
141 oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1);
142 oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1);
143 DataType out_dtype = param->out_dtype;
144 oshape = trans_in_layout.BackwardShape(oshape);
145 // assign output type
146 reporter->Assign(types[2], TensorType(oshape, out_dtype));
147 return true;
148}
149
150// Positional relay function to create binaryconv2d operator
151// used by frontend FFI.
152Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
153 IndexExpr channels, Array<IndexExpr> kernel_size, int activation_bits,
154 int weight_bits, String data_layout, String kernel_layout,
155 DataType pack_dtype, DataType out_dtype, bool unipolar) {
156 auto attrs = make_object<BinaryConv2DAttrs>();
157 attrs->strides = std::move(strides);
158 attrs->padding = std::move(padding);
159 attrs->channels = std::move(channels);
160 attrs->kernel_size = std::move(kernel_size);
161 attrs->activation_bits = activation_bits;
162 attrs->weight_bits = weight_bits;
163 attrs->data_layout = std::move(data_layout);
164 attrs->kernel_layout = std::move(kernel_layout);
165 attrs->pack_dtype = std::move(pack_dtype);
166 attrs->out_dtype = std::move(out_dtype);
167 attrs->unipolar = unipolar;
168 static const Op& op = Op::Get("nn.bitserial_conv2d");
169 return Call(op, {data, weight}, Attrs(attrs), {});
170}
171
172TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D);
173
174RELAY_REGISTER_OP("nn.bitserial_conv2d")
175 .describe(R"code(2D convolution using packed binary computation.
176
177This layer creates a convolution kernel that is convolved with the
178layer input using bitserial computation. This enables faster processing
179on some platforms.
180
181- **data**: 4D input tensor that can be either `NCHW` or `NHWC` layout.
182
183- **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D).
184 When data is NCHW, weight is expected to be OIHW or OIHWi.
185 When data is NHWC weight is expected to be HWIO or HWIOi.
186
187- **out**: Output with same layout as input.
188)code" TVM_ADD_FILELINE)
189 .set_attrs_type<BinaryConv2DAttrs>()
190 .set_num_inputs(2)
191 .add_argument("data", "Tensor", "The input tensor.")
192 .add_argument("weight", "Tensor", "The weight tensor.")
193 .set_support_level(2)
194 .add_type_rel("BinaryConv2D", BinaryConv2DRel)
195 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
196 BinaryConv2DInferCorrectLayout<BinaryConv2DAttrs>)
197 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
198
199// relay.nn.bitserial_dense
200TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs);
201
202bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
203 const TypeReporter& reporter) {
204 ICHECK_EQ(types.size(), 3);
205 const auto* data = types[0].as<TensorTypeNode>();
206 if (data == nullptr) return false;
207
208 const BinaryDenseAttrs* param = attrs.as<BinaryDenseAttrs>();
209 ICHECK(param != nullptr);
210
211 ICHECK(static_cast<int>(data->shape.size()) != 0);
212 ICHECK(param->units.defined());
213
214 Array<tvm::PrimExpr> oshape = data->shape;
215 oshape.Set((oshape.size() - 1), param->units);
216
217 DataType out_dtype = param->out_dtype;
218 if (out_dtype.bits() == 0) {
219 out_dtype = data->dtype;
220 }
221
222 // Assign output type.
223 reporter->Assign(types[2], TensorType(oshape, out_dtype));
224 return true;
225}
226
227// Positional relay function to create bitserial dense operator used by frontend FFI.
228Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits,
229 DataType pack_dtype, DataType out_dtype, bool unipolar) {
230 auto attrs = make_object<BinaryDenseAttrs>();
231 attrs->units = units;
232 attrs->data_bits = data_bits;
233 attrs->weight_bits = weight_bits;
234 attrs->pack_dtype = pack_dtype;
235 attrs->out_dtype = out_dtype;
236 attrs->unipolar = unipolar;
237 static const Op& op = Op::Get("nn.bitserial_dense");
238 return Call(op, {data, weight}, Attrs(attrs), {});
239}
240
241TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense);
242
243RELAY_REGISTER_OP("nn.bitserial_dense")
244 .describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`.
245
246- **data**: `(x1, x2, ..., xn, input_dim)`
247- **weight**: `(units, input_dim)`
248- **out**: `(x1, x2, ..., xn, units)`.
249
250)code" TVM_ADD_FILELINE)
251 .set_attrs_type<BinaryDenseAttrs>()
252 .set_num_inputs(2)
253 .add_argument("data", "2D Tensor", "Input data.")
254 .add_argument("weight", "2D Tensor", "Weight matrix.")
255 .set_support_level(1)
256 .add_type_rel("BinaryDense", BinaryDenseRel)
257 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
258
259} // namespace relay
260} // namespace tvm
261