1 | #include <ATen/Utils.h> |
2 | #include <c10/core/ScalarType.h> |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/accumulate.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/jit/ir/constants.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/jit/jit_log.h> |
9 | #include <torch/csrc/jit/passes/constant_propagation.h> |
10 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
11 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
12 | #include <torch/csrc/jit/passes/frozen_conv_folding.h> |
13 | #include <torch/csrc/jit/passes/utils/optimization_utils.h> |
14 | #include <torch/csrc/jit/tensorexpr/types.h> |
15 | |
16 | #ifndef AT_PER_OPERATOR_HEADERS |
17 | #include <ATen/Functions.h> |
18 | #else |
19 | #include <ATen/ops/ones_like.h> |
20 | #include <ATen/ops/zeros.h> |
21 | #include <ATen/ops/zeros_like.h> |
22 | #endif |
23 | |
24 | namespace torch { |
25 | namespace jit { |
26 | |
27 | namespace { |
28 | |
29 | using Tensor = at::Tensor; |
30 | |
31 | bool supportedConvNode(Node* n) { |
32 | switch (n->kind()) { |
33 | case aten::conv1d: |
34 | case aten::conv2d: |
35 | case aten::conv3d: |
36 | return true; |
37 | case aten::_convolution: { |
38 | auto transposed_conv = |
39 | constant_as<bool>(n->namedInput("transposed" )).value_or(true); |
40 | // dont handle transposed conv yet or not-constant transpose parameter |
41 | return !transposed_conv; |
42 | } |
43 | default: |
44 | return false; |
45 | } |
46 | } |
47 | |
48 | bool FoldFrozenConvBatchnorm(Block* b) { |
49 | bool graph_modified = false; |
50 | for (Node* n : b->nodes()) { |
51 | for (Block* block : n->blocks()) { |
52 | graph_modified |= FoldFrozenConvBatchnorm(block); |
53 | } |
54 | |
55 | if (n->kind() == aten::batch_norm && |
56 | supportedConvNode(n->inputs().at(0)->node())) { |
57 | auto conv = n->inputs().at(0)->node(); |
58 | auto bn = n; |
59 | if (nonConstantParameters(conv) || nonConstantParameters(bn)) { |
60 | continue; |
61 | } |
62 | if (conv->output()->uses().size() > 1) { |
63 | continue; |
64 | } |
65 | |
66 | auto bn_rm_ivalue = bn->namedInput("running_mean" ); |
67 | auto bn_rv_ivalue = bn->namedInput("running_var" ); |
68 | // check running_mean and running_var has value, if they are |
69 | // None(track_running_stats=False), skiping the folding path. |
70 | if (bn_rm_ivalue->type() == NoneType::get() && |
71 | bn_rv_ivalue->type() == NoneType::get()) { |
72 | continue; |
73 | } |
74 | |
75 | auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean" )).value(); |
76 | auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var" )).value(); |
77 | auto bn_eps = constant_as<double>(bn->namedInput("eps" )).value(); |
78 | auto conv_w = constant_as<Tensor>(conv->namedInput("weight" )).value(); |
79 | |
80 | // implementation taken from torch/nn/utils/fusion.py |
81 | Tensor conv_b; |
82 | if (conv->namedInput("bias" )->type() == NoneType::get()) { |
83 | // If this is on GPU and bias is none and weight was half/bfloat, but |
84 | // bn_rm was float, then probably this was a case where autocasting |
85 | // casted inputs to conv. And since CUDA conv implementation requires |
86 | // all the inputs to have the same scalar dtype, we need to make this |
87 | // placeholder have the same type as conv_w. |
88 | at::ScalarType bias_dtype = bn_rm.scalar_type(); |
89 | at::ScalarType weight_dtype = conv_w.scalar_type(); |
90 | if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) && |
91 | bias_dtype == at::kFloat) { |
92 | bias_dtype = weight_dtype; |
93 | } |
94 | conv_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype)); |
95 | } else { |
96 | conv_b = constant_as<Tensor>(conv->namedInput("bias" )).value(); |
97 | } |
98 | Tensor bn_w; |
99 | if (bn->namedInput("weight" )->type() == NoneType::get()) { |
100 | bn_w = at::ones_like(bn_rm); |
101 | } else { |
102 | bn_w = constant_as<Tensor>(bn->namedInput("weight" )).value(); |
103 | } |
104 | Tensor bn_b; |
105 | if (n->namedInput("bias" )->type() == NoneType::get()) { |
106 | bn_b = at::zeros_like(bn_rm); |
107 | } else { |
108 | bn_b = constant_as<Tensor>(bn->namedInput("bias" )).value(); |
109 | } |
110 | |
111 | ConvBNParameters params; |
112 | params.conv_w = conv_w; |
113 | params.conv_b = conv_b; |
114 | params.bn_rm = bn_rm; |
115 | params.bn_rv = bn_rv; |
116 | params.bn_eps = bn_eps; |
117 | params.bn_w = bn_w; |
118 | params.bn_b = bn_b; |
119 | std::tuple<Tensor, Tensor> out = computeUpdatedConvWeightAndBias(params); |
120 | WithInsertPoint guard(conv); |
121 | auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out)); |
122 | auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out)); |
123 | auto conv_w_value = conv->namedInput("weight" ); |
124 | auto conv_b_value = conv->namedInput("bias" ); |
125 | |
126 | fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn" ); |
127 | fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn" ); |
128 | |
129 | conv->replaceInputWith(conv_w_value, fused_conv_w); |
130 | conv->replaceInputWith(conv_b_value, fused_conv_b); |
131 | |
132 | bn->output()->replaceAllUsesWith(conv->output()); |
133 | graph_modified = true; |
134 | } |
135 | } |
136 | return graph_modified; |
137 | } |
138 | |
139 | bool supportedAddOrSub(Node* n) { |
140 | static const OperatorSet add_set{ |
141 | "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" , |
142 | "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor" , |
143 | // sub is equivalent to add |
144 | "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" , |
145 | "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor" , |
146 | }; |
147 | return n->isMemberOf(add_set); |
148 | } |
149 | |
150 | // In order to fuse add/sub/mul/div with conv, the dimensions of its |
151 | // constant tensor must satisfy the following: |
152 | // - with resizing, broadcast to w/ weight/bias tensor shape |
153 | // - broadcast to the conv output shape |
154 | // It needs to have a shape that can resize to weight/bias |
155 | // tensor shape because we need to run the op with the conv |
156 | // weights/bias without changing their sizes. |
157 | // It needs to broadcast to the conv output shape so that we do |
158 | // accidentally change the shape of op output by pre-fusing it |
159 | // compared to eager. |
160 | // The only dimension value shared by weight/bias/conv output |
161 | // is they all contain a dim with value = channels-out. In the |
162 | // conv output tensor, this is in the second dimension, |
163 | // so the pointwise op tensor may have a second dimension of |
164 | // value == channels-out, but all the other dimensions have to be 1 |
165 | bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) { |
166 | if (op_tensor.ndimension() > weight_tensor.ndimension()) { |
167 | return false; |
168 | } |
169 | for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) { |
170 | // channels-out dimension == weight_tensor.size(0) |
171 | if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) { |
172 | continue; |
173 | } |
174 | if (op_tensor.size(i) != 1) { |
175 | return false; |
176 | } |
177 | } |
178 | return true; |
179 | } |
180 | |
181 | bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) { |
182 | if (nonConstantParameters(conv) || nonConstantParameters(op)) { |
183 | return false; |
184 | } |
185 | |
186 | if (conv->output()->uses().size() > 1) { |
187 | return false; |
188 | } |
189 | |
190 | Tensor weight_tensor = |
191 | constant_as<Tensor>(conv->namedInput("weight" )).value(); |
192 | |
193 | // avoid fusing op that causes type promotion |
194 | // resticting to float avoids int/float difficulties with scalar overload |
195 | if (!weight_tensor.is_floating_point()) { |
196 | return false; |
197 | } |
198 | |
199 | if (op->inputs().at(1)->type()->cast<TensorType>()) { |
200 | auto op_tensor = constant_as<Tensor>(op->inputs().at(1)).value(); |
201 | if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) { |
202 | return false; |
203 | } |
204 | |
205 | if (!op_tensor.is_floating_point() && |
206 | c10::promoteTypes( |
207 | op_tensor.scalar_type(), weight_tensor.scalar_type()) != |
208 | weight_tensor.scalar_type()) { |
209 | return false; |
210 | } |
211 | } |
212 | return true; |
213 | } |
214 | |
215 | Tensor resizeConstantScalarOrTensorToShape( |
216 | Value* v, |
217 | const std::vector<int64_t>& shape, |
218 | at::TensorOptions options) { |
219 | Tensor ret_tensor; |
220 | if (v->type()->cast<TensorType>()) { |
221 | ret_tensor = constant_as<Tensor>(v).value(); |
222 | } else { |
223 | ret_tensor = at::zeros(shape, options); |
224 | if (v->type()->cast<IntType>()) { |
225 | ret_tensor.fill_(constant_as<int64_t>(v).value()); |
226 | } else { |
227 | ret_tensor.fill_(constant_as<double>(v).value()); |
228 | } |
229 | } |
230 | |
231 | if (ret_tensor.numel() == 1) { |
232 | // expand errors if the shape input has less # dims than the tensor input |
233 | ret_tensor = ret_tensor.reshape({1}); |
234 | ret_tensor = ret_tensor.expand(shape); |
235 | } else { |
236 | TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape)); |
237 | ret_tensor = ret_tensor.view(shape); |
238 | } |
239 | return ret_tensor; |
240 | } |
241 | |
242 | bool FoldFrozenConvAddOrSub(Block* b) { |
243 | bool graph_modified = false; |
244 | for (Node* n : b->nodes()) { |
245 | for (Block* block : n->blocks()) { |
246 | graph_modified |= FoldFrozenConvAddOrSub(block); |
247 | } |
248 | |
249 | if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) { |
250 | auto conv = n->inputs().at(0)->node(); |
251 | auto add_or_sub = n; |
252 | |
253 | if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_sub)) { |
254 | continue; |
255 | } |
256 | |
257 | Tensor weight_tensor = |
258 | constant_as<Tensor>(conv->namedInput("weight" )).value(); |
259 | |
260 | Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape( |
261 | add_or_sub->inputs().at(1), |
262 | {weight_tensor.size(0)}, |
263 | weight_tensor.options()); |
264 | Tensor bias; |
265 | if (conv->namedInput("bias" )->type() == NoneType::get()) { |
266 | bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype()); |
267 | } else { |
268 | bias = constant_as<Tensor>(conv->namedInput("bias" )).value(); |
269 | } |
270 | |
271 | WithInsertPoint guard(conv); |
272 | |
273 | add_or_sub->replaceInputWith( |
274 | conv->output(), b->owningGraph()->insertConstant(bias)); |
275 | add_or_sub->replaceInput( |
276 | 1, b->owningGraph()->insertConstant(add_or_sub_tensor)); |
277 | |
278 | auto stack_out = runNodeIfInputsAreConstant(add_or_sub); |
279 | TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); |
280 | Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype()); |
281 | |
282 | auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias); |
283 | auto conv_b_value = conv->namedInput("bias" ); |
284 | |
285 | fused_conv_b->setDebugName( |
286 | conv_b_value->debugName() + "_fused_" + |
287 | add_or_sub->kind().toUnqualString()); |
288 | conv->replaceInputWith(conv_b_value, fused_conv_b); |
289 | add_or_sub->output()->replaceAllUsesWith(conv->output()); |
290 | graph_modified = true; |
291 | // DCE run after cleans up nodes |
292 | } |
293 | } |
294 | return graph_modified; |
295 | } |
296 | |
297 | bool supportedMulOrDiv(Node* n) { |
298 | static const OperatorSet add_set{ |
299 | "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor" , |
300 | "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor" , |
301 | // div is equivalent to mul |
302 | "aten::div.Tensor(Tensor self, Tensor other) -> Tensor" , |
303 | "aten::div.Scalar(Tensor self, Scalar other) -> Tensor" , |
304 | }; |
305 | return n->isMemberOf(add_set); |
306 | } |
307 | |
308 | bool FoldFrozenConvMulOrDiv(Block* b) { |
309 | bool graph_modified = false; |
310 | for (Node* n : b->nodes()) { |
311 | for (Block* block : n->blocks()) { |
312 | graph_modified |= FoldFrozenConvMulOrDiv(block); |
313 | } |
314 | |
315 | if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) { |
316 | auto conv = n->inputs().at(0)->node(); |
317 | auto mul_or_div = n; |
318 | |
319 | if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) { |
320 | continue; |
321 | } |
322 | |
323 | Tensor weight_tensor = |
324 | constant_as<Tensor>(conv->namedInput("weight" )).value(); |
325 | int64_t out_channels = weight_tensor.size(0); |
326 | |
327 | // We've already verified that the second input has numel == 1 or |
328 | // channels-out resize it to the shape that will broadcast to |
329 | // weight_tensor when the op is run so we dont change weight size |
330 | std::vector<int64_t> weight_compatible_size = {out_channels}; |
331 | for (const auto i : c10::irange(1, weight_tensor.ndimension())) { |
332 | (void)i; // Suppress unused variable warning |
333 | weight_compatible_size.push_back(1); |
334 | } |
335 | |
336 | WithInsertPoint guard(conv); |
337 | |
338 | Tensor mul_tensor = resizeConstantScalarOrTensorToShape( |
339 | mul_or_div->inputs().at(1), |
340 | weight_compatible_size, |
341 | weight_tensor.options()); |
342 | |
343 | // First fold with weight tensor |
344 | mul_or_div->replaceInputWith( |
345 | conv->output(), b->owningGraph()->insertConstant(weight_tensor)); |
346 | mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor)); |
347 | |
348 | auto stack_out = runNodeIfInputsAreConstant(mul_or_div); |
349 | TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); |
350 | Tensor fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype()); |
351 | |
352 | auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight); |
353 | auto conv_weight_value = conv->namedInput("weight" ); |
354 | |
355 | fused_conv_weight->setDebugName( |
356 | conv_weight_value->debugName() + "_fused_" + |
357 | mul_or_div->kind().toUnqualString()); |
358 | conv->replaceInputWith(conv_weight_value, fused_conv_weight); |
359 | mul_or_div->output()->replaceAllUsesWith(conv->output()); |
360 | |
361 | // now fold with bias tensor |
362 | if (conv->namedInput("bias" )->type() != NoneType::get()) { |
363 | Tensor bias = constant_as<Tensor>(conv->namedInput("bias" )).value(); |
364 | // bias is of shape {channels_out} |
365 | auto mul_tensor = resizeConstantScalarOrTensorToShape( |
366 | mul_or_div->inputs().at(1), {out_channels}, bias.options()); |
367 | |
368 | mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias)); |
369 | mul_or_div->replaceInput( |
370 | 1, b->owningGraph()->insertConstant(mul_tensor)); |
371 | |
372 | auto stack_out = runNodeIfInputsAreConstant(mul_or_div); |
373 | TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1); |
374 | Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype()); |
375 | |
376 | auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias); |
377 | auto conv_b_value = conv->namedInput("bias" ); |
378 | |
379 | fused_conv_weight->setDebugName( |
380 | conv_b_value->debugName() + "_fused_" + |
381 | mul_or_div->kind().toUnqualString()); |
382 | conv->replaceInputWith(conv_b_value, fused_conv_bias); |
383 | } |
384 | graph_modified = true; |
385 | // DCE run after cleans up nodes |
386 | } |
387 | } |
388 | return graph_modified; |
389 | } |
390 | |
391 | } // namespace |
392 | |
393 | bool FoldFrozenConvBatchnorm(std::shared_ptr<Graph>& graph) { |
394 | bool graph_modified = FoldFrozenConvBatchnorm(graph->block()); |
395 | EliminateDeadCode(graph); |
396 | return graph_modified; |
397 | } |
398 | |
399 | bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) { |
400 | bool graph_modified = FoldFrozenConvAddOrSub(graph->block()); |
401 | EliminateDeadCode(graph); |
402 | return graph_modified; |
403 | } |
404 | |
405 | bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) { |
406 | bool graph_modified = FoldFrozenConvMulOrDiv(graph->block()); |
407 | EliminateDeadCode(graph); |
408 | return graph_modified; |
409 | } |
410 | |
411 | } // namespace jit |
412 | } // namespace torch |
413 | |