1 | #include <torch/csrc/jit/passes/lower_graph.h> |
2 | |
3 | #include <torch/csrc/jit/api/object.h> |
4 | #include <torch/csrc/jit/frontend/error_report.h> |
5 | #include <torch/csrc/jit/passes/inliner.h> |
6 | #include <torch/custom_class.h> |
7 | #include <unordered_map> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | struct Slot { |
13 | c10::intrusive_ptr<c10::ivalue::Object> obj; |
14 | size_t offset; |
15 | bool operator==(const Slot& other) const { |
16 | return (this->obj == other.obj && this->offset == other.offset); |
17 | } |
18 | }; |
19 | |
20 | // remove the first module argument, replacing any access of its |
21 | // parameters/attributes with extra_ivalue input Slots that hold what value to |
22 | // pass into the graph. Used for ONNX export to remove first-class modules |
23 | // so it can deal purely with parameters and inputs |
24 | std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph( |
25 | const ModulePtr& self, |
26 | Graph& g_, |
27 | size_t self_offset = 0) { |
28 | std::shared_ptr<Graph> g = g_.copy(); |
29 | // Inline to remove method/function calls |
30 | Inline(*g); |
31 | |
32 | std::vector<Slot> ; |
33 | |
34 | struct SlotHash { |
35 | std::size_t operator()(const Slot& slot) const { |
36 | auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get()); |
37 | auto offset_hash = std::hash<size_t>{}(slot.offset); |
38 | return c10::hash_combine(obj_hash, offset_hash); |
39 | } |
40 | }; |
41 | std::unordered_map<Slot, size_t, SlotHash> slot_to_offset; |
42 | struct ToScan { |
43 | ModulePtr mod; |
44 | Node* n; |
45 | size_t offset; |
46 | }; |
47 | std::vector<ToScan> to_scan; |
48 | std::vector<Node*> to_clean; // nodes that should be dead at the end |
49 | |
50 | auto getOrAddSlot = [&](const Slot& slot) -> Value* { |
51 | auto it = slot_to_offset.find(slot); |
52 | if (it != slot_to_offset.end()) { |
53 | size_t ivalues_start = g->inputs().size() - extra_ivalues.size(); |
54 | return g->inputs().at(ivalues_start + it->second); |
55 | } |
56 | extra_ivalues.emplace_back(slot); |
57 | slot_to_offset[slot] = extra_ivalues.size() - 1; |
58 | return g->addInput()->setType(slot.obj->getSlot(slot.offset).type()); |
59 | }; |
60 | |
61 | auto self_value = g->inputs().at(self_offset); |
62 | |
63 | for (Use use : self_value->uses()) { |
64 | to_scan.emplace_back(ToScan{self, use.user, use.offset}); |
65 | } |
66 | while (!to_scan.empty()) { |
67 | auto e = to_scan.back(); |
68 | to_scan.pop_back(); |
69 | |
70 | // when we lambda lift forks, first-class modules may be passed across |
71 | // forks. This code recursively lowers the module in the fork call. |
72 | if (e.n->kind() == prim::fork) { |
73 | auto subgraph = e.n->g(attr::Subgraph); |
74 | std::vector<Slot> new_slots; |
75 | std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset); |
76 | e.n->g_(attr::Subgraph, subgraph); |
77 | for (const Slot& slot : new_slots) { |
78 | e.n->addInput(getOrAddSlot(slot)); |
79 | } |
80 | e.n->removeInput(e.offset); |
81 | continue; |
82 | } |
83 | if (e.n->kind() == prim::PythonOp) { |
84 | throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method." ; |
85 | } |
86 | if (e.n->kind() != prim::GetAttr) { |
87 | throw ErrorReport(e.n->sourceRange()) |
88 | << "temporary: the only valid use of a module is looking up an " |
89 | "attribute but found " |
90 | << *e.n; |
91 | } |
92 | size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name)); |
93 | auto iv = e.mod->getSlot(slot_idx); |
94 | if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) { |
95 | if (c->is_module()) { |
96 | for (Use use : e.n->output()->uses()) { |
97 | to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset}); |
98 | } |
99 | to_clean.emplace_back(e.n); |
100 | continue; |
101 | } |
102 | } |
103 | e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx})); |
104 | e.n->destroy(); |
105 | } |
106 | |
107 | while (!to_clean.empty()) { |
108 | Node* n = to_clean.back(); |
109 | AT_ASSERT(!n->hasUses()); |
110 | n->destroy(); |
111 | to_clean.pop_back(); |
112 | } |
113 | AT_ASSERT(!self_value->hasUses()); |
114 | g->eraseInput(self_offset); |
115 | |
116 | return std::make_pair(std::move(g), std::move(extra_ivalues)); |
117 | } |
118 | |
119 | static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) { |
120 | std::vector<IValue> result; |
121 | result.reserve(slots.size()); |
122 | for (const Slot& slot : slots) { |
123 | auto obj = slot.obj->getSlot(slot.offset); |
124 | if (obj.isTensor()) { |
125 | result.emplace_back(obj.toTensor()); |
126 | } else { |
127 | // Unpack quantization packed tensor |
128 | auto type = obj.type(); |
129 | TORCH_CHECK( |
130 | (type == |
131 | getCustomClass( |
132 | "__torch__.torch.classes.quantized.Conv2dPackedParamsBase" )) || |
133 | (type == |
134 | getCustomClass( |
135 | "__torch__.torch.classes.quantized.Conv3dPackedParamsBase" )) || |
136 | (type == |
137 | getCustomClass( |
138 | "__torch__.torch.classes.quantized.LinearPackedParamsBase" )), |
139 | "Unknown type " , |
140 | type->repr_str(), |
141 | " encountered in graph lowering. This type is not supported in ONNX export." ); |
142 | result.emplace_back( |
143 | script::Object(obj.toObject()).run_method("__getstate__" )); |
144 | } |
145 | } |
146 | return result; |
147 | } |
148 | |
149 | std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph( |
150 | Graph& graph, |
151 | const ModulePtr& self) { |
152 | auto result = lower_graph(self, graph); |
153 | return std::make_pair(result.first, loadTensors(result.second)); |
154 | } |
155 | |
156 | } // namespace jit |
157 | } // namespace torch |
158 | |