1 | #include <torch/csrc/jit/passes/fold_conv_bn.h> |
2 | |
3 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
6 | #include <torch/csrc/jit/passes/quantization/helper.h> |
7 | |
8 | #include <ATen/TensorOperators.h> |
9 | |
10 | #ifndef AT_PER_OPERATOR_HEADERS |
11 | #include <ATen/Functions.h> |
12 | #else |
13 | #include <ATen/ops/empty_like.h> |
14 | #include <ATen/ops/ones_like.h> |
15 | #include <ATen/ops/rsqrt.h> |
16 | #include <ATen/ops/zeros_like.h> |
17 | #endif |
18 | |
19 | #include <stack> |
20 | #include <utility> |
21 | |
22 | namespace torch { |
23 | namespace jit { |
24 | |
25 | std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias( |
26 | const ConvBNParameters& p) { |
27 | at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps); |
28 | const int64_t ndim = p.conv_w.dim(); |
29 | at::DimVector sizes(ndim, 1); |
30 | sizes.at(0) = -1; |
31 | |
32 | auto conv_w_dtype = p.conv_w.dtype(); |
33 | auto conv_b_dtype = p.conv_b.dtype(); |
34 | |
35 | at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes); |
36 | at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b; |
37 | return std::make_tuple(new_w.to(conv_w_dtype), new_b.to(conv_b_dtype)); |
38 | } |
39 | |
40 | namespace { |
41 | using graph_rewrite_helper::PatternInfo; |
42 | |
43 | static bool hastensor(Module& m, const char* name) { |
44 | return m.hasattr(name) && m.attr(name).isTensor(); |
45 | } |
46 | |
47 | void replaceConvBiasWithGetAttr(Module& module) { |
48 | for (const auto& method : module.get_methods()) { |
49 | auto graph = method.graph(); |
50 | // Only looks for _convolution pattern. |
51 | // Thus assumes that tracing will have always gotten rid of aten::conv2d or |
52 | // aten::conv3d. If it did not, BN folding will fail. |
53 | const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"( |
54 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
55 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
56 | %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool): |
57 | %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, |
58 | %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32) |
59 | return (%conv_out) )" ); |
60 | const PatternInfo& pattern_convolution_deprecated = |
61 | PatternInfo::parse_from_str(R"( |
62 | graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], |
63 | %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, |
64 | %deterministic:bool, %cudnn_enabled:bool): |
65 | %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation, |
66 | %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled) |
67 | return (%conv_out) )" ); |
68 | auto replace_pattern = [&](const PatternInfo& pattern_convolution) { |
69 | const Graph& pattern_convolution_graph = |
70 | *pattern_convolution.pattern_graph; |
71 | const auto& convolution_vmap = pattern_convolution.vmap; |
72 | |
73 | const auto& matches = |
74 | findPatternMatches(pattern_convolution_graph, *graph); |
75 | for (const auto& match : matches) { |
76 | // We come here only if the bias was not present in the module. |
77 | // In that case, the corresponding graph will not have getAttr("bias") |
78 | // Insert that in the graph. |
79 | // And change _convolution to take the new value. |
80 | auto conv_node = |
81 | match.values_map.at(convolution_vmap.at("conv_out" ))->node(); |
82 | WithInsertPoint ins(conv_node); |
83 | Value* bias_attr_val = graph->insertGetAttr(graph->inputs()[0], "bias" ) |
84 | ->setType(TensorType::get()); |
85 | constexpr size_t conv_bias_index = 2; |
86 | conv_node->replaceInput(conv_bias_index, bias_attr_val); |
87 | } |
88 | }; |
89 | replace_pattern(pattern_convolution); |
90 | replace_pattern(pattern_convolution_deprecated); |
91 | } |
92 | } |
93 | |
94 | void addBiasForConvIfNone(Module& module, const std::string& pattern_name) { |
95 | auto t = module.type()->expect<ClassType>(); |
96 | |
97 | const std::string real_typename = t->name()->qualifiedName(); |
98 | const std::string demangled_typename = removeTorchMangle(real_typename); |
99 | bool is_floating_point_conv = |
100 | ((demangled_typename == "__torch__.torch.nn.modules.conv.Conv1d" ) || |
101 | (demangled_typename == "__torch__.torch.nn.modules.conv.Conv2d" ) || |
102 | (demangled_typename == "__torch__.torch.nn.modules.conv.Conv3d" )); |
103 | |
104 | if (is_floating_point_conv) { |
105 | if (!t->hasAttribute("bias" )) { |
106 | auto optional_tensor_type = OptionalType::create(TensorType::get()); |
107 | t->addAttribute("bias" , std::move(optional_tensor_type), true); |
108 | auto optional_tensor = c10::optional<at::Tensor>(); |
109 | module.setattr("bias" , std::move(optional_tensor)); |
110 | replaceConvBiasWithGetAttr(module); |
111 | } |
112 | } |
113 | for (Module m : module.children()) { |
114 | addBiasForConvIfNone(m, pattern_name); |
115 | } |
116 | } |
117 | |
118 | class FoldConvBatchNormHelper { |
119 | public: |
120 | /** |
121 | * In this step we find all Conv - BatchNorm patterns in the graph |
122 | * and extract the corresponding parameters for these two modules, |
123 | * and record informations for the modifications of the graph without |
124 | * actually performing these modifications. |
125 | */ |
126 | void analyze(Module& module, const PatternInfo& pattern); |
127 | /** |
128 | * In this step we perform all the modifications including |
129 | * setting the attributes for conv module, rewriting values |
130 | * and deleting nodes in the graph |
131 | */ |
132 | void transform(); |
133 | |
134 | private: |
135 | bool tryExtractingConvBNParameters( |
136 | Module& conv, |
137 | Module& bn, |
138 | ConvBNParameters& r); |
139 | |
140 | std::unordered_map<ModulePtr, std::tuple<at::Tensor, at::Tensor>> |
141 | conv_module_and_params_; |
142 | |
143 | // A map from graph to a list of tuple of paths of matched conv and bn module |
144 | // e.g. if we have a graph `g` containing following code |
145 | // x = self.sub.conv1(..) |
146 | // x = self.sub.bn1(..) |
147 | // x = self.sub.conv2(..) |
148 | // x = self.sub.bn2(..) |
149 | // then the value for graph `g` in this map will be: |
150 | // [(['sub', 'conv1'], ['sub', 'bn1']), (['sub', 'conv2'], ['sub', 'bn2'])] |
151 | // the first entry of the list is the paths to first conv-bn match |
152 | // the second entry of the list is the path to second match |
153 | std::unordered_map< |
154 | Graph*, |
155 | std::vector< |
156 | std::tuple<std::vector<std::string>, std::vector<std::string>>>> |
157 | conv_bn_paths_; |
158 | |
159 | std::unordered_map<Value*, Value*> rewrite_map_; |
160 | std::vector<Value*> values_to_rewrite_; |
161 | std::unordered_set<Node*> nodes_to_delete_; |
162 | }; |
163 | |
164 | bool (const script::Module& bn, ConvBNParameters& r) { |
165 | auto bn_forward = bn.get_method("forward" ); |
166 | auto graph = bn_forward.graph(); |
167 | const PatternInfo& pattern_bn = PatternInfo::parse_from_str(R"( |
168 | graph(%a, %weight, %bias, %running_mean, %running_var, |
169 | %training, %momentum, %eps, %cudnn_enabled): |
170 | %bn_out = aten::batch_norm(%a, %weight, %bias, %running_mean, |
171 | %running_var, %training, %momentum, %eps, %cudnn_enabled) |
172 | return (%bn_out) )" ); |
173 | const Graph& pattern_bn_graph = *pattern_bn.pattern_graph; |
174 | const auto& bn_vmap = pattern_bn.vmap; |
175 | |
176 | const auto& matches = findPatternMatches(pattern_bn_graph, *graph); |
177 | |
178 | if (matches.size() > 1) { |
179 | return false; |
180 | } |
181 | |
182 | if (bn.hasattr("eps" )) { |
183 | r.bn_eps = bn.attr("eps" ).toDouble(); |
184 | } else { |
185 | auto optional_eps = toIValue(matches[0].values_map.at(bn_vmap.at("eps" ))); |
186 | if (!optional_eps) { |
187 | return false; |
188 | } |
189 | r.bn_eps = optional_eps.value().toDouble(); |
190 | } |
191 | r.bn_w = at::ones_like(bn.attr("running_mean" ).toTensor()); |
192 | if (bn.hasattr("weight" )) { |
193 | if (bn.attr("weight" ).isTensor()) { |
194 | r.bn_w = bn.attr("weight" ).toTensor(); |
195 | } |
196 | } else { |
197 | auto optional_bn_weight = |
198 | toIValue(matches[0].values_map.at(bn_vmap.at("weight" ))); |
199 | if (!optional_bn_weight) { |
200 | return false; |
201 | } |
202 | if (optional_bn_weight.value().isTensor()) { |
203 | r.bn_w = optional_bn_weight.value().toTensor(); |
204 | } |
205 | } |
206 | r.bn_b = at::zeros_like(bn.attr("running_mean" ).toTensor()); |
207 | if (bn.hasattr("bias" )) { |
208 | if (bn.attr("bias" ).isTensor()) { |
209 | r.bn_b = bn.attr("bias" ).toTensor(); |
210 | } |
211 | } else { |
212 | auto optional_bn_bias = |
213 | toIValue(matches[0].values_map.at(bn_vmap.at("bias" ))); |
214 | if (!optional_bn_bias) { |
215 | return false; |
216 | } |
217 | |
218 | if (optional_bn_bias.value().isTensor()) { |
219 | r.bn_b = optional_bn_bias.value().toTensor(); |
220 | } |
221 | } |
222 | return true; |
223 | } |
224 | |
225 | bool FoldConvBatchNormHelper::( |
226 | Module& conv, |
227 | Module& bn, |
228 | ConvBNParameters& r) { |
229 | if (!hastensor(conv, "weight" ) || !conv.hasattr("bias" ) || |
230 | !hastensor(bn, "running_mean" ) || !hastensor(bn, "running_var" )) { |
231 | return false; |
232 | } |
233 | |
234 | r.bn_rm = bn.attr("running_mean" ).toTensor(); |
235 | r.bn_rv = bn.attr("running_var" ).toTensor(); |
236 | if (!extractOptionalBNParams(bn, r)) { |
237 | return false; |
238 | } |
239 | |
240 | r.conv_w = conv.attr("weight" ).toTensor(); |
241 | r.conv_b = at::zeros_like(r.bn_rm); |
242 | auto bias_opt = conv.attr("bias" ).toOptional<at::Tensor>(); |
243 | if (bias_opt) { |
244 | r.conv_b = *bias_opt; |
245 | } |
246 | |
247 | return true; |
248 | } |
249 | |
250 | void FoldConvBatchNormHelper::analyze( |
251 | Module& module, |
252 | const PatternInfo& pattern) { |
253 | const Graph& pattern_graph = *pattern.pattern_graph; |
254 | const auto& vmap = pattern.vmap; |
255 | Value* pattern_conv_out = vmap.at("conv_out" ); |
256 | Value* pattern_bn_out = vmap.at("bn_out" ); |
257 | Value* pattern_bn_submodule = vmap.at("batchnorm" ); |
258 | Node* pattern_conv = pattern_conv_out->node(); |
259 | Node* pattern_bn = pattern_bn_out->node(); |
260 | |
261 | // We will put submodules into this worklist and keep processing items from it |
262 | // one by one. We start by just putting the top module there. |
263 | std::stack<Module> worklist({module}); |
264 | while (!worklist.empty()) { |
265 | Module current = worklist.top(); |
266 | worklist.pop(); |
267 | |
268 | // Queue submodules for processing |
269 | for (const Module& submodule : current.children()) { |
270 | worklist.push(submodule); |
271 | } |
272 | |
273 | // Process all method of the current module |
274 | for (auto& method : current.get_methods()) { |
275 | GRAPH_DUMP( |
276 | current.type()->name()->name() + "::" + method.name() + |
277 | "() before Conv-BatchNorm folding" , |
278 | method.graph()); |
279 | const auto& matches = findPatternMatches(pattern_graph, *method.graph()); |
280 | |
281 | GRAPH_DEBUG("number of Conv-BatchNorm matches: " , matches.size()); |
282 | Graph* g = method.graph().get(); |
283 | if (!conv_bn_paths_.count(g)) { |
284 | // This is to make sure we don't visit one graph multiple times |
285 | conv_bn_paths_[g] = {}; |
286 | for (const Match& match : matches) { |
287 | if (!std::all_of( |
288 | pattern.filters.begin(), |
289 | pattern.filters.end(), |
290 | [&](const MatchFilter& f) { return f(match, vmap); })) { |
291 | continue; |
292 | } |
293 | GRAPH_DEBUG("Checking next match..." ); |
294 | // Get the conv and bn submodule |
295 | Node* matched_conv = match.nodes_map.at(pattern_conv); |
296 | Node* matched_bn = match.nodes_map.at(pattern_bn); |
297 | Node* matched_bn_submodule = |
298 | match.values_map.at(pattern_bn_submodule)->node(); |
299 | Value* conv_instance = matched_conv->input(0); |
300 | Value* bn_instance = matched_bn->input(0); |
301 | Value* self = g->inputs()[0]; |
302 | auto conv_module_path = getModuleAccessPath(conv_instance, self); |
303 | auto bn_module_path = getModuleAccessPath(bn_instance, self); |
304 | Module conv_submodule = findChildModule(current, conv_module_path); |
305 | Module bn_submodule = findChildModule(current, bn_module_path); |
306 | |
307 | ConvBNParameters params; |
308 | if (!tryExtractingConvBNParameters( |
309 | conv_submodule, bn_submodule, params)) { |
310 | GRAPH_DEBUG( |
311 | "Conv and BN modules didn't have all required parameters or attributes..." ); |
312 | continue; |
313 | } |
314 | conv_bn_paths_[g].emplace_back(conv_module_path, bn_module_path); |
315 | // We are using a separate vector for saving Values we want to rewrite |
316 | // to make sure that the order in which we perform these |
317 | // transformations is deterministic. Iterating through keys of |
318 | // rewrite_map would result in non-determinism that might not manifest |
319 | // as a bug now, but can bite us later. |
320 | values_to_rewrite_.push_back(matched_bn->output()); |
321 | rewrite_map_[matched_bn->output()] = matched_conv->output(); |
322 | GRAPH_UPDATE( |
323 | "Rewriting %" , |
324 | matched_bn->output()->debugName(), |
325 | " with %" , |
326 | matched_conv->output()->debugName()); |
327 | |
328 | nodes_to_delete_.insert(matched_bn); |
329 | nodes_to_delete_.insert(matched_bn_submodule); |
330 | GRAPH_UPDATE("Deleting " , *matched_bn); |
331 | GRAPH_UPDATE("Deleting " , *matched_bn_submodule); |
332 | |
333 | auto slot = conv_submodule.type()->getAttributeSlot("bias" ); |
334 | TORCH_CHECK( |
335 | conv_submodule.type()->is_parameter(slot), |
336 | "Expected conv module to have a bias parameter" ); |
337 | } // matches |
338 | } |
339 | |
340 | for (const auto& conv_bn : conv_bn_paths_.at(g)) { |
341 | Module conv_submodule = findChildModule(current, std::get<0>(conv_bn)); |
342 | Module bn_submodule = findChildModule(current, std::get<1>(conv_bn)); |
343 | |
344 | ConvBNParameters params; |
345 | TORCH_INTERNAL_ASSERT(tryExtractingConvBNParameters( |
346 | conv_submodule, bn_submodule, params)); |
347 | auto new_w_b = computeUpdatedConvWeightAndBias(params); |
348 | conv_module_and_params_[conv_submodule._ivalue()] = new_w_b; |
349 | } // conv_bn module |
350 | } // methods |
351 | } // while |
352 | } |
353 | |
354 | void FoldConvBatchNormHelper::transform() { |
355 | for (const auto& item : conv_module_and_params_) { |
356 | Module conv(item.first); |
357 | auto w_b = item.second; |
358 | conv.setattr("weight" , std::get<0>(w_b)); |
359 | conv.setattr("bias" , std::get<1>(w_b)); |
360 | } |
361 | |
362 | // Perform planned rewritings |
363 | for (auto v : values_to_rewrite_) { |
364 | v->replaceAllUsesWith(rewrite_map_.at(v)); |
365 | } |
366 | |
367 | // Perform planned deletions |
368 | for (auto n : nodes_to_delete_) { |
369 | n->removeAllInputs(); |
370 | } |
371 | for (auto n : nodes_to_delete_) { |
372 | n->destroy(); |
373 | } |
374 | } |
375 | |
376 | } // namespace |
377 | |
378 | Module FoldConvBatchNorm(const Module& module) { |
379 | Module m = module.clone(); |
380 | |
381 | addBiasForConvIfNone(m, "Conv2d" ); |
382 | addBiasForConvIfNone(m, "Conv3d" ); |
383 | // Conv2d + BatchNorm2d |
384 | const PatternInfo pattern2d = PatternInfo::parse_from_str( |
385 | R"( |
386 | graph(%self, %input, %conv, %batchnorm): |
387 | %conv_out = prim::CallMethod[name="forward"](%conv, %input) |
388 | %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out) |
389 | return (%bn_out))" , |
390 | {is_conv2d_module, is_batchnorm2d_module}); |
391 | // Conv3d + BatchNorm3d |
392 | const PatternInfo pattern3d = PatternInfo::parse_from_str( |
393 | R"( |
394 | graph(%self, %input, %conv, %batchnorm): |
395 | %conv_out = prim::CallMethod[name="forward"](%conv, %input) |
396 | %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out) |
397 | return (%bn_out))" , |
398 | {is_conv3d_module, is_batchnorm3d_module}); |
399 | |
400 | const std::vector<std::reference_wrapper<const PatternInfo>> patterns = { |
401 | pattern2d, pattern3d}; |
402 | for (const auto& pattern : patterns) { |
403 | FoldConvBatchNormHelper h; |
404 | h.analyze(m, pattern); |
405 | h.transform(); |
406 | } |
407 | return m; |
408 | } |
409 | |
410 | } // namespace jit |
411 | } // namespace torch |
412 | |