1#include <torch/csrc/jit/passes/fold_conv_bn.h>
2
3#include <torch/csrc/jit/ir/subgraph_matcher.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
6#include <torch/csrc/jit/passes/quantization/helper.h>
7
8#include <ATen/TensorOperators.h>
9
10#ifndef AT_PER_OPERATOR_HEADERS
11#include <ATen/Functions.h>
12#else
13#include <ATen/ops/empty_like.h>
14#include <ATen/ops/ones_like.h>
15#include <ATen/ops/rsqrt.h>
16#include <ATen/ops/zeros_like.h>
17#endif
18
19#include <stack>
20#include <utility>
21
22namespace torch {
23namespace jit {
24
25std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias(
26 const ConvBNParameters& p) {
27 at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps);
28 const int64_t ndim = p.conv_w.dim();
29 at::DimVector sizes(ndim, 1);
30 sizes.at(0) = -1;
31
32 auto conv_w_dtype = p.conv_w.dtype();
33 auto conv_b_dtype = p.conv_b.dtype();
34
35 at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes);
36 at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b;
37 return std::make_tuple(new_w.to(conv_w_dtype), new_b.to(conv_b_dtype));
38}
39
40namespace {
41using graph_rewrite_helper::PatternInfo;
42
43static bool hastensor(Module& m, const char* name) {
44 return m.hasattr(name) && m.attr(name).isTensor();
45}
46
47void replaceConvBiasWithGetAttr(Module& module) {
48 for (const auto& method : module.get_methods()) {
49 auto graph = method.graph();
50 // Only looks for _convolution pattern.
51 // Thus assumes that tracing will have always gotten rid of aten::conv2d or
52 // aten::conv3d. If it did not, BN folding will fail.
53 const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"(
54 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
55 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
56 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
57 %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
58 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
59 return (%conv_out) )");
60 const PatternInfo& pattern_convolution_deprecated =
61 PatternInfo::parse_from_str(R"(
62 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
63 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
64 %deterministic:bool, %cudnn_enabled:bool):
65 %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
66 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
67 return (%conv_out) )");
68 auto replace_pattern = [&](const PatternInfo& pattern_convolution) {
69 const Graph& pattern_convolution_graph =
70 *pattern_convolution.pattern_graph;
71 const auto& convolution_vmap = pattern_convolution.vmap;
72
73 const auto& matches =
74 findPatternMatches(pattern_convolution_graph, *graph);
75 for (const auto& match : matches) {
76 // We come here only if the bias was not present in the module.
77 // In that case, the corresponding graph will not have getAttr("bias")
78 // Insert that in the graph.
79 // And change _convolution to take the new value.
80 auto conv_node =
81 match.values_map.at(convolution_vmap.at("conv_out"))->node();
82 WithInsertPoint ins(conv_node);
83 Value* bias_attr_val = graph->insertGetAttr(graph->inputs()[0], "bias")
84 ->setType(TensorType::get());
85 constexpr size_t conv_bias_index = 2;
86 conv_node->replaceInput(conv_bias_index, bias_attr_val);
87 }
88 };
89 replace_pattern(pattern_convolution);
90 replace_pattern(pattern_convolution_deprecated);
91 }
92}
93
94void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
95 auto t = module.type()->expect<ClassType>();
96
97 const std::string real_typename = t->name()->qualifiedName();
98 const std::string demangled_typename = removeTorchMangle(real_typename);
99 bool is_floating_point_conv =
100 ((demangled_typename == "__torch__.torch.nn.modules.conv.Conv1d") ||
101 (demangled_typename == "__torch__.torch.nn.modules.conv.Conv2d") ||
102 (demangled_typename == "__torch__.torch.nn.modules.conv.Conv3d"));
103
104 if (is_floating_point_conv) {
105 if (!t->hasAttribute("bias")) {
106 auto optional_tensor_type = OptionalType::create(TensorType::get());
107 t->addAttribute("bias", std::move(optional_tensor_type), true);
108 auto optional_tensor = c10::optional<at::Tensor>();
109 module.setattr("bias", std::move(optional_tensor));
110 replaceConvBiasWithGetAttr(module);
111 }
112 }
113 for (Module m : module.children()) {
114 addBiasForConvIfNone(m, pattern_name);
115 }
116}
117
118class FoldConvBatchNormHelper {
119 public:
120 /**
121 * In this step we find all Conv - BatchNorm patterns in the graph
122 * and extract the corresponding parameters for these two modules,
123 * and record informations for the modifications of the graph without
124 * actually performing these modifications.
125 */
126 void analyze(Module& module, const PatternInfo& pattern);
127 /**
128 * In this step we perform all the modifications including
129 * setting the attributes for conv module, rewriting values
130 * and deleting nodes in the graph
131 */
132 void transform();
133
134 private:
135 bool tryExtractingConvBNParameters(
136 Module& conv,
137 Module& bn,
138 ConvBNParameters& r);
139
140 std::unordered_map<ModulePtr, std::tuple<at::Tensor, at::Tensor>>
141 conv_module_and_params_;
142
143 // A map from graph to a list of tuple of paths of matched conv and bn module
144 // e.g. if we have a graph `g` containing following code
145 // x = self.sub.conv1(..)
146 // x = self.sub.bn1(..)
147 // x = self.sub.conv2(..)
148 // x = self.sub.bn2(..)
149 // then the value for graph `g` in this map will be:
150 // [(['sub', 'conv1'], ['sub', 'bn1']), (['sub', 'conv2'], ['sub', 'bn2'])]
151 // the first entry of the list is the paths to first conv-bn match
152 // the second entry of the list is the path to second match
153 std::unordered_map<
154 Graph*,
155 std::vector<
156 std::tuple<std::vector<std::string>, std::vector<std::string>>>>
157 conv_bn_paths_;
158
159 std::unordered_map<Value*, Value*> rewrite_map_;
160 std::vector<Value*> values_to_rewrite_;
161 std::unordered_set<Node*> nodes_to_delete_;
162};
163
164bool extractOptionalBNParams(const script::Module& bn, ConvBNParameters& r) {
165 auto bn_forward = bn.get_method("forward");
166 auto graph = bn_forward.graph();
167 const PatternInfo& pattern_bn = PatternInfo::parse_from_str(R"(
168 graph(%a, %weight, %bias, %running_mean, %running_var,
169 %training, %momentum, %eps, %cudnn_enabled):
170 %bn_out = aten::batch_norm(%a, %weight, %bias, %running_mean,
171 %running_var, %training, %momentum, %eps, %cudnn_enabled)
172 return (%bn_out) )");
173 const Graph& pattern_bn_graph = *pattern_bn.pattern_graph;
174 const auto& bn_vmap = pattern_bn.vmap;
175
176 const auto& matches = findPatternMatches(pattern_bn_graph, *graph);
177
178 if (matches.size() > 1) {
179 return false;
180 }
181
182 if (bn.hasattr("eps")) {
183 r.bn_eps = bn.attr("eps").toDouble();
184 } else {
185 auto optional_eps = toIValue(matches[0].values_map.at(bn_vmap.at("eps")));
186 if (!optional_eps) {
187 return false;
188 }
189 r.bn_eps = optional_eps.value().toDouble();
190 }
191 r.bn_w = at::ones_like(bn.attr("running_mean").toTensor());
192 if (bn.hasattr("weight")) {
193 if (bn.attr("weight").isTensor()) {
194 r.bn_w = bn.attr("weight").toTensor();
195 }
196 } else {
197 auto optional_bn_weight =
198 toIValue(matches[0].values_map.at(bn_vmap.at("weight")));
199 if (!optional_bn_weight) {
200 return false;
201 }
202 if (optional_bn_weight.value().isTensor()) {
203 r.bn_w = optional_bn_weight.value().toTensor();
204 }
205 }
206 r.bn_b = at::zeros_like(bn.attr("running_mean").toTensor());
207 if (bn.hasattr("bias")) {
208 if (bn.attr("bias").isTensor()) {
209 r.bn_b = bn.attr("bias").toTensor();
210 }
211 } else {
212 auto optional_bn_bias =
213 toIValue(matches[0].values_map.at(bn_vmap.at("bias")));
214 if (!optional_bn_bias) {
215 return false;
216 }
217
218 if (optional_bn_bias.value().isTensor()) {
219 r.bn_b = optional_bn_bias.value().toTensor();
220 }
221 }
222 return true;
223}
224
225bool FoldConvBatchNormHelper::tryExtractingConvBNParameters(
226 Module& conv,
227 Module& bn,
228 ConvBNParameters& r) {
229 if (!hastensor(conv, "weight") || !conv.hasattr("bias") ||
230 !hastensor(bn, "running_mean") || !hastensor(bn, "running_var")) {
231 return false;
232 }
233
234 r.bn_rm = bn.attr("running_mean").toTensor();
235 r.bn_rv = bn.attr("running_var").toTensor();
236 if (!extractOptionalBNParams(bn, r)) {
237 return false;
238 }
239
240 r.conv_w = conv.attr("weight").toTensor();
241 r.conv_b = at::zeros_like(r.bn_rm);
242 auto bias_opt = conv.attr("bias").toOptional<at::Tensor>();
243 if (bias_opt) {
244 r.conv_b = *bias_opt;
245 }
246
247 return true;
248}
249
250void FoldConvBatchNormHelper::analyze(
251 Module& module,
252 const PatternInfo& pattern) {
253 const Graph& pattern_graph = *pattern.pattern_graph;
254 const auto& vmap = pattern.vmap;
255 Value* pattern_conv_out = vmap.at("conv_out");
256 Value* pattern_bn_out = vmap.at("bn_out");
257 Value* pattern_bn_submodule = vmap.at("batchnorm");
258 Node* pattern_conv = pattern_conv_out->node();
259 Node* pattern_bn = pattern_bn_out->node();
260
261 // We will put submodules into this worklist and keep processing items from it
262 // one by one. We start by just putting the top module there.
263 std::stack<Module> worklist({module});
264 while (!worklist.empty()) {
265 Module current = worklist.top();
266 worklist.pop();
267
268 // Queue submodules for processing
269 for (const Module& submodule : current.children()) {
270 worklist.push(submodule);
271 }
272
273 // Process all method of the current module
274 for (auto& method : current.get_methods()) {
275 GRAPH_DUMP(
276 current.type()->name()->name() + "::" + method.name() +
277 "() before Conv-BatchNorm folding",
278 method.graph());
279 const auto& matches = findPatternMatches(pattern_graph, *method.graph());
280
281 GRAPH_DEBUG("number of Conv-BatchNorm matches: ", matches.size());
282 Graph* g = method.graph().get();
283 if (!conv_bn_paths_.count(g)) {
284 // This is to make sure we don't visit one graph multiple times
285 conv_bn_paths_[g] = {};
286 for (const Match& match : matches) {
287 if (!std::all_of(
288 pattern.filters.begin(),
289 pattern.filters.end(),
290 [&](const MatchFilter& f) { return f(match, vmap); })) {
291 continue;
292 }
293 GRAPH_DEBUG("Checking next match...");
294 // Get the conv and bn submodule
295 Node* matched_conv = match.nodes_map.at(pattern_conv);
296 Node* matched_bn = match.nodes_map.at(pattern_bn);
297 Node* matched_bn_submodule =
298 match.values_map.at(pattern_bn_submodule)->node();
299 Value* conv_instance = matched_conv->input(0);
300 Value* bn_instance = matched_bn->input(0);
301 Value* self = g->inputs()[0];
302 auto conv_module_path = getModuleAccessPath(conv_instance, self);
303 auto bn_module_path = getModuleAccessPath(bn_instance, self);
304 Module conv_submodule = findChildModule(current, conv_module_path);
305 Module bn_submodule = findChildModule(current, bn_module_path);
306
307 ConvBNParameters params;
308 if (!tryExtractingConvBNParameters(
309 conv_submodule, bn_submodule, params)) {
310 GRAPH_DEBUG(
311 "Conv and BN modules didn't have all required parameters or attributes...");
312 continue;
313 }
314 conv_bn_paths_[g].emplace_back(conv_module_path, bn_module_path);
315 // We are using a separate vector for saving Values we want to rewrite
316 // to make sure that the order in which we perform these
317 // transformations is deterministic. Iterating through keys of
318 // rewrite_map would result in non-determinism that might not manifest
319 // as a bug now, but can bite us later.
320 values_to_rewrite_.push_back(matched_bn->output());
321 rewrite_map_[matched_bn->output()] = matched_conv->output();
322 GRAPH_UPDATE(
323 "Rewriting %",
324 matched_bn->output()->debugName(),
325 " with %",
326 matched_conv->output()->debugName());
327
328 nodes_to_delete_.insert(matched_bn);
329 nodes_to_delete_.insert(matched_bn_submodule);
330 GRAPH_UPDATE("Deleting ", *matched_bn);
331 GRAPH_UPDATE("Deleting ", *matched_bn_submodule);
332
333 auto slot = conv_submodule.type()->getAttributeSlot("bias");
334 TORCH_CHECK(
335 conv_submodule.type()->is_parameter(slot),
336 "Expected conv module to have a bias parameter");
337 } // matches
338 }
339
340 for (const auto& conv_bn : conv_bn_paths_.at(g)) {
341 Module conv_submodule = findChildModule(current, std::get<0>(conv_bn));
342 Module bn_submodule = findChildModule(current, std::get<1>(conv_bn));
343
344 ConvBNParameters params;
345 TORCH_INTERNAL_ASSERT(tryExtractingConvBNParameters(
346 conv_submodule, bn_submodule, params));
347 auto new_w_b = computeUpdatedConvWeightAndBias(params);
348 conv_module_and_params_[conv_submodule._ivalue()] = new_w_b;
349 } // conv_bn module
350 } // methods
351 } // while
352}
353
354void FoldConvBatchNormHelper::transform() {
355 for (const auto& item : conv_module_and_params_) {
356 Module conv(item.first);
357 auto w_b = item.second;
358 conv.setattr("weight", std::get<0>(w_b));
359 conv.setattr("bias", std::get<1>(w_b));
360 }
361
362 // Perform planned rewritings
363 for (auto v : values_to_rewrite_) {
364 v->replaceAllUsesWith(rewrite_map_.at(v));
365 }
366
367 // Perform planned deletions
368 for (auto n : nodes_to_delete_) {
369 n->removeAllInputs();
370 }
371 for (auto n : nodes_to_delete_) {
372 n->destroy();
373 }
374}
375
376} // namespace
377
378Module FoldConvBatchNorm(const Module& module) {
379 Module m = module.clone();
380
381 addBiasForConvIfNone(m, "Conv2d");
382 addBiasForConvIfNone(m, "Conv3d");
383 // Conv2d + BatchNorm2d
384 const PatternInfo pattern2d = PatternInfo::parse_from_str(
385 R"(
386graph(%self, %input, %conv, %batchnorm):
387 %conv_out = prim::CallMethod[name="forward"](%conv, %input)
388 %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
389 return (%bn_out))",
390 {is_conv2d_module, is_batchnorm2d_module});
391 // Conv3d + BatchNorm3d
392 const PatternInfo pattern3d = PatternInfo::parse_from_str(
393 R"(
394graph(%self, %input, %conv, %batchnorm):
395 %conv_out = prim::CallMethod[name="forward"](%conv, %input)
396 %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
397 return (%bn_out))",
398 {is_conv3d_module, is_batchnorm3d_module});
399
400 const std::vector<std::reference_wrapper<const PatternInfo>> patterns = {
401 pattern2d, pattern3d};
402 for (const auto& pattern : patterns) {
403 FoldConvBatchNormHelper h;
404 h.analyze(m, pattern);
405 h.transform();
406 }
407 return m;
408}
409
410} // namespace jit
411} // namespace torch
412