1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file mac_count.cc |
23 | * \brief Pass to roughly count the number of MACs (Multiply-Accumulate) |
24 | * operations of a model. Only MACs in CONV and Dense ops are counted. |
25 | * This pass is valid after the type infer pass is called, |
26 | * otherwise the count is 0. |
27 | */ |
28 | |
29 | #include <tvm/relay/analysis.h> |
30 | #include <tvm/relay/attrs/nn.h> |
31 | #include <tvm/relay/expr_functor.h> |
32 | #include <tvm/relay/op.h> |
33 | #include <tvm/relay/op_attr_types.h> |
34 | |
35 | #include "../transforms/pattern_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | namespace mac_count { |
41 | |
42 | inline int64_t GetCartesianProd(Array<IndexExpr> arr) { |
43 | int64_t ret = 1; |
44 | for (size_t i = 0; i < arr.size(); i++) { |
45 | const auto* intImm = arr[i].as<IntImmNode>(); |
46 | ret *= static_cast<int64_t>(intImm->value); |
47 | } |
48 | return ret; |
49 | } |
50 | |
51 | /* |
52 | * \brief Preparation function for MAC count. |
53 | * \param call_node The call node. |
54 | * \return The number of MACs. |
55 | */ |
56 | using FMacCount = runtime::TypedPackedFunc<int64_t(const Call& call_node)>; |
57 | |
58 | //---------------------------------------------- |
59 | // Per operator defs for MAC count |
60 | //---------------------------------------------- |
61 | |
62 | int64_t ConvMacCount(const Call& call_node) { |
63 | if (!call_node->checked_type_.defined()) { |
64 | LOG(WARNING) << "The infer type pass should be called before the mac count pass" ; |
65 | return 0; |
66 | } |
67 | Array<Expr> args = call_node->args; |
68 | ICHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2." ; |
69 | const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>(); |
70 | const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); |
71 | Array<IndexExpr> data_shape = data_type->shape; |
72 | std::string data_layout = conv_2d_attr->data_layout; |
73 | int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); |
74 | int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); |
75 | ICHECK_NE(C_ind, -1) << "There is no input channel dimension." ; |
76 | int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value); |
77 | if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value); |
78 | Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size; |
79 | ICHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2." ; |
80 | const auto* expr = call_node->checked_type().as<TensorTypeNode>(); |
81 | Array<IndexExpr> output_tensor = expr->shape; |
82 | ICHECK(output_tensor.size() == 4 || output_tensor.size() == 5) |
83 | << "The dimension of the output tensor in Conv 2D should be 4 or 5." ; |
84 | int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); |
85 | ICHECK_EQ(input_channel % conv_2d_attr->groups, 0) |
86 | << "The number of input channels is not divisble by groups." ; |
87 | count *= input_channel / conv_2d_attr->groups; |
88 | return count; |
89 | } |
90 | |
91 | int64_t Conv2dTransposeMacCount(const Call& call_node) { |
92 | if (!call_node->checked_type_.defined()) { |
93 | LOG(WARNING) << "The infer type pass should be called before the mac count pass" ; |
94 | return 0; |
95 | } |
96 | Array<Expr> args = call_node->args; |
97 | ICHECK_EQ(args.size(), 2) |
98 | << "The number of input arguments of a CONV 2D Transpose node should be 2." ; |
99 | const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>(); |
100 | const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); |
101 | Array<IndexExpr> data_shape = data_type->shape; |
102 | std::string data_layout = conv_2d_transpose_attr->data_layout; |
103 | int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); |
104 | int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); |
105 | ICHECK_NE(C_ind, -1) << "There is no input channel dimension." ; |
106 | int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value); |
107 | if (c_ind != -1) input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value); |
108 | Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size; |
109 | ICHECK_EQ(kernel_size.size(), 2) |
110 | << "The dimension of the kernel in Conv 2D Transpose should be 2." ; |
111 | const auto* expr = call_node->checked_type().as<TensorTypeNode>(); |
112 | Array<IndexExpr> output_tensor = expr->shape; |
113 | ICHECK(output_tensor.size() == 4 || output_tensor.size() == 5) |
114 | << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5." ; |
115 | int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); |
116 | ICHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) |
117 | << "The number of input channels is not divisble by groups." ; |
118 | count *= input_channel / conv_2d_transpose_attr->groups; |
119 | return count; |
120 | } |
121 | |
122 | int64_t DenseMacCount(const Call& call_node) { |
123 | if (!call_node->checked_type_.defined()) { |
124 | LOG(WARNING) << "The infer type pass should be called before the mac count pass" ; |
125 | return 0; |
126 | } |
127 | Array<Expr> args = call_node->args; |
128 | ICHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2." ; |
129 | const auto* data_type = args[0]->checked_type().as<TensorTypeNode>(); |
130 | const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>(); |
131 | Array<IndexExpr> data_shape = data_type->shape; |
132 | Array<IndexExpr> weight_shape = weight_type->shape; |
133 | ICHECK(data_shape.size() == 2 && weight_shape.size() == 2) |
134 | << "The dimension of an input tensor to Dense node should be 2." ; |
135 | int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImmNode>()->value); |
136 | int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImmNode>()->value); |
137 | int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImmNode>()->value); |
138 | int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImmNode>()->value); |
139 | ICHECK_EQ(d2, d4) << "The dimensions of input arguments do not match." ; |
140 | int64_t count = d1 * d2 * d3; |
141 | return count; |
142 | } |
143 | |
144 | int64_t BatchMatmulMacCount(const Call& call_node) { |
145 | if (!call_node->checked_type_.defined()) { |
146 | LOG(WARNING) << "The infer type pass should be called before the mac count pass" ; |
147 | return 0; |
148 | } |
149 | Array<Expr> args = call_node->args; |
150 | ICHECK_EQ(args.size(), 2); |
151 | Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape; |
152 | Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape; |
153 | int64_t batch = x_shape[0].as<IntImmNode>()->value; |
154 | int64_t m = x_shape[1].as<IntImmNode>()->value; |
155 | int64_t k = x_shape[2].as<IntImmNode>()->value; |
156 | int64_t n = y_shape[1].as<IntImmNode>()->value; |
157 | return batch * m * k * n; |
158 | } |
159 | |
160 | RELAY_REGISTER_OP("nn.conv2d" ).set_attr<FMacCount>("FMacCount" , ConvMacCount); |
161 | |
162 | RELAY_REGISTER_OP("nn.conv2d_transpose" ).set_attr<FMacCount>("FMacCount" , Conv2dTransposeMacCount); |
163 | |
164 | RELAY_REGISTER_OP("nn.dense" ).set_attr<FMacCount>("FMacCount" , DenseMacCount); |
165 | |
166 | RELAY_REGISTER_OP("nn.batch_matmul" ).set_attr<FMacCount>("FMacCount" , BatchMatmulMacCount); |
167 | |
168 | class MacCounter : private ExprVisitor { |
169 | public: |
170 | MacCounter() { count_ = 0; } |
171 | static int64_t GetTotalMacNumber(const Expr& expr) { |
172 | LOG(INFO) << "This pass only counts MACs in direct conv2d, " |
173 | << "conv2d_transpose, dense, and batch_matmul ops" ; |
174 | MacCounter counter; |
175 | counter(expr); |
176 | return counter.count_; |
177 | } |
178 | |
179 | private: |
180 | void VisitExpr_(const CallNode* call_node) final { |
181 | static const auto& fprep = Op::GetAttrMap<FMacCount>("FMacCount" ); |
182 | auto f = fprep.get(call_node->op, nullptr); |
183 | if (f != nullptr) count_ += f(GetRef<Call>(call_node)); |
184 | ExprVisitor::VisitExpr_(call_node); |
185 | } |
186 | |
187 | int64_t count_; |
188 | }; |
189 | |
190 | int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } |
191 | |
192 | TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber" ).set_body_typed(GetTotalMacNumber); |
193 | |
194 | } // namespace mac_count |
195 | } // namespace relay |
196 | } // namespace tvm |
197 | |