1#include <ATen/Utils.h>
2#include <c10/core/ScalarType.h>
3#include <c10/util/Exception.h>
4#include <c10/util/accumulate.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/ir/constants.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <torch/csrc/jit/jit_log.h>
9#include <torch/csrc/jit/passes/constant_propagation.h>
10#include <torch/csrc/jit/passes/dead_code_elimination.h>
11#include <torch/csrc/jit/passes/fold_conv_bn.h>
12#include <torch/csrc/jit/passes/frozen_conv_folding.h>
13#include <torch/csrc/jit/passes/utils/optimization_utils.h>
14#include <torch/csrc/jit/tensorexpr/types.h>
15
16#ifndef AT_PER_OPERATOR_HEADERS
17#include <ATen/Functions.h>
18#else
19#include <ATen/ops/ones_like.h>
20#include <ATen/ops/zeros.h>
21#include <ATen/ops/zeros_like.h>
22#endif
23
24namespace torch {
25namespace jit {
26
27namespace {
28
29using Tensor = at::Tensor;
30
31bool supportedConvNode(Node* n) {
32 switch (n->kind()) {
33 case aten::conv1d:
34 case aten::conv2d:
35 case aten::conv3d:
36 return true;
37 case aten::_convolution: {
38 auto transposed_conv =
39 constant_as<bool>(n->namedInput("transposed")).value_or(true);
40 // dont handle transposed conv yet or not-constant transpose parameter
41 return !transposed_conv;
42 }
43 default:
44 return false;
45 }
46}
47
48bool FoldFrozenConvBatchnorm(Block* b) {
49 bool graph_modified = false;
50 for (Node* n : b->nodes()) {
51 for (Block* block : n->blocks()) {
52 graph_modified |= FoldFrozenConvBatchnorm(block);
53 }
54
55 if (n->kind() == aten::batch_norm &&
56 supportedConvNode(n->inputs().at(0)->node())) {
57 auto conv = n->inputs().at(0)->node();
58 auto bn = n;
59 if (nonConstantParameters(conv) || nonConstantParameters(bn)) {
60 continue;
61 }
62 if (conv->output()->uses().size() > 1) {
63 continue;
64 }
65
66 auto bn_rm_ivalue = bn->namedInput("running_mean");
67 auto bn_rv_ivalue = bn->namedInput("running_var");
68 // check running_mean and running_var has value, if they are
69 // None(track_running_stats=False), skiping the folding path.
70 if (bn_rm_ivalue->type() == NoneType::get() &&
71 bn_rv_ivalue->type() == NoneType::get()) {
72 continue;
73 }
74
75 auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
76 auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
77 auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
78 auto conv_w = constant_as<Tensor>(conv->namedInput("weight")).value();
79
80 // implementation taken from torch/nn/utils/fusion.py
81 Tensor conv_b;
82 if (conv->namedInput("bias")->type() == NoneType::get()) {
83 // If this is on GPU and bias is none and weight was half/bfloat, but
84 // bn_rm was float, then probably this was a case where autocasting
85 // casted inputs to conv. And since CUDA conv implementation requires
86 // all the inputs to have the same scalar dtype, we need to make this
87 // placeholder have the same type as conv_w.
88 at::ScalarType bias_dtype = bn_rm.scalar_type();
89 at::ScalarType weight_dtype = conv_w.scalar_type();
90 if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
91 bias_dtype == at::kFloat) {
92 bias_dtype = weight_dtype;
93 }
94 conv_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
95 } else {
96 conv_b = constant_as<Tensor>(conv->namedInput("bias")).value();
97 }
98 Tensor bn_w;
99 if (bn->namedInput("weight")->type() == NoneType::get()) {
100 bn_w = at::ones_like(bn_rm);
101 } else {
102 bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
103 }
104 Tensor bn_b;
105 if (n->namedInput("bias")->type() == NoneType::get()) {
106 bn_b = at::zeros_like(bn_rm);
107 } else {
108 bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
109 }
110
111 ConvBNParameters params;
112 params.conv_w = conv_w;
113 params.conv_b = conv_b;
114 params.bn_rm = bn_rm;
115 params.bn_rv = bn_rv;
116 params.bn_eps = bn_eps;
117 params.bn_w = bn_w;
118 params.bn_b = bn_b;
119 std::tuple<Tensor, Tensor> out = computeUpdatedConvWeightAndBias(params);
120 WithInsertPoint guard(conv);
121 auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out));
122 auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out));
123 auto conv_w_value = conv->namedInput("weight");
124 auto conv_b_value = conv->namedInput("bias");
125
126 fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn");
127 fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn");
128
129 conv->replaceInputWith(conv_w_value, fused_conv_w);
130 conv->replaceInputWith(conv_b_value, fused_conv_b);
131
132 bn->output()->replaceAllUsesWith(conv->output());
133 graph_modified = true;
134 }
135 }
136 return graph_modified;
137}
138
139bool supportedAddOrSub(Node* n) {
140 static const OperatorSet add_set{
141 "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
142 "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
143 // sub is equivalent to add
144 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
145 "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
146 };
147 return n->isMemberOf(add_set);
148}
149
150// In order to fuse add/sub/mul/div with conv, the dimensions of its
151// constant tensor must satisfy the following:
152// - with resizing, broadcast to w/ weight/bias tensor shape
153// - broadcast to the conv output shape
154// It needs to have a shape that can resize to weight/bias
155// tensor shape because we need to run the op with the conv
156// weights/bias without changing their sizes.
157// It needs to broadcast to the conv output shape so that we do
158// accidentally change the shape of op output by pre-fusing it
159// compared to eager.
160// The only dimension value shared by weight/bias/conv output
161// is they all contain a dim with value = channels-out. In the
162// conv output tensor, this is in the second dimension,
163// so the pointwise op tensor may have a second dimension of
164// value == channels-out, but all the other dimensions have to be 1
165bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) {
166 if (op_tensor.ndimension() > weight_tensor.ndimension()) {
167 return false;
168 }
169 for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) {
170 // channels-out dimension == weight_tensor.size(0)
171 if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) {
172 continue;
173 }
174 if (op_tensor.size(i) != 1) {
175 return false;
176 }
177 }
178 return true;
179}
180
181bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
182 if (nonConstantParameters(conv) || nonConstantParameters(op)) {
183 return false;
184 }
185
186 if (conv->output()->uses().size() > 1) {
187 return false;
188 }
189
190 Tensor weight_tensor =
191 constant_as<Tensor>(conv->namedInput("weight")).value();
192
193 // avoid fusing op that causes type promotion
194 // resticting to float avoids int/float difficulties with scalar overload
195 if (!weight_tensor.is_floating_point()) {
196 return false;
197 }
198
199 if (op->inputs().at(1)->type()->cast<TensorType>()) {
200 auto op_tensor = constant_as<Tensor>(op->inputs().at(1)).value();
201 if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) {
202 return false;
203 }
204
205 if (!op_tensor.is_floating_point() &&
206 c10::promoteTypes(
207 op_tensor.scalar_type(), weight_tensor.scalar_type()) !=
208 weight_tensor.scalar_type()) {
209 return false;
210 }
211 }
212 return true;
213}
214
215Tensor resizeConstantScalarOrTensorToShape(
216 Value* v,
217 const std::vector<int64_t>& shape,
218 at::TensorOptions options) {
219 Tensor ret_tensor;
220 if (v->type()->cast<TensorType>()) {
221 ret_tensor = constant_as<Tensor>(v).value();
222 } else {
223 ret_tensor = at::zeros(shape, options);
224 if (v->type()->cast<IntType>()) {
225 ret_tensor.fill_(constant_as<int64_t>(v).value());
226 } else {
227 ret_tensor.fill_(constant_as<double>(v).value());
228 }
229 }
230
231 if (ret_tensor.numel() == 1) {
232 // expand errors if the shape input has less # dims than the tensor input
233 ret_tensor = ret_tensor.reshape({1});
234 ret_tensor = ret_tensor.expand(shape);
235 } else {
236 TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
237 ret_tensor = ret_tensor.view(shape);
238 }
239 return ret_tensor;
240}
241
242bool FoldFrozenConvAddOrSub(Block* b) {
243 bool graph_modified = false;
244 for (Node* n : b->nodes()) {
245 for (Block* block : n->blocks()) {
246 graph_modified |= FoldFrozenConvAddOrSub(block);
247 }
248
249 if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) {
250 auto conv = n->inputs().at(0)->node();
251 auto add_or_sub = n;
252
253 if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_sub)) {
254 continue;
255 }
256
257 Tensor weight_tensor =
258 constant_as<Tensor>(conv->namedInput("weight")).value();
259
260 Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape(
261 add_or_sub->inputs().at(1),
262 {weight_tensor.size(0)},
263 weight_tensor.options());
264 Tensor bias;
265 if (conv->namedInput("bias")->type() == NoneType::get()) {
266 bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype());
267 } else {
268 bias = constant_as<Tensor>(conv->namedInput("bias")).value();
269 }
270
271 WithInsertPoint guard(conv);
272
273 add_or_sub->replaceInputWith(
274 conv->output(), b->owningGraph()->insertConstant(bias));
275 add_or_sub->replaceInput(
276 1, b->owningGraph()->insertConstant(add_or_sub_tensor));
277
278 auto stack_out = runNodeIfInputsAreConstant(add_or_sub);
279 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
280 Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
281
282 auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias);
283 auto conv_b_value = conv->namedInput("bias");
284
285 fused_conv_b->setDebugName(
286 conv_b_value->debugName() + "_fused_" +
287 add_or_sub->kind().toUnqualString());
288 conv->replaceInputWith(conv_b_value, fused_conv_b);
289 add_or_sub->output()->replaceAllUsesWith(conv->output());
290 graph_modified = true;
291 // DCE run after cleans up nodes
292 }
293 }
294 return graph_modified;
295}
296
297bool supportedMulOrDiv(Node* n) {
298 static const OperatorSet add_set{
299 "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
300 "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor",
301 // div is equivalent to mul
302 "aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
303 "aten::div.Scalar(Tensor self, Scalar other) -> Tensor",
304 };
305 return n->isMemberOf(add_set);
306}
307
308bool FoldFrozenConvMulOrDiv(Block* b) {
309 bool graph_modified = false;
310 for (Node* n : b->nodes()) {
311 for (Block* block : n->blocks()) {
312 graph_modified |= FoldFrozenConvMulOrDiv(block);
313 }
314
315 if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) {
316 auto conv = n->inputs().at(0)->node();
317 auto mul_or_div = n;
318
319 if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) {
320 continue;
321 }
322
323 Tensor weight_tensor =
324 constant_as<Tensor>(conv->namedInput("weight")).value();
325 int64_t out_channels = weight_tensor.size(0);
326
327 // We've already verified that the second input has numel == 1 or
328 // channels-out resize it to the shape that will broadcast to
329 // weight_tensor when the op is run so we dont change weight size
330 std::vector<int64_t> weight_compatible_size = {out_channels};
331 for (const auto i : c10::irange(1, weight_tensor.ndimension())) {
332 (void)i; // Suppress unused variable warning
333 weight_compatible_size.push_back(1);
334 }
335
336 WithInsertPoint guard(conv);
337
338 Tensor mul_tensor = resizeConstantScalarOrTensorToShape(
339 mul_or_div->inputs().at(1),
340 weight_compatible_size,
341 weight_tensor.options());
342
343 // First fold with weight tensor
344 mul_or_div->replaceInputWith(
345 conv->output(), b->owningGraph()->insertConstant(weight_tensor));
346 mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor));
347
348 auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
349 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
350 Tensor fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype());
351
352 auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight);
353 auto conv_weight_value = conv->namedInput("weight");
354
355 fused_conv_weight->setDebugName(
356 conv_weight_value->debugName() + "_fused_" +
357 mul_or_div->kind().toUnqualString());
358 conv->replaceInputWith(conv_weight_value, fused_conv_weight);
359 mul_or_div->output()->replaceAllUsesWith(conv->output());
360
361 // now fold with bias tensor
362 if (conv->namedInput("bias")->type() != NoneType::get()) {
363 Tensor bias = constant_as<Tensor>(conv->namedInput("bias")).value();
364 // bias is of shape {channels_out}
365 auto mul_tensor = resizeConstantScalarOrTensorToShape(
366 mul_or_div->inputs().at(1), {out_channels}, bias.options());
367
368 mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias));
369 mul_or_div->replaceInput(
370 1, b->owningGraph()->insertConstant(mul_tensor));
371
372 auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
373 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
374 Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
375
376 auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias);
377 auto conv_b_value = conv->namedInput("bias");
378
379 fused_conv_weight->setDebugName(
380 conv_b_value->debugName() + "_fused_" +
381 mul_or_div->kind().toUnqualString());
382 conv->replaceInputWith(conv_b_value, fused_conv_bias);
383 }
384 graph_modified = true;
385 // DCE run after cleans up nodes
386 }
387 }
388 return graph_modified;
389}
390
391} // namespace
392
393bool FoldFrozenConvBatchnorm(std::shared_ptr<Graph>& graph) {
394 bool graph_modified = FoldFrozenConvBatchnorm(graph->block());
395 EliminateDeadCode(graph);
396 return graph_modified;
397}
398
399bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
400 bool graph_modified = FoldFrozenConvAddOrSub(graph->block());
401 EliminateDeadCode(graph);
402 return graph_modified;
403}
404
405bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
406 bool graph_modified = FoldFrozenConvMulOrDiv(graph->block());
407 EliminateDeadCode(graph);
408 return graph_modified;
409}
410
411} // namespace jit
412} // namespace torch
413