1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Binary op constructions
22 * \file nn/bnn.h
23 */
24#ifndef TVM_TOPI_NN_BNN_H_
25#define TVM_TOPI_NN_BNN_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/te/operation.h>
29#include <tvm/topi/detail/constant_utils.h>
30#include <tvm/topi/tags.h>
31
32#include <string>
33
34namespace tvm {
35namespace topi {
36namespace nn {
37
38using namespace tvm::te;
39
40/*!
41 * \brief Binarization and bit-packing along a certain axis.
42 *
43 * \param data N-D tensor, can be any layout
44 * \param axis The axis along which to do binarization and bit-packing. This axis
45 * must have a size equal to an integer multiple of 32.
46 * \param name The name of the operation
47 * \param tag The tag to mark the operation
48 *
49 * \return Output tensor with dtype uint32
50 */
51inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis,
52 std::string name = "PackedInput",
53 std::string tag = "binarize_pack") {
54 auto ishape = data->shape;
55 ICHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
56 << "binarize_pack: axis size must be a multiple of 32";
57
58 arith::Analyzer analyzer;
59 auto n = ishape.size();
60 Array<PrimExpr> oshape;
61 for (size_t i = 0; i < n; ++i) {
62 oshape.push_back(i == static_cast<size_t>(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32))
63 : ishape[i]);
64 }
65
66 return tvm::te::compute(
67 oshape,
68 [&](const Array<Var>& indices) {
69 Array<PrimExpr> start_idx;
70 for (size_t i = 0; i < n; ++i) {
71 start_idx.push_back(i == static_cast<size_t>(axis) ? indices[i] * 32
72 : static_cast<PrimExpr>(indices[i]));
73 }
74 auto packed = make_const(DataType::UInt(32), 0);
75 for (size_t j = 0; j < 32; ++j) {
76 Array<PrimExpr> idx;
77 for (size_t i = 0; i < n; ++i) {
78 idx.push_back(i == static_cast<size_t>(axis) ? start_idx[i] + static_cast<int>(j)
79 : start_idx[i]);
80 }
81 auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0);
82 packed = (packed | sign);
83 if (j == 31) {
84 return packed;
85 }
86 packed = packed << 1;
87 }
88 return packed; // never reached, but suppress compiler warning
89 },
90 name, tag);
91}
92
93/*!
94 * \brief Binary matrix multiplication using xor and bit-count
95 *
96 * \param data Tensor with shape [batch, in_dim], dtype is uint32
97 * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32
98 *
99 * \return Tensor with shape [batch, out_dim], dtype is float32
100 */
101inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) {
102 ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
103 ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
104 ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data";
105 ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight";
106
107 auto batch = data->shape[0];
108 auto in_dim = data->shape[1];
109 auto out_dim = weight->shape[0];
110
111 auto k = tvm::te::reduce_axis(Range(0, in_dim), "k");
112 auto matmul = tvm::te::compute(
113 {batch, out_dim},
114 [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor",
115 "binary_dense");
116
117 return tvm::te::compute(
118 {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor",
119 kElementWise);
120}
121
122} // namespace nn
123} // namespace topi
124} // namespace tvm
125#endif // TVM_TOPI_NN_BNN_H_
126