1 | #include <torch/csrc/jit/passes/freeze_module.h> |
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/jit/api/function_impl.h> |
7 | #include <torch/csrc/jit/ir/alias_analysis.h> |
8 | #include <torch/csrc/jit/passes/autocast.h> |
9 | #include <torch/csrc/jit/passes/clear_profiling.h> |
10 | #include <torch/csrc/jit/passes/eliminate_no_ops.h> |
11 | #include <torch/csrc/jit/passes/inliner.h> |
12 | #include <torch/csrc/jit/passes/lower_tuples.h> |
13 | #include <torch/csrc/jit/passes/remove_mutation.h> |
14 | #include <torch/csrc/jit/runtime/graph_executor_impl.h> |
15 | |
16 | #include <stack> |
17 | #include <utility> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | namespace { |
23 | |
24 | std::vector<std::string> splitName(const std::string& name) { |
25 | std::vector<std::string> result; |
26 | std::string sub_name; |
27 | std::istringstream name_stream(name); |
28 | while (std::getline(name_stream, sub_name, '.')) { |
29 | result.push_back(std::move(sub_name)); |
30 | } |
31 | return result; |
32 | } |
33 | |
34 | template <typename Iter> |
35 | std::string concatName(const Iter& begin, const Iter& end) { |
36 | std::string combined_name = "" ; |
37 | for (Iter it = begin; it != end; ++it) { |
38 | const std::string& sub_name = *it; |
39 | if (!combined_name.empty()) { |
40 | combined_name += "." ; |
41 | } |
42 | combined_name += sub_name; |
43 | } |
44 | return combined_name; |
45 | } |
46 | |
47 | class AttributePropagator { |
48 | public: |
49 | AttributePropagator( |
50 | Module& module, |
51 | std::vector<std::string>& preservedAttrs, |
52 | bool freezeInterfaces, |
53 | bool preserveParameters) |
54 | : module_(module), |
55 | freezeInterfaces_(freezeInterfaces), |
56 | preserveParameters_(preserveParameters) { |
57 | auto checkName = [this](std::string& name) { |
58 | const auto resolved_name = resolveName(name); |
59 | |
60 | if (resolved_name) { |
61 | const auto& parent_module = resolved_name->first; |
62 | const auto& attr_name = resolved_name->second; |
63 | if (parent_module.hasattr(attr_name)) { |
64 | auto value = parent_module.attr(attr_name); |
65 | // Freezing client wants to presever this submodule. When cleaning |
66 | // the frozen module, make sure it will be preserved entirely. |
67 | if (value.isModule()) { |
68 | preservedSubModule_.insert(value.toModule()._ivalue()); |
69 | } |
70 | insertMutableAttr(attr_name, value, parent_module._ivalue()); |
71 | } else { |
72 | auto fn = parent_module.get_method(attr_name); |
73 | preservedMethods_.insert(&fn.function()); |
74 | } |
75 | return true; |
76 | } |
77 | |
78 | return false; |
79 | }; |
80 | |
81 | // forward is preserved by default, but |
82 | // not all modules have a forward function defined |
83 | if (module_.find_method("forward" )) { |
84 | auto method = module_.get_method("forward" ); |
85 | preservedMethods_.insert(&method.function()); |
86 | } |
87 | |
88 | for (auto name : preservedAttrs) { |
89 | TORCH_CHECK(checkName(name), "Unknown name: " + name); |
90 | } |
91 | } |
92 | |
93 | void optimizeSubGraphs( |
94 | std::shared_ptr<Graph>& graph, |
95 | const std::function<void(std::shared_ptr<Graph>&)>& func) { |
96 | func(graph); |
97 | std::stack<Block*> blocks({graph->block()}); |
98 | while (!blocks.empty()) { |
99 | Block* block = blocks.top(); |
100 | blocks.pop(); |
101 | for (auto n : block->nodes()) { |
102 | for (Block* sub_block : n->blocks()) { |
103 | blocks.push(sub_block); |
104 | } |
105 | if (n->kind() == prim::fork) { |
106 | auto subgraph = n->g(attr::Subgraph); |
107 | optimizeSubGraphs(subgraph, func); |
108 | } |
109 | } |
110 | } |
111 | } |
112 | |
113 | void run() { |
114 | auto applyInline = [](std::shared_ptr<Graph>& subgraph) { |
115 | Inline(*subgraph); |
116 | ClearProfilingInformation(subgraph); |
117 | }; |
118 | auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) { |
119 | #ifndef C10_MOBILE |
120 | Autocast(subgraph); |
121 | #endif |
122 | runOptimization( |
123 | subgraph, |
124 | /* unroll_non_constant_loops? */ false, |
125 | /* const_prop_user_classes? */ false); |
126 | EliminateNoOps(subgraph); |
127 | LowerSimpleTuples(subgraph); |
128 | }; |
129 | |
130 | std::unordered_map<std::string, std::unordered_set<std::string>> |
131 | interfacesToReassignType; |
132 | |
133 | for (auto function : preservedMethods_) { |
134 | GRAPH_DEBUG("Analyzing function: " + function->name()); |
135 | auto graph = toGraphFunction(*function).graph(); |
136 | optimizeSubGraphs(graph, applyInline); |
137 | if (freezeInterfaces_) { |
138 | inlineInterfaceCalls(graph, interfacesToReassignType); |
139 | } |
140 | } |
141 | |
142 | reassignInterfaceTypes(interfacesToReassignType); |
143 | |
144 | for (auto function : preservedMethods_) { |
145 | GRAPH_DEBUG("Recording mutable attrs for function: " + function->name()); |
146 | auto graph = toGraphFunction(*function).graph(); |
147 | // Record Attributes that are explicitly set in the module. |
148 | // They cannot be folded. |
149 | recordMutableAttrs(graph); |
150 | } |
151 | |
152 | for (auto function : preservedMethods_) { |
153 | GRAPH_DEBUG("Propagating function: " + function->name()); |
154 | auto graph = toGraphFunction(*function).graph(); |
155 | propagateAttributes(graph); |
156 | optimizeSubGraphs(graph, applyOptimizations); |
157 | } |
158 | GRAPH_DEBUG("Cleaning up module" ); |
159 | cleanupFrozenModule(); |
160 | } |
161 | |
162 | private: |
163 | using ResolvedName = std::pair<Module, std::string>; |
164 | |
165 | // Try to resolve qualified names (submodule1.submodule2.foo). If |
166 | // the qualified name exists in the root module, return the unqualified |
167 | // attribute/function name and the parent module. Else, return nullopt. |
168 | // Examples: |
169 | // submodule1.submodule2.foo -> {submodule2, "foo"} |
170 | // submodule1.non_existent_module.foo -> nullopt |
171 | c10::optional<ResolvedName> resolveName(const std::string& name) { |
172 | auto sub_names = splitName(name); |
173 | if (sub_names.empty()) { |
174 | return c10::nullopt; |
175 | } |
176 | auto& attr_name = sub_names.back(); |
177 | auto cur_module = module_; |
178 | std::vector<ResolvedName> attr_infos; |
179 | attr_infos.reserve(sub_names.size() - 1); |
180 | |
181 | for (size_t i = 0; i < sub_names.size() - 1; ++i) { |
182 | bool found = false; |
183 | const auto& sub_name = sub_names[i]; |
184 | for (const auto& child_module : cur_module.named_children()) { |
185 | if (child_module.name == sub_name) { |
186 | attr_infos.emplace_back(cur_module._ivalue(), child_module.name); |
187 | cur_module = child_module.value; |
188 | found = true; |
189 | break; |
190 | } |
191 | } |
192 | if (!found) { |
193 | return c10::nullopt; |
194 | } |
195 | } |
196 | |
197 | if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) { |
198 | // We don't want to mark these modules as mutable yet; that could |
199 | // interfere with the inlining procedure. Instead, we'll record |
200 | // the fact that the user wants to preserve them. They will be |
201 | // processed during clean-up preparation (recordReferenceAttrs) |
202 | for (auto& attr_info : attr_infos) { |
203 | const auto& parent_module = attr_info.first; |
204 | auto& sub_name = attr_info.second; |
205 | userPreservedAttrs_[parent_module._ivalue()].insert( |
206 | std::move(sub_name)); |
207 | } |
208 | return std::make_pair(std::move(cur_module), std::move(attr_name)); |
209 | } |
210 | |
211 | return c10::nullopt; |
212 | } |
213 | |
214 | bool _loadModulePath(Value* input, std::shared_ptr<Graph>& graph) { |
215 | Node* node = input->node(); |
216 | names_.clear(); |
217 | while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { |
218 | if (node->kind() == prim::GetAttr) { |
219 | names_.push_front(node->s(attr::name)); |
220 | node = node->inputs()[0]->node(); |
221 | } else { |
222 | return false; |
223 | } |
224 | } |
225 | |
226 | return true; |
227 | } |
228 | |
229 | c10::optional<std::deque<std::string>> getModulePath( |
230 | Value* input, |
231 | std::shared_ptr<Graph>& graph) { |
232 | bool success = _loadModulePath(input, graph); |
233 | if (!success) { |
234 | return c10::nullopt; |
235 | } |
236 | return names_; |
237 | } |
238 | |
239 | template <typename Iter> |
240 | bool getModuleFromPath( |
241 | Module& attrModule, |
242 | const Iter& begin, |
243 | const Iter& end) { |
244 | for (Iter it = begin; it != end; ++it) { |
245 | const std::string& moduleName = *it; |
246 | if (preservedAttrs_.count(attrModule.attr(moduleName))) { |
247 | return false; |
248 | } |
249 | attrModule = attrModule.attr(moduleName).toModule(); |
250 | } |
251 | return true; |
252 | } |
253 | |
254 | // findConstantAttr function locates the sub Module where attributes are |
255 | // defined. The algorithm chases getAttr chains to locate the submodules. |
256 | // For example: |
257 | // module M { |
258 | // attributes { |
259 | // A = <SubModule at ...> |
260 | // } |
261 | // ... |
262 | // %A = prim::GetAttr[name="A"](%self) |
263 | // ... |
264 | // %B = prim::GetAttr[name="B"](%A) |
265 | // ... |
266 | // %weight = prim::GetAttr[name="scale"](%B) |
267 | // ... |
268 | // submodules { |
269 | // module SubModule { |
270 | // attributes { |
271 | // B = <SubModule2 at ...> |
272 | // } |
273 | // submodules { |
274 | // module SubModule2 { |
275 | // attributes { |
276 | // scale = 2 |
277 | // } |
278 | // } |
279 | // } |
280 | // } |
281 | // } |
282 | // |
283 | // findConstantAttr(%B, "scale", M) returns true because there are no |
284 | // explicit SetAttr that modifies %B. attrModule points to the module where |
285 | // attribute lives (in this example it is <SubModule2 at ...>). |
286 | // |
287 | // Note inplace mutations to attributes are checked later using alias |
288 | // analysis. |
289 | // |
290 | // We can use a more efficient algorithm to hash each constant GetAttr to its |
291 | // corresponding value. Based on initial test on resnet50 and other torch |
292 | // vision tests. GetAttrs are not too frequent so it is ok to chase GetAttr |
293 | // chain to retrieve their values. |
294 | bool findConstantAttr( |
295 | Value* input, |
296 | std::string& name, |
297 | Module& attrModule, |
298 | std::shared_ptr<Graph>& graph) { |
299 | if (!input->type()->cast<InterfaceType>() && |
300 | !input->type()->expectRef<ClassType>().is_module()) { |
301 | return false; |
302 | } |
303 | |
304 | // loads the path into this->names_ |
305 | if (!_loadModulePath(input, graph)) { |
306 | return false; |
307 | } |
308 | |
309 | // reassigns attrModule to the module in names_ |
310 | if (!getModuleFromPath(attrModule, names_.begin(), names_.end())) { |
311 | return false; |
312 | } |
313 | |
314 | auto attr = attrModule.attr(name); |
315 | if (!AliasDb::isMutableType(attr.type())) { |
316 | auto it = preservedScalarAttrs_.find(attrModule._ivalue()); |
317 | return it == preservedScalarAttrs_.end() || !it->second.count(name); |
318 | } |
319 | |
320 | if (preservedAttrs_.count(attr)) { |
321 | return false; |
322 | } |
323 | if (!attr.type()->cast<ClassType>()) { |
324 | for (auto& ivalue : preservedAttrs_) { |
325 | if (!ivalue.isObject() && ivalue.overlaps(attr)) { |
326 | return false; |
327 | } |
328 | } |
329 | } |
330 | return true; |
331 | } |
332 | |
333 | void insertMutableAttr( |
334 | const std::string& name, |
335 | const IValue& attr, |
336 | const ModulePtr& attrModule) { |
337 | if (AliasDb::isMutableType(attr.type())) { |
338 | preservedAttrs_.insert(attr); |
339 | } else { |
340 | preservedScalarAttrs_[attrModule].insert(name); |
341 | } |
342 | } |
343 | |
344 | void recordMutableAttrs(std::shared_ptr<Graph>& graph) { |
345 | std::stack<Block*> blocks({graph->block()}); |
346 | std::unique_ptr<AliasDb> aliasDb = |
347 | torch::make_unique<AliasDb>(graph, /* isFrozen */ true); |
348 | while (!blocks.empty()) { |
349 | Block* block = blocks.top(); |
350 | blocks.pop(); |
351 | for (auto n : block->nodes()) { |
352 | for (Block* sub_block : n->blocks()) { |
353 | blocks.push(sub_block); |
354 | } |
355 | |
356 | // Modules with prim::ModuleContainerIndex cannot be frozen because they |
357 | // return InterfaceTypes. |
358 | TORCH_CHECK( |
359 | n->kind() != prim::ModuleContainerIndex, |
360 | "Freezing modules containing prim::ModuleContainerIndex is not supported" ); |
361 | |
362 | if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) { |
363 | // By default if interface attributes are present then fail freezing. |
364 | // If freezingInterfaces is on then Interfaces are folded similarly |
365 | // to other attributes. |
366 | TORCH_CHECK( |
367 | freezeInterfaces_ || |
368 | !(n->kind() == prim::GetAttr && |
369 | n->output()->type()->cast<InterfaceType>()), |
370 | "attempted to freeze a module that uses interface attributes" ); |
371 | auto name = n->s(attr::name); |
372 | auto attrModule = module_; |
373 | if (!findConstantAttr(n->inputs()[0], name, attrModule, graph)) { |
374 | continue; |
375 | } |
376 | |
377 | auto attr = attrModule.attr(name); |
378 | if (n->kind() == prim::GetAttr) { |
379 | auto type = n->output()->type(); |
380 | // Do not record submodules. Their attributes are tracked |
381 | // individually. |
382 | if (attr.isObject() || !AliasDb::isMutableType(attr.type())) { |
383 | continue; |
384 | } |
385 | usedAttrs_.insert(attr); |
386 | } |
387 | |
388 | if (n->kind() == prim::SetAttr || aliasDb->hasOutputWriters(n)) { |
389 | GRAPH_DEBUG( |
390 | n->kind() == prim::GetAttr ? "attribute: " + name + " in %" + |
391 | n->output()->debugName() + " has inplace writer" |
392 | : "attribute: " + name + " is set" ); |
393 | auto mptr = attrModule._ivalue(); |
394 | insertMutableAttr(name, attr, mptr); |
395 | } |
396 | } else if (n->kind() == prim::fork) { |
397 | applyToForkSubgraph( |
398 | n, |
399 | graph, |
400 | // NOLINTNEXTLINE(modernize-avoid-bind) |
401 | std::bind( |
402 | &AttributePropagator::recordMutableAttrs, |
403 | *this, |
404 | std::placeholders::_1)); |
405 | } |
406 | } |
407 | } |
408 | // FIXME: Current Alias analysis fails to track subvalues. |
409 | // This is not a common scenario, for freezing, detect and error out. |
410 | IValue::HashAliasedIValues seen; |
411 | for (auto& val : usedAttrs_) { |
412 | IValue::HashAliasedIValues subValues; |
413 | val.getSubValues(subValues); |
414 | TORCH_CHECK( |
415 | std::all_of( |
416 | subValues.begin(), |
417 | subValues.end(), |
418 | [&seen](const IValue& v) { return seen.count(v) == 0; }), |
419 | "module contains attributes values that overlaps " , |
420 | val); |
421 | seen.insert(subValues.begin(), subValues.end()); |
422 | } |
423 | } |
424 | |
425 | IValue overrideGradient(IValue attr) { |
426 | if (attr.isTensor()) { |
427 | auto& t = attr.toTensor(); |
428 | if (t.requires_grad()) { |
429 | auto detached = t.detach(); |
430 | detached.set_requires_grad(false); |
431 | attr = IValue(std::move(detached)); |
432 | } |
433 | } else if (attr.isTuple()) { |
434 | auto tuple = std::move(attr).toTuple(); |
435 | const auto& elems = tuple->elements(); |
436 | for (const auto idx : c10::irange(elems.size())) { |
437 | tuple->unsafeSetElement(idx, overrideGradient(elems[idx])); |
438 | } |
439 | attr = std::move(tuple); |
440 | } else if (attr.isList()) { |
441 | c10::List<IValue> elems = std::move(attr).toList(); |
442 | for (const auto i : c10::irange(elems.size())) { |
443 | elems.set(i, overrideGradient(elems.extract(i))); |
444 | } |
445 | attr = elems; |
446 | } else if (attr.isGenericDict()) { |
447 | auto dict = std::move(attr).toGenericDict(); |
448 | for (const auto& pair : dict) { |
449 | auto val = pair.value(); |
450 | val = overrideGradient(std::move(val)); |
451 | } |
452 | attr = dict; |
453 | } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) { |
454 | auto obj_type = attr.type()->expect<ClassType>(); |
455 | auto obj_value = std::move(attr).toObject(); |
456 | auto sub_attributes = obj_type->getAttributes(); |
457 | for (const auto& sub_attr : sub_attributes) { |
458 | auto sub_attr_val = obj_value->getAttr(sub_attr.getName()); |
459 | sub_attr_val = overrideGradient(std::move(sub_attr_val)); |
460 | } |
461 | return obj_value; |
462 | } |
463 | |
464 | return attr; |
465 | } |
466 | |
467 | // This method is invoked only when 'freezeInterfaces' parameter is on. |
468 | // The module associated with Interface is retrieved and the invoked method |
469 | // is inlined. |
470 | bool inlineInterfaceCall(Node* n, const IValue& attr) { |
471 | auto class_type = attr.type()->expect<ClassType>(); |
472 | bool inlined = false; |
473 | for (auto use : n->output()->uses()) { |
474 | auto user_node = use.user; |
475 | if (user_node->kind() == prim::CallMethod) { |
476 | const std::string& methodName = user_node->s(attr::name); |
477 | Function& function = class_type->getMethod(methodName); |
478 | if (auto graphFunction = tryToGraphFunction(function)) { |
479 | GRAPH_UPDATE( |
480 | "Inlining interface method '" , |
481 | function.name(), |
482 | "' to " , |
483 | *user_node); |
484 | |
485 | GRAPH_UPDATE("Function body: " , graphFunction->optimized_graph()); |
486 | inlineCallTo(user_node, graphFunction); |
487 | inlined = true; |
488 | } |
489 | } |
490 | } |
491 | return inlined; |
492 | } |
493 | |
494 | // [Note: Inlining interfaces strategy] |
495 | // There's two structures that are relevant to freezing: |
496 | // - the graph describing the computation in a method |
497 | // - the module describing the data structure of the module instance. |
498 | // |
499 | // First, in inlineInterfaceCalls, we inline interfaces. This is done in a |
500 | // separate step from normal inlining because CallMethod on an interface type |
501 | // requires extra steps compared to inlining a normal CallMethod. |
502 | // |
503 | // Next we need to simplify the structure of the module data structure, which |
504 | // is done for the most part by the usual steps in cleanupFrozenModule. |
505 | // |
506 | // However, there's a complication that comes from the fact that within a |
507 | // method, you can change the value of an interface to another module that |
508 | // implements that interface. |
509 | // |
510 | // For example: |
511 | // |
512 | // impl: MyInterface |
513 | // ... |
514 | // def forward(self, x): |
515 | // if x > 0: |
516 | // self.impl = my_interface_impl |
517 | // |
518 | // This is disallowed in freezing, because in this case we can't flatten out |
519 | // the module structure, since the type of self.impl will change. |
520 | // |
521 | // To handle this, we do the following: |
522 | // 1. inlineInterfaceCalls: |
523 | // a. inline the graph, and in the process record all interfaces |
524 | // b. simultaneously, check (throw error) for disallowed SetAttr calls. |
525 | // 2. call reassignInterfaceTypes, which reassigns interface types to their |
526 | // concrete types. This is done in a separate step to avoid interfering |
527 | // with inlineInterfaceCalls (note: this may not need to be done as a |
528 | // separate step) |
529 | // 3. eventually cleanupFrozenModule will reorder the module data structure |
530 | // and it will expect that all interface types have been removed. |
531 | void inlineInterfaceCalls( |
532 | std::shared_ptr<Graph>& graph, |
533 | std::unordered_map<std::string, std::unordered_set<std::string>>& |
534 | interfacesToRetype) { |
535 | auto block = graph->block(); |
536 | std::stack<Block*> blocks({block}); |
537 | |
538 | while (!blocks.empty()) { |
539 | Block* block = blocks.top(); |
540 | blocks.pop(); |
541 | for (auto n : block->nodes()) { |
542 | for (Block* sub_block : n->blocks()) { |
543 | blocks.push(sub_block); |
544 | } |
545 | if (n->kind() == prim::GetAttr) { |
546 | if (!n->output()->type()->cast<InterfaceType>()) { |
547 | continue; |
548 | } |
549 | auto name = n->s(attr::name); |
550 | auto attrModule = module_; |
551 | auto input = n->inputs()[0]; |
552 | TORCH_CHECK( |
553 | findConstantAttr(input, name, attrModule, graph), |
554 | "failed to freeze interface attribute '" + name + "'" ); |
555 | TORCH_INTERNAL_ASSERT(attrModule.hasattr(name)); |
556 | auto attr = attrModule.attr(name); |
557 | inlineInterfaceCall(n, attr); |
558 | // Reset the GetAttr to concrete module type. |
559 | n->output()->setType(attr.type()); |
560 | |
561 | // Record this so that we can reassign the type later |
562 | // in reassignInterfaceTypes() |
563 | // See [Note: Inlining interfaces strategy] |
564 | auto path = getModulePath(input, graph); |
565 | TORCH_INTERNAL_ASSERT(path.has_value()); |
566 | auto path_str = concatName(path->begin(), path->end()); |
567 | interfacesToRetype[path_str].insert(name); |
568 | } else if (n->kind() == prim::SetAttr) { |
569 | // Check to make sure we're not assigning the value of any parameters |
570 | // that are interface types. |
571 | // See [Note: Inlining interfaces strategy] |
572 | auto name = n->s(attr::name); |
573 | auto attrModule = module_; |
574 | auto input = n->inputs()[0]; |
575 | |
576 | if (!input->type()->cast<InterfaceType>() && |
577 | !input->type()->expectRef<ClassType>().is_module()) { |
578 | // we only care if we're setattr["thing"](%mod) if %mod |
579 | continue; |
580 | } |
581 | |
582 | // note: this will modify attrModule until it is the parent of the |
583 | // "name" attr. In other words, attrModule is now the module that |
584 | // matches "input". |
585 | // We can't use findConstantAttr in case the base item is an object, |
586 | // instead of a module/interface. |
587 | auto path = getModulePath(input, graph); |
588 | TORCH_INTERNAL_ASSERT(path.has_value()); |
589 | getModuleFromPath(attrModule, path->begin(), path->end()); |
590 | |
591 | const auto& attrType = attrModule.type()->getAttribute(name); |
592 | TORCH_INTERNAL_ASSERT( |
593 | !attrType->cast<InterfaceType>(), |
594 | "Freezing does not support SetAttr on an interface type. " , |
595 | "SetAttr is attempted on '" , |
596 | name, |
597 | "'" ); |
598 | } else if (n->kind() == prim::fork) { |
599 | applyToForkSubgraph( |
600 | n, |
601 | graph, |
602 | // NOLINTNEXTLINE(modernize-avoid-bind) |
603 | std::bind( |
604 | &AttributePropagator::inlineInterfaceCalls, |
605 | *this, |
606 | std::placeholders::_1, |
607 | interfacesToRetype)); |
608 | } |
609 | } |
610 | } |
611 | } |
612 | |
613 | // See [Note: Inlining interfaces strategy] |
614 | // This modifies the internal structure of module types to reassign the |
615 | // type from an interface type to its concrete type. |
616 | void reassignInterfaceTypes( |
617 | const std::unordered_map<std::string, std::unordered_set<std::string>>& |
618 | interfacesToRetype) { |
619 | for (const auto& it : interfacesToRetype) { |
620 | const std::string& modulePath = it.first; |
621 | const std::vector<std::string>& splitPath = splitName(modulePath); |
622 | Module attrModule = module_; |
623 | getModuleFromPath(attrModule, splitPath.begin(), splitPath.end()); |
624 | |
625 | for (const std::string& name : it.second) { |
626 | auto subvalue = attrModule.attr(name); |
627 | auto subvalueType = subvalue.type(); |
628 | attrModule.type()->unsafeChangeAttributeType(name, subvalueType); |
629 | } |
630 | } |
631 | } |
632 | |
633 | void propagateAttributes(std::shared_ptr<Graph>& graph) { |
634 | std::unordered_map<ModulePtr, std::unordered_map<std::string, Value*>> |
635 | attrValues; |
636 | auto isEval = !module_.hasattr("training" ) || !module_.is_training(); |
637 | GRAPH_DEBUG("Freezing Module: " , module_.type()->name()->name()); |
638 | auto block = graph->block(); |
639 | std::stack<Block*> blocks({block}); |
640 | |
641 | Node* m = *block->nodes().begin(); |
642 | WithInsertPoint guard(m); |
643 | while (!blocks.empty()) { |
644 | Block* block = blocks.top(); |
645 | blocks.pop(); |
646 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
647 | Node* n = *it; |
648 | it++; // advance iterator bc the current node may be destroyed |
649 | |
650 | for (Block* sub_block : n->blocks()) { |
651 | blocks.push(sub_block); |
652 | } |
653 | if (n->kind() == prim::GetAttr) { |
654 | auto name = n->s(attr::name); |
655 | auto attrModule = module_; |
656 | auto input = n->inputs()[0]; |
657 | if (!findConstantAttr(input, name, attrModule, graph)) { |
658 | GRAPH_DEBUG( |
659 | input->type()->cast<InterfaceType>() || |
660 | input->type()->expectRef<ClassType>().is_module() |
661 | ? "attribute: " + name + " is mutable." |
662 | : "" ); |
663 | continue; |
664 | } |
665 | TORCH_INTERNAL_ASSERT(attrModule.hasattr(name)); |
666 | Value* paramConst = nullptr; |
667 | auto iter = attrValues.find(attrModule._ivalue()); |
668 | if (iter != attrValues.end()) { |
669 | auto iter2 = iter->second.find(name); |
670 | if (iter2 != iter->second.end()) |
671 | paramConst = iter2->second; |
672 | } |
673 | if (!paramConst) { |
674 | auto attr = attrModule.attr(name); |
675 | if (!isEval || preserveParameters_) { |
676 | auto type = attrModule.type(); |
677 | auto slot = *type->findAttributeSlot(name); |
678 | if (type->is_parameter(slot) || type->is_buffer(slot) || |
679 | (attr.isObject() && |
680 | !attr.toObjectRef().type()->is_module())) { |
681 | continue; |
682 | } else { |
683 | attr = overrideGradient(attr); |
684 | } |
685 | if (!isEval && name == "training" ) { |
686 | continue; |
687 | } |
688 | } else { |
689 | attr = overrideGradient(attr); |
690 | } |
691 | if (attr.isObject()) { |
692 | if (object_memo_.count(attr.toObject())) { |
693 | attr = object_memo_[attr.toObject()]; |
694 | } else { |
695 | auto weak_class_obj = |
696 | attr.toObject()->copy_to_weak_compilation_ref(); |
697 | object_memo_[attr.toObject()] = weak_class_obj; |
698 | attr = weak_class_obj; |
699 | } |
700 | } |
701 | if (auto attrVal = tryInsertConstant(*graph, attr)) { |
702 | paramConst = *attrVal; |
703 | } else { |
704 | GRAPH_DEBUG( |
705 | attr.type()->cast<ClassType>() ? "" : "attribute: " , |
706 | name, |
707 | " is not materializable." ); |
708 | continue; |
709 | } |
710 | std::string fullName("self." ); |
711 | for (auto& name : names_) { |
712 | fullName += name + '.'; |
713 | } |
714 | fullName += name; |
715 | paramConst->setDebugName(fullName); |
716 | attrValues[attrModule._ivalue()][name] = paramConst; |
717 | } |
718 | GRAPH_UPDATE( |
719 | "Folding GetAttr %" , |
720 | n->outputs()[0]->debugName(), |
721 | " with " , |
722 | paramConst->debugName()); |
723 | n->outputs().at(0)->replaceAllUsesWith(paramConst); |
724 | n->removeAllInputs(); |
725 | } else if (n->kind() == prim::fork) { |
726 | applyToForkSubgraph( |
727 | n, |
728 | graph, |
729 | // NOLINTNEXTLINE(modernize-avoid-bind) |
730 | std::bind( |
731 | &AttributePropagator::propagateAttributes, |
732 | *this, |
733 | std::placeholders::_1)); |
734 | } |
735 | } |
736 | } |
737 | } |
738 | |
739 | void applyToForkSubgraph( |
740 | Node* n, |
741 | std::shared_ptr<Graph>& graph, |
742 | const std::function<void(std::shared_ptr<Graph>&)>& func) { |
743 | TORCH_CHECK(n->kind() == prim::fork); |
744 | auto attrModule = module_; |
745 | auto node = n->inputs()[0]->node(); |
746 | // Check if first parameter of fork is a module. This module is used |
747 | // as the base module (similar to 'self' in forward) to resolve GetAttrs. |
748 | // Otherwise freezing is applied using module_ |
749 | if (node->kind() == prim::GetAttr && |
750 | node->output()->type()->cast<ClassType>()) { |
751 | auto name = node->s(attr::name); |
752 | auto input = node->inputs()[0]; |
753 | if (!findConstantAttr(input, name, attrModule, graph)) { |
754 | // Module needs to be preserved. |
755 | return; |
756 | } |
757 | attrModule = attrModule.attr(name).toModule(); |
758 | std::swap(module_, attrModule); |
759 | } |
760 | |
761 | auto subgraph = n->g(attr::Subgraph); |
762 | func(subgraph); |
763 | module_ = attrModule; |
764 | } |
765 | |
766 | bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) { |
767 | for (auto& output : graph->outputs()) { |
768 | if (subModule.type()->isSubtypeOf(*output->type())) { |
769 | return true; |
770 | } |
771 | } |
772 | return preservedSubModule_.count(subModule._ivalue()); |
773 | } |
774 | |
775 | void (Block* b) { |
776 | auto nodes = b->nodes(); |
777 | for (auto it = nodes.begin(); it != nodes.end(); it++) { |
778 | auto node = *it; |
779 | if (node->kind() != aten::wait) { |
780 | continue; |
781 | } |
782 | TORCH_INTERNAL_ASSERT(node->inputs().size() == 1); |
783 | TORCH_INTERNAL_ASSERT(node->outputs().size() == 1); |
784 | // If input type is not a from aten::fork call then the |
785 | // aten::wait operator can be deleted. |
786 | if (node->input()->type()->kind() != TypeKind::FutureType) { |
787 | node->output()->replaceAllUsesWith(node->input()); |
788 | it.destroyCurrent(); |
789 | } |
790 | } |
791 | // For the remaining nodes, recurse. |
792 | for (auto it = nodes.begin(); it != nodes.end(); it++) { |
793 | auto node = *it; |
794 | for (auto sub_b : node->blocks()) { |
795 | removeExtraWaitCalls(sub_b); |
796 | } |
797 | } |
798 | } |
799 | |
800 | // cleanupFrozenModule function cleans up the Frozen module. It performs the |
801 | // following: |
802 | // 1) Remove unused attributes. |
803 | // 2) Remove unreferenced submodules |
804 | // 3) Remove non public unreferenced methods. |
805 | void cleanupFrozenModule() { |
806 | for (auto function : preservedMethods_) { |
807 | auto graph = toGraphFunction(*function).graph(); |
808 | recordReferencedAttrs(graph); |
809 | handleSharedClassType(module_, graph); |
810 | removeExtraWaitCalls(graph->block()); |
811 | toGraphFunction(*function).clear_optimized_graphs(); |
812 | } |
813 | removeUnusedAttrs(); |
814 | } |
815 | |
816 | // Prepraring for clean up phase. At this point, record all subModules that |
817 | // contains mutable attributes. |
818 | void recordReferencedAttrs(std::shared_ptr<Graph>& graph) { |
819 | std::stack<Block*> blocks({graph->block()}); |
820 | std::set<ModulePtr> modules({module_._ivalue()}); |
821 | while (!blocks.empty()) { |
822 | Block* block = blocks.top(); |
823 | blocks.pop(); |
824 | for (auto n : block->nodes()) { |
825 | for (Block* subBlock : n->blocks()) { |
826 | blocks.push(subBlock); |
827 | } |
828 | if (n->kind() == prim::GetAttr) { |
829 | auto& name = n->s(attr::name); |
830 | // For now, use all module ivalues which are the same type |
831 | // and could be the module that this GetAttr resolves to |
832 | // TODO: we could attempt to follow the GetAttr chain and |
833 | // find the exact ivalue, we would have to be careful |
834 | // that the chain does not contain any attributes which |
835 | // get written to (setAttr calls) |
836 | for (auto& mptr : modules) { |
837 | auto module = Module(mptr); |
838 | if (module.type() == n->inputs()[0]->type()) { |
839 | TORCH_INTERNAL_ASSERT(module.hasattr(name)); |
840 | auto module = Module(mptr); |
841 | auto attr = module.attr(name); |
842 | // TODO: this could be insertReferencedAttr to be more clear, |
843 | // these are attributes we could not inline, which include |
844 | // other reasons besides mutation (unsupported constant, |
845 | // getAttr resolving to non-getAttr node, etc) |
846 | insertMutableAttr(name, attr, mptr); |
847 | if (attr.isModule()) { |
848 | modules.insert(attr.toModule()._ivalue()); |
849 | } |
850 | } |
851 | } |
852 | } else if (n->kind() == prim::fork) { |
853 | applyToForkSubgraph( |
854 | n, |
855 | graph, |
856 | // NOLINTNEXTLINE(modernize-avoid-bind) |
857 | std::bind( |
858 | &AttributePropagator::recordReferencedAttrs, |
859 | *this, |
860 | std::placeholders::_1)); |
861 | } |
862 | } |
863 | } |
864 | // We have to process the attributes that the user wants to preserve |
865 | // separately since it's possible that the user-preserved module is |
866 | // never referenced in the graph. |
867 | for (const auto& attr_info : userPreservedAttrs_) { |
868 | const auto& parent_module = attr_info.first; |
869 | for (const auto& attr_name : attr_info.second) { |
870 | const auto value = parent_module->getAttr(attr_name); |
871 | insertMutableAttr(attr_name, value, parent_module); |
872 | } |
873 | } |
874 | } |
875 | |
876 | // This function recursively iterates over submodules to identify |
877 | // for each class type the attribute slots that need to be preserved. |
878 | // |
879 | // Note 'attrsToKeep[type].insert(type->numAttributes())' means all |
880 | // attribute slots of 'type' and its methods are preserved. A submodule is |
881 | // preserved when it escapes (meaning it is returned). |
882 | void handleSharedClassType(Module& module, std::shared_ptr<Graph>& graph) { |
883 | auto type = module.type(); |
884 | size_t N = type->numAttributes(); |
885 | if (moduleEscapes(module, graph)) { |
886 | // Perserve all its attributes and methods. |
887 | attrsToKeep_[type].insert(N); |
888 | return; |
889 | } |
890 | auto it2 = preservedScalarAttrs_.find(module._ivalue()); |
891 | SharedTypeSubModules_[type].insert(module._ivalue()); |
892 | attrsToKeep_[type].insert({}); |
893 | for (const auto i : c10::irange(N)) { |
894 | auto name = type->getAttributeName(i); |
895 | auto attr = module.attr(name); |
896 | auto attrTy = attr.type(); |
897 | |
898 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
899 | bool isMutable; |
900 | if (AliasDb::isMutableType(attrTy)) { |
901 | isMutable = preservedAttrs_.count(attr); |
902 | } else { |
903 | isMutable = |
904 | it2 != preservedScalarAttrs_.end() && it2->second.count(name); |
905 | } |
906 | if (isMutable) { |
907 | attrsToKeep_[type].insert(i); |
908 | if (attr.isModule()) { |
909 | // See [Note: Inlining interfaces strategy] |
910 | TORCH_CHECK( |
911 | !type->getAttribute(i)->cast<InterfaceType>(), |
912 | "Unexpected interface attribute '" + name + "' during freezing" ); |
913 | |
914 | auto attrModule = attr.toModule(); |
915 | handleSharedClassType(attrModule, graph); |
916 | } |
917 | } |
918 | } |
919 | } |
920 | |
921 | // Remove unused attributes and methods for each sub module of the frozen |
922 | // module. This function iterates over the Calsstypes of its submodule |
923 | // attributes including its own type. |
924 | void removeUnusedAttrs() { |
925 | std::vector<std::string> attrsToRemove; |
926 | std::vector<Function*> funcsToRemove; |
927 | for (auto& it : attrsToKeep_) { |
928 | auto& type = it.first; |
929 | size_t N = type->numAttributes(); |
930 | if (it.second.count(N)) { |
931 | continue; |
932 | } |
933 | for (const auto i : c10::irange(N)) { |
934 | if (it.second.count(i) == 0) { |
935 | attrsToRemove.push_back(type->getAttributeName(i)); |
936 | } |
937 | } |
938 | for (auto& fn : type->methods()) { |
939 | if (preservedMethods_.count(fn)) { |
940 | continue; |
941 | } |
942 | funcsToRemove.push_back(fn); |
943 | } |
944 | |
945 | for (auto& name : attrsToRemove) { |
946 | for (auto& val : SharedTypeSubModules_[type]) { |
947 | auto mod = val.toModule(); |
948 | mod._ivalue()->unsafeRemoveAttr(name); |
949 | } |
950 | type->unsafeRemoveAttribute(name); |
951 | } |
952 | for (auto fn : funcsToRemove) { |
953 | type->unsafeRemoveMethod(fn->name()); |
954 | auto mod = SharedTypeSubModules_[type].begin()->toModule(); |
955 | mod._ivalue()->compilation_unit()->unsafeRemoveMethod(fn->qualname()); |
956 | } |
957 | |
958 | attrsToRemove.clear(); |
959 | funcsToRemove.clear(); |
960 | } |
961 | } |
962 | |
963 | // Contains attributes that can't be folded or user directs to keep them. |
964 | IValue::HashAliasedIValues preservedAttrs_; |
965 | // Tracked immutable types (Scalars) by their attribute names not |
966 | // IValues. |
967 | std::unordered_map<ModulePtr, std::unordered_set<std::string>> |
968 | preservedScalarAttrs_; |
969 | |
970 | // Contains user specified methods to be preserved in frozen module. |
971 | std::unordered_set<Function*> preservedMethods_; |
972 | |
973 | // Contains user specified sub module to be preserve in frozen module. |
974 | std::unordered_set<ModulePtr> preservedSubModule_; |
975 | |
976 | // Track all used attributes ivalues that can be aliased. |
977 | IValue::HashAliasedIValues usedAttrs_; |
978 | |
979 | // Contains the attribute slots that need to be preserved for each ClassType. |
980 | std::unordered_map<ClassTypePtr, std::unordered_set<size_t>> attrsToKeep_; |
981 | |
982 | // Contains the sub modules that share the same ClassType. |
983 | std::unordered_map<ClassTypePtr, IValue::HashAliasedIValues> |
984 | SharedTypeSubModules_; |
985 | |
986 | Module& module_; |
987 | |
988 | // Allow to freeze modules containing interfaces. |
989 | bool freezeInterfaces_; |
990 | |
991 | // Preserve module parameters |
992 | bool preserveParameters_; |
993 | |
994 | // Contains the attributes names (e.g. {"self", "subModule", "a"} |
995 | std::deque<std::string> names_; |
996 | |
997 | // see [Constant Object Weak CompilationUnit Reference] |
998 | std::unordered_map< |
999 | c10::intrusive_ptr<at::ivalue::Object>, |
1000 | c10::intrusive_ptr<at::ivalue::Object>> |
1001 | object_memo_; |
1002 | |
1003 | // Contains names of attributes that the user wants to preserve with |
1004 | // their owning modules. |
1005 | std::unordered_map<ModulePtr, std::unordered_set<std::string>> |
1006 | userPreservedAttrs_; |
1007 | |
1008 | }; // class AttributePropagator |
1009 | |
1010 | void checkModuleDoesNotReturnSelf(const Module& module) { |
1011 | if (module.find_method("forward" )) { |
1012 | Method method = module.get_method("forward" ); |
1013 | // Check that module does not return itself. |
1014 | for (auto& output : method.graph()->outputs()) { |
1015 | TORCH_CHECK( |
1016 | output->type() != module.type(), |
1017 | "attempted to freeze a module that return itself" ); |
1018 | } |
1019 | } |
1020 | } |
1021 | } // namespace |
1022 | |
1023 | Module freeze_module( |
1024 | const Module& module, |
1025 | std::vector<std::string> preservedAttrs, |
1026 | bool freezeInterfaces, |
1027 | bool preserveParameters) { |
1028 | checkModuleDoesNotReturnSelf(module); |
1029 | |
1030 | auto moduleClone = module.clone(true); |
1031 | AttributePropagator attrPropagator( |
1032 | moduleClone, preservedAttrs, freezeInterfaces, preserveParameters); |
1033 | attrPropagator.run(); |
1034 | return moduleClone; |
1035 | } |
1036 | |
1037 | void freeze_module_inplace( |
1038 | Module* module, |
1039 | std::vector<std::string> preservedAttrs, |
1040 | bool freezeInterfaces, |
1041 | bool preserveParameters) { |
1042 | TORCH_CHECK(module != nullptr, "module cannot be nullptr" ); |
1043 | checkModuleDoesNotReturnSelf(*module); |
1044 | AttributePropagator attrPropagator( |
1045 | *module, preservedAttrs, freezeInterfaces, preserveParameters); |
1046 | attrPropagator.run(); |
1047 | } |
1048 | |
1049 | } // namespace jit |
1050 | } // namespace torch |
1051 | |