1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file mac_count.cc
23 * \brief Pass to roughly count the number of MACs (Multiply-Accumulate)
24 * operations of a model. Only MACs in CONV and Dense ops are counted.
25 * This pass is valid after the type infer pass is called,
26 * otherwise the count is 0.
27 */
28
29#include <tvm/relay/analysis.h>
30#include <tvm/relay/attrs/nn.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/op.h>
33#include <tvm/relay/op_attr_types.h>
34
35#include "../transforms/pattern_utils.h"
36
37namespace tvm {
38namespace relay {
39
40namespace mac_count {
41
42inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
43 int64_t ret = 1;
44 for (size_t i = 0; i < arr.size(); i++) {
45 const auto* intImm = arr[i].as<IntImmNode>();
46 ret *= static_cast<int64_t>(intImm->value);
47 }
48 return ret;
49}
50
51/*
52 * \brief Preparation function for MAC count.
53 * \param call_node The call node.
54 * \return The number of MACs.
55 */
56using FMacCount = runtime::TypedPackedFunc<int64_t(const Call& call_node)>;
57
58//----------------------------------------------
59// Per operator defs for MAC count
60//----------------------------------------------
61
62int64_t ConvMacCount(const Call& call_node) {
63 if (!call_node->checked_type_.defined()) {
64 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
65 return 0;
66 }
67 Array<Expr> args = call_node->args;
68 ICHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2.";
69 const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
70 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
71 Array<IndexExpr> data_shape = data_type->shape;
72 std::string data_layout = conv_2d_attr->data_layout;
73 int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
74 int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
75 ICHECK_NE(C_ind, -1) << "There is no input channel dimension.";
76 int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
77 if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
78 Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
79 ICHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2.";
80 const auto* expr = call_node->checked_type().as<TensorTypeNode>();
81 Array<IndexExpr> output_tensor = expr->shape;
82 ICHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
83 << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
84 int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
85 ICHECK_EQ(input_channel % conv_2d_attr->groups, 0)
86 << "The number of input channels is not divisble by groups.";
87 count *= input_channel / conv_2d_attr->groups;
88 return count;
89}
90
91int64_t Conv2dTransposeMacCount(const Call& call_node) {
92 if (!call_node->checked_type_.defined()) {
93 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
94 return 0;
95 }
96 Array<Expr> args = call_node->args;
97 ICHECK_EQ(args.size(), 2)
98 << "The number of input arguments of a CONV 2D Transpose node should be 2.";
99 const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
100 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
101 Array<IndexExpr> data_shape = data_type->shape;
102 std::string data_layout = conv_2d_transpose_attr->data_layout;
103 int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
104 int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
105 ICHECK_NE(C_ind, -1) << "There is no input channel dimension.";
106 int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
107 if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
108 Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
109 ICHECK_EQ(kernel_size.size(), 2)
110 << "The dimension of the kernel in Conv 2D Transpose should be 2.";
111 const auto* expr = call_node->checked_type().as<TensorTypeNode>();
112 Array<IndexExpr> output_tensor = expr->shape;
113 ICHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
114 << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5.";
115 int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
116 ICHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0)
117 << "The number of input channels is not divisble by groups.";
118 count *= input_channel / conv_2d_transpose_attr->groups;
119 return count;
120}
121
122int64_t DenseMacCount(const Call& call_node) {
123 if (!call_node->checked_type_.defined()) {
124 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
125 return 0;
126 }
127 Array<Expr> args = call_node->args;
128 ICHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2.";
129 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
130 const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
131 Array<IndexExpr> data_shape = data_type->shape;
132 Array<IndexExpr> weight_shape = weight_type->shape;
133 ICHECK(data_shape.size() == 2 && weight_shape.size() == 2)
134 << "The dimension of an input tensor to Dense node should be 2.";
135 int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImmNode>()->value);
136 int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImmNode>()->value);
137 int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImmNode>()->value);
138 int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImmNode>()->value);
139 ICHECK_EQ(d2, d4) << "The dimensions of input arguments do not match.";
140 int64_t count = d1 * d2 * d3;
141 return count;
142}
143
144int64_t BatchMatmulMacCount(const Call& call_node) {
145 if (!call_node->checked_type_.defined()) {
146 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
147 return 0;
148 }
149 Array<Expr> args = call_node->args;
150 ICHECK_EQ(args.size(), 2);
151 Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
152 Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
153 int64_t batch = x_shape[0].as<IntImmNode>()->value;
154 int64_t m = x_shape[1].as<IntImmNode>()->value;
155 int64_t k = x_shape[2].as<IntImmNode>()->value;
156 int64_t n = y_shape[1].as<IntImmNode>()->value;
157 return batch * m * k * n;
158}
159
160RELAY_REGISTER_OP("nn.conv2d").set_attr<FMacCount>("FMacCount", ConvMacCount);
161
162RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr<FMacCount>("FMacCount", Conv2dTransposeMacCount);
163
164RELAY_REGISTER_OP("nn.dense").set_attr<FMacCount>("FMacCount", DenseMacCount);
165
166RELAY_REGISTER_OP("nn.batch_matmul").set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
167
168class MacCounter : private ExprVisitor {
169 public:
170 MacCounter() { count_ = 0; }
171 static int64_t GetTotalMacNumber(const Expr& expr) {
172 LOG(INFO) << "This pass only counts MACs in direct conv2d, "
173 << "conv2d_transpose, dense, and batch_matmul ops";
174 MacCounter counter;
175 counter(expr);
176 return counter.count_;
177 }
178
179 private:
180 void VisitExpr_(const CallNode* call_node) final {
181 static const auto& fprep = Op::GetAttrMap<FMacCount>("FMacCount");
182 auto f = fprep.get(call_node->op, nullptr);
183 if (f != nullptr) count_ += f(GetRef<Call>(call_node));
184 ExprVisitor::VisitExpr_(call_node);
185 }
186
187 int64_t count_;
188};
189
190int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); }
191
192TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber").set_body_typed(GetTotalMacNumber);
193
194} // namespace mac_count
195} // namespace relay
196} // namespace tvm
197