1#include <torch/csrc/jit/passes/freeze_module.h>
2
3#include <torch/csrc/jit/jit_log.h>
4
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/api/function_impl.h>
7#include <torch/csrc/jit/ir/alias_analysis.h>
8#include <torch/csrc/jit/passes/autocast.h>
9#include <torch/csrc/jit/passes/clear_profiling.h>
10#include <torch/csrc/jit/passes/eliminate_no_ops.h>
11#include <torch/csrc/jit/passes/inliner.h>
12#include <torch/csrc/jit/passes/lower_tuples.h>
13#include <torch/csrc/jit/passes/remove_mutation.h>
14#include <torch/csrc/jit/runtime/graph_executor_impl.h>
15
16#include <stack>
17#include <utility>
18
19namespace torch {
20namespace jit {
21
22namespace {
23
24std::vector<std::string> splitName(const std::string& name) {
25 std::vector<std::string> result;
26 std::string sub_name;
27 std::istringstream name_stream(name);
28 while (std::getline(name_stream, sub_name, '.')) {
29 result.push_back(std::move(sub_name));
30 }
31 return result;
32}
33
34template <typename Iter>
35std::string concatName(const Iter& begin, const Iter& end) {
36 std::string combined_name = "";
37 for (Iter it = begin; it != end; ++it) {
38 const std::string& sub_name = *it;
39 if (!combined_name.empty()) {
40 combined_name += ".";
41 }
42 combined_name += sub_name;
43 }
44 return combined_name;
45}
46
47class AttributePropagator {
48 public:
49 AttributePropagator(
50 Module& module,
51 std::vector<std::string>& preservedAttrs,
52 bool freezeInterfaces,
53 bool preserveParameters)
54 : module_(module),
55 freezeInterfaces_(freezeInterfaces),
56 preserveParameters_(preserveParameters) {
57 auto checkName = [this](std::string& name) {
58 const auto resolved_name = resolveName(name);
59
60 if (resolved_name) {
61 const auto& parent_module = resolved_name->first;
62 const auto& attr_name = resolved_name->second;
63 if (parent_module.hasattr(attr_name)) {
64 auto value = parent_module.attr(attr_name);
65 // Freezing client wants to presever this submodule. When cleaning
66 // the frozen module, make sure it will be preserved entirely.
67 if (value.isModule()) {
68 preservedSubModule_.insert(value.toModule()._ivalue());
69 }
70 insertMutableAttr(attr_name, value, parent_module._ivalue());
71 } else {
72 auto fn = parent_module.get_method(attr_name);
73 preservedMethods_.insert(&fn.function());
74 }
75 return true;
76 }
77
78 return false;
79 };
80
81 // forward is preserved by default, but
82 // not all modules have a forward function defined
83 if (module_.find_method("forward")) {
84 auto method = module_.get_method("forward");
85 preservedMethods_.insert(&method.function());
86 }
87
88 for (auto name : preservedAttrs) {
89 TORCH_CHECK(checkName(name), "Unknown name: " + name);
90 }
91 }
92
93 void optimizeSubGraphs(
94 std::shared_ptr<Graph>& graph,
95 const std::function<void(std::shared_ptr<Graph>&)>& func) {
96 func(graph);
97 std::stack<Block*> blocks({graph->block()});
98 while (!blocks.empty()) {
99 Block* block = blocks.top();
100 blocks.pop();
101 for (auto n : block->nodes()) {
102 for (Block* sub_block : n->blocks()) {
103 blocks.push(sub_block);
104 }
105 if (n->kind() == prim::fork) {
106 auto subgraph = n->g(attr::Subgraph);
107 optimizeSubGraphs(subgraph, func);
108 }
109 }
110 }
111 }
112
113 void run() {
114 auto applyInline = [](std::shared_ptr<Graph>& subgraph) {
115 Inline(*subgraph);
116 ClearProfilingInformation(subgraph);
117 };
118 auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
119#ifndef C10_MOBILE
120 Autocast(subgraph);
121#endif
122 runOptimization(
123 subgraph,
124 /* unroll_non_constant_loops? */ false,
125 /* const_prop_user_classes? */ false);
126 EliminateNoOps(subgraph);
127 LowerSimpleTuples(subgraph);
128 };
129
130 std::unordered_map<std::string, std::unordered_set<std::string>>
131 interfacesToReassignType;
132
133 for (auto function : preservedMethods_) {
134 GRAPH_DEBUG("Analyzing function: " + function->name());
135 auto graph = toGraphFunction(*function).graph();
136 optimizeSubGraphs(graph, applyInline);
137 if (freezeInterfaces_) {
138 inlineInterfaceCalls(graph, interfacesToReassignType);
139 }
140 }
141
142 reassignInterfaceTypes(interfacesToReassignType);
143
144 for (auto function : preservedMethods_) {
145 GRAPH_DEBUG("Recording mutable attrs for function: " + function->name());
146 auto graph = toGraphFunction(*function).graph();
147 // Record Attributes that are explicitly set in the module.
148 // They cannot be folded.
149 recordMutableAttrs(graph);
150 }
151
152 for (auto function : preservedMethods_) {
153 GRAPH_DEBUG("Propagating function: " + function->name());
154 auto graph = toGraphFunction(*function).graph();
155 propagateAttributes(graph);
156 optimizeSubGraphs(graph, applyOptimizations);
157 }
158 GRAPH_DEBUG("Cleaning up module");
159 cleanupFrozenModule();
160 }
161
162 private:
163 using ResolvedName = std::pair<Module, std::string>;
164
165 // Try to resolve qualified names (submodule1.submodule2.foo). If
166 // the qualified name exists in the root module, return the unqualified
167 // attribute/function name and the parent module. Else, return nullopt.
168 // Examples:
169 // submodule1.submodule2.foo -> {submodule2, "foo"}
170 // submodule1.non_existent_module.foo -> nullopt
171 c10::optional<ResolvedName> resolveName(const std::string& name) {
172 auto sub_names = splitName(name);
173 if (sub_names.empty()) {
174 return c10::nullopt;
175 }
176 auto& attr_name = sub_names.back();
177 auto cur_module = module_;
178 std::vector<ResolvedName> attr_infos;
179 attr_infos.reserve(sub_names.size() - 1);
180
181 for (size_t i = 0; i < sub_names.size() - 1; ++i) {
182 bool found = false;
183 const auto& sub_name = sub_names[i];
184 for (const auto& child_module : cur_module.named_children()) {
185 if (child_module.name == sub_name) {
186 attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
187 cur_module = child_module.value;
188 found = true;
189 break;
190 }
191 }
192 if (!found) {
193 return c10::nullopt;
194 }
195 }
196
197 if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
198 // We don't want to mark these modules as mutable yet; that could
199 // interfere with the inlining procedure. Instead, we'll record
200 // the fact that the user wants to preserve them. They will be
201 // processed during clean-up preparation (recordReferenceAttrs)
202 for (auto& attr_info : attr_infos) {
203 const auto& parent_module = attr_info.first;
204 auto& sub_name = attr_info.second;
205 userPreservedAttrs_[parent_module._ivalue()].insert(
206 std::move(sub_name));
207 }
208 return std::make_pair(std::move(cur_module), std::move(attr_name));
209 }
210
211 return c10::nullopt;
212 }
213
214 bool _loadModulePath(Value* input, std::shared_ptr<Graph>& graph) {
215 Node* node = input->node();
216 names_.clear();
217 while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
218 if (node->kind() == prim::GetAttr) {
219 names_.push_front(node->s(attr::name));
220 node = node->inputs()[0]->node();
221 } else {
222 return false;
223 }
224 }
225
226 return true;
227 }
228
229 c10::optional<std::deque<std::string>> getModulePath(
230 Value* input,
231 std::shared_ptr<Graph>& graph) {
232 bool success = _loadModulePath(input, graph);
233 if (!success) {
234 return c10::nullopt;
235 }
236 return names_;
237 }
238
239 template <typename Iter>
240 bool getModuleFromPath(
241 Module& attrModule,
242 const Iter& begin,
243 const Iter& end) {
244 for (Iter it = begin; it != end; ++it) {
245 const std::string& moduleName = *it;
246 if (preservedAttrs_.count(attrModule.attr(moduleName))) {
247 return false;
248 }
249 attrModule = attrModule.attr(moduleName).toModule();
250 }
251 return true;
252 }
253
254 // findConstantAttr function locates the sub Module where attributes are
255 // defined. The algorithm chases getAttr chains to locate the submodules.
256 // For example:
257 // module M {
258 // attributes {
259 // A = <SubModule at ...>
260 // }
261 // ...
262 // %A = prim::GetAttr[name="A"](%self)
263 // ...
264 // %B = prim::GetAttr[name="B"](%A)
265 // ...
266 // %weight = prim::GetAttr[name="scale"](%B)
267 // ...
268 // submodules {
269 // module SubModule {
270 // attributes {
271 // B = <SubModule2 at ...>
272 // }
273 // submodules {
274 // module SubModule2 {
275 // attributes {
276 // scale = 2
277 // }
278 // }
279 // }
280 // }
281 // }
282 //
283 // findConstantAttr(%B, "scale", M) returns true because there are no
284 // explicit SetAttr that modifies %B. attrModule points to the module where
285 // attribute lives (in this example it is <SubModule2 at ...>).
286 //
287 // Note inplace mutations to attributes are checked later using alias
288 // analysis.
289 //
290 // We can use a more efficient algorithm to hash each constant GetAttr to its
291 // corresponding value. Based on initial test on resnet50 and other torch
292 // vision tests. GetAttrs are not too frequent so it is ok to chase GetAttr
293 // chain to retrieve their values.
294 bool findConstantAttr(
295 Value* input,
296 std::string& name,
297 Module& attrModule,
298 std::shared_ptr<Graph>& graph) {
299 if (!input->type()->cast<InterfaceType>() &&
300 !input->type()->expectRef<ClassType>().is_module()) {
301 return false;
302 }
303
304 // loads the path into this->names_
305 if (!_loadModulePath(input, graph)) {
306 return false;
307 }
308
309 // reassigns attrModule to the module in names_
310 if (!getModuleFromPath(attrModule, names_.begin(), names_.end())) {
311 return false;
312 }
313
314 auto attr = attrModule.attr(name);
315 if (!AliasDb::isMutableType(attr.type())) {
316 auto it = preservedScalarAttrs_.find(attrModule._ivalue());
317 return it == preservedScalarAttrs_.end() || !it->second.count(name);
318 }
319
320 if (preservedAttrs_.count(attr)) {
321 return false;
322 }
323 if (!attr.type()->cast<ClassType>()) {
324 for (auto& ivalue : preservedAttrs_) {
325 if (!ivalue.isObject() && ivalue.overlaps(attr)) {
326 return false;
327 }
328 }
329 }
330 return true;
331 }
332
333 void insertMutableAttr(
334 const std::string& name,
335 const IValue& attr,
336 const ModulePtr& attrModule) {
337 if (AliasDb::isMutableType(attr.type())) {
338 preservedAttrs_.insert(attr);
339 } else {
340 preservedScalarAttrs_[attrModule].insert(name);
341 }
342 }
343
344 void recordMutableAttrs(std::shared_ptr<Graph>& graph) {
345 std::stack<Block*> blocks({graph->block()});
346 std::unique_ptr<AliasDb> aliasDb =
347 torch::make_unique<AliasDb>(graph, /* isFrozen */ true);
348 while (!blocks.empty()) {
349 Block* block = blocks.top();
350 blocks.pop();
351 for (auto n : block->nodes()) {
352 for (Block* sub_block : n->blocks()) {
353 blocks.push(sub_block);
354 }
355
356 // Modules with prim::ModuleContainerIndex cannot be frozen because they
357 // return InterfaceTypes.
358 TORCH_CHECK(
359 n->kind() != prim::ModuleContainerIndex,
360 "Freezing modules containing prim::ModuleContainerIndex is not supported");
361
362 if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) {
363 // By default if interface attributes are present then fail freezing.
364 // If freezingInterfaces is on then Interfaces are folded similarly
365 // to other attributes.
366 TORCH_CHECK(
367 freezeInterfaces_ ||
368 !(n->kind() == prim::GetAttr &&
369 n->output()->type()->cast<InterfaceType>()),
370 "attempted to freeze a module that uses interface attributes");
371 auto name = n->s(attr::name);
372 auto attrModule = module_;
373 if (!findConstantAttr(n->inputs()[0], name, attrModule, graph)) {
374 continue;
375 }
376
377 auto attr = attrModule.attr(name);
378 if (n->kind() == prim::GetAttr) {
379 auto type = n->output()->type();
380 // Do not record submodules. Their attributes are tracked
381 // individually.
382 if (attr.isObject() || !AliasDb::isMutableType(attr.type())) {
383 continue;
384 }
385 usedAttrs_.insert(attr);
386 }
387
388 if (n->kind() == prim::SetAttr || aliasDb->hasOutputWriters(n)) {
389 GRAPH_DEBUG(
390 n->kind() == prim::GetAttr ? "attribute: " + name + " in %" +
391 n->output()->debugName() + " has inplace writer"
392 : "attribute: " + name + " is set");
393 auto mptr = attrModule._ivalue();
394 insertMutableAttr(name, attr, mptr);
395 }
396 } else if (n->kind() == prim::fork) {
397 applyToForkSubgraph(
398 n,
399 graph,
400 // NOLINTNEXTLINE(modernize-avoid-bind)
401 std::bind(
402 &AttributePropagator::recordMutableAttrs,
403 *this,
404 std::placeholders::_1));
405 }
406 }
407 }
408 // FIXME: Current Alias analysis fails to track subvalues.
409 // This is not a common scenario, for freezing, detect and error out.
410 IValue::HashAliasedIValues seen;
411 for (auto& val : usedAttrs_) {
412 IValue::HashAliasedIValues subValues;
413 val.getSubValues(subValues);
414 TORCH_CHECK(
415 std::all_of(
416 subValues.begin(),
417 subValues.end(),
418 [&seen](const IValue& v) { return seen.count(v) == 0; }),
419 "module contains attributes values that overlaps ",
420 val);
421 seen.insert(subValues.begin(), subValues.end());
422 }
423 }
424
425 IValue overrideGradient(IValue attr) {
426 if (attr.isTensor()) {
427 auto& t = attr.toTensor();
428 if (t.requires_grad()) {
429 auto detached = t.detach();
430 detached.set_requires_grad(false);
431 attr = IValue(std::move(detached));
432 }
433 } else if (attr.isTuple()) {
434 auto tuple = std::move(attr).toTuple();
435 const auto& elems = tuple->elements();
436 for (const auto idx : c10::irange(elems.size())) {
437 tuple->unsafeSetElement(idx, overrideGradient(elems[idx]));
438 }
439 attr = std::move(tuple);
440 } else if (attr.isList()) {
441 c10::List<IValue> elems = std::move(attr).toList();
442 for (const auto i : c10::irange(elems.size())) {
443 elems.set(i, overrideGradient(elems.extract(i)));
444 }
445 attr = elems;
446 } else if (attr.isGenericDict()) {
447 auto dict = std::move(attr).toGenericDict();
448 for (const auto& pair : dict) {
449 auto val = pair.value();
450 val = overrideGradient(std::move(val));
451 }
452 attr = dict;
453 } else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
454 auto obj_type = attr.type()->expect<ClassType>();
455 auto obj_value = std::move(attr).toObject();
456 auto sub_attributes = obj_type->getAttributes();
457 for (const auto& sub_attr : sub_attributes) {
458 auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
459 sub_attr_val = overrideGradient(std::move(sub_attr_val));
460 }
461 return obj_value;
462 }
463
464 return attr;
465 }
466
467 // This method is invoked only when 'freezeInterfaces' parameter is on.
468 // The module associated with Interface is retrieved and the invoked method
469 // is inlined.
470 bool inlineInterfaceCall(Node* n, const IValue& attr) {
471 auto class_type = attr.type()->expect<ClassType>();
472 bool inlined = false;
473 for (auto use : n->output()->uses()) {
474 auto user_node = use.user;
475 if (user_node->kind() == prim::CallMethod) {
476 const std::string& methodName = user_node->s(attr::name);
477 Function& function = class_type->getMethod(methodName);
478 if (auto graphFunction = tryToGraphFunction(function)) {
479 GRAPH_UPDATE(
480 "Inlining interface method '",
481 function.name(),
482 "' to ",
483 *user_node);
484
485 GRAPH_UPDATE("Function body: ", graphFunction->optimized_graph());
486 inlineCallTo(user_node, graphFunction);
487 inlined = true;
488 }
489 }
490 }
491 return inlined;
492 }
493
494 // [Note: Inlining interfaces strategy]
495 // There's two structures that are relevant to freezing:
496 // - the graph describing the computation in a method
497 // - the module describing the data structure of the module instance.
498 //
499 // First, in inlineInterfaceCalls, we inline interfaces. This is done in a
500 // separate step from normal inlining because CallMethod on an interface type
501 // requires extra steps compared to inlining a normal CallMethod.
502 //
503 // Next we need to simplify the structure of the module data structure, which
504 // is done for the most part by the usual steps in cleanupFrozenModule.
505 //
506 // However, there's a complication that comes from the fact that within a
507 // method, you can change the value of an interface to another module that
508 // implements that interface.
509 //
510 // For example:
511 //
512 // impl: MyInterface
513 // ...
514 // def forward(self, x):
515 // if x > 0:
516 // self.impl = my_interface_impl
517 //
518 // This is disallowed in freezing, because in this case we can't flatten out
519 // the module structure, since the type of self.impl will change.
520 //
521 // To handle this, we do the following:
522 // 1. inlineInterfaceCalls:
523 // a. inline the graph, and in the process record all interfaces
524 // b. simultaneously, check (throw error) for disallowed SetAttr calls.
525 // 2. call reassignInterfaceTypes, which reassigns interface types to their
526 // concrete types. This is done in a separate step to avoid interfering
527 // with inlineInterfaceCalls (note: this may not need to be done as a
528 // separate step)
529 // 3. eventually cleanupFrozenModule will reorder the module data structure
530 // and it will expect that all interface types have been removed.
531 void inlineInterfaceCalls(
532 std::shared_ptr<Graph>& graph,
533 std::unordered_map<std::string, std::unordered_set<std::string>>&
534 interfacesToRetype) {
535 auto block = graph->block();
536 std::stack<Block*> blocks({block});
537
538 while (!blocks.empty()) {
539 Block* block = blocks.top();
540 blocks.pop();
541 for (auto n : block->nodes()) {
542 for (Block* sub_block : n->blocks()) {
543 blocks.push(sub_block);
544 }
545 if (n->kind() == prim::GetAttr) {
546 if (!n->output()->type()->cast<InterfaceType>()) {
547 continue;
548 }
549 auto name = n->s(attr::name);
550 auto attrModule = module_;
551 auto input = n->inputs()[0];
552 TORCH_CHECK(
553 findConstantAttr(input, name, attrModule, graph),
554 "failed to freeze interface attribute '" + name + "'");
555 TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
556 auto attr = attrModule.attr(name);
557 inlineInterfaceCall(n, attr);
558 // Reset the GetAttr to concrete module type.
559 n->output()->setType(attr.type());
560
561 // Record this so that we can reassign the type later
562 // in reassignInterfaceTypes()
563 // See [Note: Inlining interfaces strategy]
564 auto path = getModulePath(input, graph);
565 TORCH_INTERNAL_ASSERT(path.has_value());
566 auto path_str = concatName(path->begin(), path->end());
567 interfacesToRetype[path_str].insert(name);
568 } else if (n->kind() == prim::SetAttr) {
569 // Check to make sure we're not assigning the value of any parameters
570 // that are interface types.
571 // See [Note: Inlining interfaces strategy]
572 auto name = n->s(attr::name);
573 auto attrModule = module_;
574 auto input = n->inputs()[0];
575
576 if (!input->type()->cast<InterfaceType>() &&
577 !input->type()->expectRef<ClassType>().is_module()) {
578 // we only care if we're setattr["thing"](%mod) if %mod
579 continue;
580 }
581
582 // note: this will modify attrModule until it is the parent of the
583 // "name" attr. In other words, attrModule is now the module that
584 // matches "input".
585 // We can't use findConstantAttr in case the base item is an object,
586 // instead of a module/interface.
587 auto path = getModulePath(input, graph);
588 TORCH_INTERNAL_ASSERT(path.has_value());
589 getModuleFromPath(attrModule, path->begin(), path->end());
590
591 const auto& attrType = attrModule.type()->getAttribute(name);
592 TORCH_INTERNAL_ASSERT(
593 !attrType->cast<InterfaceType>(),
594 "Freezing does not support SetAttr on an interface type. ",
595 "SetAttr is attempted on '",
596 name,
597 "'");
598 } else if (n->kind() == prim::fork) {
599 applyToForkSubgraph(
600 n,
601 graph,
602 // NOLINTNEXTLINE(modernize-avoid-bind)
603 std::bind(
604 &AttributePropagator::inlineInterfaceCalls,
605 *this,
606 std::placeholders::_1,
607 interfacesToRetype));
608 }
609 }
610 }
611 }
612
613 // See [Note: Inlining interfaces strategy]
614 // This modifies the internal structure of module types to reassign the
615 // type from an interface type to its concrete type.
616 void reassignInterfaceTypes(
617 const std::unordered_map<std::string, std::unordered_set<std::string>>&
618 interfacesToRetype) {
619 for (const auto& it : interfacesToRetype) {
620 const std::string& modulePath = it.first;
621 const std::vector<std::string>& splitPath = splitName(modulePath);
622 Module attrModule = module_;
623 getModuleFromPath(attrModule, splitPath.begin(), splitPath.end());
624
625 for (const std::string& name : it.second) {
626 auto subvalue = attrModule.attr(name);
627 auto subvalueType = subvalue.type();
628 attrModule.type()->unsafeChangeAttributeType(name, subvalueType);
629 }
630 }
631 }
632
633 void propagateAttributes(std::shared_ptr<Graph>& graph) {
634 std::unordered_map<ModulePtr, std::unordered_map<std::string, Value*>>
635 attrValues;
636 auto isEval = !module_.hasattr("training") || !module_.is_training();
637 GRAPH_DEBUG("Freezing Module: ", module_.type()->name()->name());
638 auto block = graph->block();
639 std::stack<Block*> blocks({block});
640
641 Node* m = *block->nodes().begin();
642 WithInsertPoint guard(m);
643 while (!blocks.empty()) {
644 Block* block = blocks.top();
645 blocks.pop();
646 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
647 Node* n = *it;
648 it++; // advance iterator bc the current node may be destroyed
649
650 for (Block* sub_block : n->blocks()) {
651 blocks.push(sub_block);
652 }
653 if (n->kind() == prim::GetAttr) {
654 auto name = n->s(attr::name);
655 auto attrModule = module_;
656 auto input = n->inputs()[0];
657 if (!findConstantAttr(input, name, attrModule, graph)) {
658 GRAPH_DEBUG(
659 input->type()->cast<InterfaceType>() ||
660 input->type()->expectRef<ClassType>().is_module()
661 ? "attribute: " + name + " is mutable."
662 : "");
663 continue;
664 }
665 TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
666 Value* paramConst = nullptr;
667 auto iter = attrValues.find(attrModule._ivalue());
668 if (iter != attrValues.end()) {
669 auto iter2 = iter->second.find(name);
670 if (iter2 != iter->second.end())
671 paramConst = iter2->second;
672 }
673 if (!paramConst) {
674 auto attr = attrModule.attr(name);
675 if (!isEval || preserveParameters_) {
676 auto type = attrModule.type();
677 auto slot = *type->findAttributeSlot(name);
678 if (type->is_parameter(slot) || type->is_buffer(slot) ||
679 (attr.isObject() &&
680 !attr.toObjectRef().type()->is_module())) {
681 continue;
682 } else {
683 attr = overrideGradient(attr);
684 }
685 if (!isEval && name == "training") {
686 continue;
687 }
688 } else {
689 attr = overrideGradient(attr);
690 }
691 if (attr.isObject()) {
692 if (object_memo_.count(attr.toObject())) {
693 attr = object_memo_[attr.toObject()];
694 } else {
695 auto weak_class_obj =
696 attr.toObject()->copy_to_weak_compilation_ref();
697 object_memo_[attr.toObject()] = weak_class_obj;
698 attr = weak_class_obj;
699 }
700 }
701 if (auto attrVal = tryInsertConstant(*graph, attr)) {
702 paramConst = *attrVal;
703 } else {
704 GRAPH_DEBUG(
705 attr.type()->cast<ClassType>() ? "" : "attribute: ",
706 name,
707 " is not materializable.");
708 continue;
709 }
710 std::string fullName("self.");
711 for (auto& name : names_) {
712 fullName += name + '.';
713 }
714 fullName += name;
715 paramConst->setDebugName(fullName);
716 attrValues[attrModule._ivalue()][name] = paramConst;
717 }
718 GRAPH_UPDATE(
719 "Folding GetAttr %",
720 n->outputs()[0]->debugName(),
721 " with ",
722 paramConst->debugName());
723 n->outputs().at(0)->replaceAllUsesWith(paramConst);
724 n->removeAllInputs();
725 } else if (n->kind() == prim::fork) {
726 applyToForkSubgraph(
727 n,
728 graph,
729 // NOLINTNEXTLINE(modernize-avoid-bind)
730 std::bind(
731 &AttributePropagator::propagateAttributes,
732 *this,
733 std::placeholders::_1));
734 }
735 }
736 }
737 }
738
739 void applyToForkSubgraph(
740 Node* n,
741 std::shared_ptr<Graph>& graph,
742 const std::function<void(std::shared_ptr<Graph>&)>& func) {
743 TORCH_CHECK(n->kind() == prim::fork);
744 auto attrModule = module_;
745 auto node = n->inputs()[0]->node();
746 // Check if first parameter of fork is a module. This module is used
747 // as the base module (similar to 'self' in forward) to resolve GetAttrs.
748 // Otherwise freezing is applied using module_
749 if (node->kind() == prim::GetAttr &&
750 node->output()->type()->cast<ClassType>()) {
751 auto name = node->s(attr::name);
752 auto input = node->inputs()[0];
753 if (!findConstantAttr(input, name, attrModule, graph)) {
754 // Module needs to be preserved.
755 return;
756 }
757 attrModule = attrModule.attr(name).toModule();
758 std::swap(module_, attrModule);
759 }
760
761 auto subgraph = n->g(attr::Subgraph);
762 func(subgraph);
763 module_ = attrModule;
764 }
765
766 bool moduleEscapes(Module& subModule, std::shared_ptr<Graph>& graph) {
767 for (auto& output : graph->outputs()) {
768 if (subModule.type()->isSubtypeOf(*output->type())) {
769 return true;
770 }
771 }
772 return preservedSubModule_.count(subModule._ivalue());
773 }
774
775 void removeExtraWaitCalls(Block* b) {
776 auto nodes = b->nodes();
777 for (auto it = nodes.begin(); it != nodes.end(); it++) {
778 auto node = *it;
779 if (node->kind() != aten::wait) {
780 continue;
781 }
782 TORCH_INTERNAL_ASSERT(node->inputs().size() == 1);
783 TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
784 // If input type is not a from aten::fork call then the
785 // aten::wait operator can be deleted.
786 if (node->input()->type()->kind() != TypeKind::FutureType) {
787 node->output()->replaceAllUsesWith(node->input());
788 it.destroyCurrent();
789 }
790 }
791 // For the remaining nodes, recurse.
792 for (auto it = nodes.begin(); it != nodes.end(); it++) {
793 auto node = *it;
794 for (auto sub_b : node->blocks()) {
795 removeExtraWaitCalls(sub_b);
796 }
797 }
798 }
799
800 // cleanupFrozenModule function cleans up the Frozen module. It performs the
801 // following:
802 // 1) Remove unused attributes.
803 // 2) Remove unreferenced submodules
804 // 3) Remove non public unreferenced methods.
805 void cleanupFrozenModule() {
806 for (auto function : preservedMethods_) {
807 auto graph = toGraphFunction(*function).graph();
808 recordReferencedAttrs(graph);
809 handleSharedClassType(module_, graph);
810 removeExtraWaitCalls(graph->block());
811 toGraphFunction(*function).clear_optimized_graphs();
812 }
813 removeUnusedAttrs();
814 }
815
816 // Prepraring for clean up phase. At this point, record all subModules that
817 // contains mutable attributes.
818 void recordReferencedAttrs(std::shared_ptr<Graph>& graph) {
819 std::stack<Block*> blocks({graph->block()});
820 std::set<ModulePtr> modules({module_._ivalue()});
821 while (!blocks.empty()) {
822 Block* block = blocks.top();
823 blocks.pop();
824 for (auto n : block->nodes()) {
825 for (Block* subBlock : n->blocks()) {
826 blocks.push(subBlock);
827 }
828 if (n->kind() == prim::GetAttr) {
829 auto& name = n->s(attr::name);
830 // For now, use all module ivalues which are the same type
831 // and could be the module that this GetAttr resolves to
832 // TODO: we could attempt to follow the GetAttr chain and
833 // find the exact ivalue, we would have to be careful
834 // that the chain does not contain any attributes which
835 // get written to (setAttr calls)
836 for (auto& mptr : modules) {
837 auto module = Module(mptr);
838 if (module.type() == n->inputs()[0]->type()) {
839 TORCH_INTERNAL_ASSERT(module.hasattr(name));
840 auto module = Module(mptr);
841 auto attr = module.attr(name);
842 // TODO: this could be insertReferencedAttr to be more clear,
843 // these are attributes we could not inline, which include
844 // other reasons besides mutation (unsupported constant,
845 // getAttr resolving to non-getAttr node, etc)
846 insertMutableAttr(name, attr, mptr);
847 if (attr.isModule()) {
848 modules.insert(attr.toModule()._ivalue());
849 }
850 }
851 }
852 } else if (n->kind() == prim::fork) {
853 applyToForkSubgraph(
854 n,
855 graph,
856 // NOLINTNEXTLINE(modernize-avoid-bind)
857 std::bind(
858 &AttributePropagator::recordReferencedAttrs,
859 *this,
860 std::placeholders::_1));
861 }
862 }
863 }
864 // We have to process the attributes that the user wants to preserve
865 // separately since it's possible that the user-preserved module is
866 // never referenced in the graph.
867 for (const auto& attr_info : userPreservedAttrs_) {
868 const auto& parent_module = attr_info.first;
869 for (const auto& attr_name : attr_info.second) {
870 const auto value = parent_module->getAttr(attr_name);
871 insertMutableAttr(attr_name, value, parent_module);
872 }
873 }
874 }
875
876 // This function recursively iterates over submodules to identify
877 // for each class type the attribute slots that need to be preserved.
878 //
879 // Note 'attrsToKeep[type].insert(type->numAttributes())' means all
880 // attribute slots of 'type' and its methods are preserved. A submodule is
881 // preserved when it escapes (meaning it is returned).
882 void handleSharedClassType(Module& module, std::shared_ptr<Graph>& graph) {
883 auto type = module.type();
884 size_t N = type->numAttributes();
885 if (moduleEscapes(module, graph)) {
886 // Perserve all its attributes and methods.
887 attrsToKeep_[type].insert(N);
888 return;
889 }
890 auto it2 = preservedScalarAttrs_.find(module._ivalue());
891 SharedTypeSubModules_[type].insert(module._ivalue());
892 attrsToKeep_[type].insert({});
893 for (const auto i : c10::irange(N)) {
894 auto name = type->getAttributeName(i);
895 auto attr = module.attr(name);
896 auto attrTy = attr.type();
897
898 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
899 bool isMutable;
900 if (AliasDb::isMutableType(attrTy)) {
901 isMutable = preservedAttrs_.count(attr);
902 } else {
903 isMutable =
904 it2 != preservedScalarAttrs_.end() && it2->second.count(name);
905 }
906 if (isMutable) {
907 attrsToKeep_[type].insert(i);
908 if (attr.isModule()) {
909 // See [Note: Inlining interfaces strategy]
910 TORCH_CHECK(
911 !type->getAttribute(i)->cast<InterfaceType>(),
912 "Unexpected interface attribute '" + name + "' during freezing");
913
914 auto attrModule = attr.toModule();
915 handleSharedClassType(attrModule, graph);
916 }
917 }
918 }
919 }
920
921 // Remove unused attributes and methods for each sub module of the frozen
922 // module. This function iterates over the Calsstypes of its submodule
923 // attributes including its own type.
924 void removeUnusedAttrs() {
925 std::vector<std::string> attrsToRemove;
926 std::vector<Function*> funcsToRemove;
927 for (auto& it : attrsToKeep_) {
928 auto& type = it.first;
929 size_t N = type->numAttributes();
930 if (it.second.count(N)) {
931 continue;
932 }
933 for (const auto i : c10::irange(N)) {
934 if (it.second.count(i) == 0) {
935 attrsToRemove.push_back(type->getAttributeName(i));
936 }
937 }
938 for (auto& fn : type->methods()) {
939 if (preservedMethods_.count(fn)) {
940 continue;
941 }
942 funcsToRemove.push_back(fn);
943 }
944
945 for (auto& name : attrsToRemove) {
946 for (auto& val : SharedTypeSubModules_[type]) {
947 auto mod = val.toModule();
948 mod._ivalue()->unsafeRemoveAttr(name);
949 }
950 type->unsafeRemoveAttribute(name);
951 }
952 for (auto fn : funcsToRemove) {
953 type->unsafeRemoveMethod(fn->name());
954 auto mod = SharedTypeSubModules_[type].begin()->toModule();
955 mod._ivalue()->compilation_unit()->unsafeRemoveMethod(fn->qualname());
956 }
957
958 attrsToRemove.clear();
959 funcsToRemove.clear();
960 }
961 }
962
963 // Contains attributes that can't be folded or user directs to keep them.
964 IValue::HashAliasedIValues preservedAttrs_;
965 // Tracked immutable types (Scalars) by their attribute names not
966 // IValues.
967 std::unordered_map<ModulePtr, std::unordered_set<std::string>>
968 preservedScalarAttrs_;
969
970 // Contains user specified methods to be preserved in frozen module.
971 std::unordered_set<Function*> preservedMethods_;
972
973 // Contains user specified sub module to be preserve in frozen module.
974 std::unordered_set<ModulePtr> preservedSubModule_;
975
976 // Track all used attributes ivalues that can be aliased.
977 IValue::HashAliasedIValues usedAttrs_;
978
979 // Contains the attribute slots that need to be preserved for each ClassType.
980 std::unordered_map<ClassTypePtr, std::unordered_set<size_t>> attrsToKeep_;
981
982 // Contains the sub modules that share the same ClassType.
983 std::unordered_map<ClassTypePtr, IValue::HashAliasedIValues>
984 SharedTypeSubModules_;
985
986 Module& module_;
987
988 // Allow to freeze modules containing interfaces.
989 bool freezeInterfaces_;
990
991 // Preserve module parameters
992 bool preserveParameters_;
993
994 // Contains the attributes names (e.g. {"self", "subModule", "a"}
995 std::deque<std::string> names_;
996
997 // see [Constant Object Weak CompilationUnit Reference]
998 std::unordered_map<
999 c10::intrusive_ptr<at::ivalue::Object>,
1000 c10::intrusive_ptr<at::ivalue::Object>>
1001 object_memo_;
1002
1003 // Contains names of attributes that the user wants to preserve with
1004 // their owning modules.
1005 std::unordered_map<ModulePtr, std::unordered_set<std::string>>
1006 userPreservedAttrs_;
1007
1008}; // class AttributePropagator
1009
1010void checkModuleDoesNotReturnSelf(const Module& module) {
1011 if (module.find_method("forward")) {
1012 Method method = module.get_method("forward");
1013 // Check that module does not return itself.
1014 for (auto& output : method.graph()->outputs()) {
1015 TORCH_CHECK(
1016 output->type() != module.type(),
1017 "attempted to freeze a module that return itself");
1018 }
1019 }
1020}
1021} // namespace
1022
1023Module freeze_module(
1024 const Module& module,
1025 std::vector<std::string> preservedAttrs,
1026 bool freezeInterfaces,
1027 bool preserveParameters) {
1028 checkModuleDoesNotReturnSelf(module);
1029
1030 auto moduleClone = module.clone(true);
1031 AttributePropagator attrPropagator(
1032 moduleClone, preservedAttrs, freezeInterfaces, preserveParameters);
1033 attrPropagator.run();
1034 return moduleClone;
1035}
1036
1037void freeze_module_inplace(
1038 Module* module,
1039 std::vector<std::string> preservedAttrs,
1040 bool freezeInterfaces,
1041 bool preserveParameters) {
1042 TORCH_CHECK(module != nullptr, "module cannot be nullptr");
1043 checkModuleDoesNotReturnSelf(*module);
1044 AttributePropagator attrPropagator(
1045 *module, preservedAttrs, freezeInterfaces, preserveParameters);
1046 attrPropagator.run();
1047}
1048
1049} // namespace jit
1050} // namespace torch
1051