1#include <torch/csrc/jit/passes/lower_graph.h>
2
3#include <torch/csrc/jit/api/object.h>
4#include <torch/csrc/jit/frontend/error_report.h>
5#include <torch/csrc/jit/passes/inliner.h>
6#include <torch/custom_class.h>
7#include <unordered_map>
8
9namespace torch {
10namespace jit {
11
12struct Slot {
13 c10::intrusive_ptr<c10::ivalue::Object> obj;
14 size_t offset;
15 bool operator==(const Slot& other) const {
16 return (this->obj == other.obj && this->offset == other.offset);
17 }
18};
19
20// remove the first module argument, replacing any access of its
21// parameters/attributes with extra_ivalue input Slots that hold what value to
22// pass into the graph. Used for ONNX export to remove first-class modules
23// so it can deal purely with parameters and inputs
24std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
25 const ModulePtr& self,
26 Graph& g_,
27 size_t self_offset = 0) {
28 std::shared_ptr<Graph> g = g_.copy();
29 // Inline to remove method/function calls
30 Inline(*g);
31
32 std::vector<Slot> extra_ivalues;
33
34 struct SlotHash {
35 std::size_t operator()(const Slot& slot) const {
36 auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
37 auto offset_hash = std::hash<size_t>{}(slot.offset);
38 return c10::hash_combine(obj_hash, offset_hash);
39 }
40 };
41 std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
42 struct ToScan {
43 ModulePtr mod;
44 Node* n;
45 size_t offset;
46 };
47 std::vector<ToScan> to_scan;
48 std::vector<Node*> to_clean; // nodes that should be dead at the end
49
50 auto getOrAddSlot = [&](const Slot& slot) -> Value* {
51 auto it = slot_to_offset.find(slot);
52 if (it != slot_to_offset.end()) {
53 size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
54 return g->inputs().at(ivalues_start + it->second);
55 }
56 extra_ivalues.emplace_back(slot);
57 slot_to_offset[slot] = extra_ivalues.size() - 1;
58 return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
59 };
60
61 auto self_value = g->inputs().at(self_offset);
62
63 for (Use use : self_value->uses()) {
64 to_scan.emplace_back(ToScan{self, use.user, use.offset});
65 }
66 while (!to_scan.empty()) {
67 auto e = to_scan.back();
68 to_scan.pop_back();
69
70 // when we lambda lift forks, first-class modules may be passed across
71 // forks. This code recursively lowers the module in the fork call.
72 if (e.n->kind() == prim::fork) {
73 auto subgraph = e.n->g(attr::Subgraph);
74 std::vector<Slot> new_slots;
75 std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
76 e.n->g_(attr::Subgraph, subgraph);
77 for (const Slot& slot : new_slots) {
78 e.n->addInput(getOrAddSlot(slot));
79 }
80 e.n->removeInput(e.offset);
81 continue;
82 }
83 if (e.n->kind() == prim::PythonOp) {
84 throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
85 }
86 if (e.n->kind() != prim::GetAttr) {
87 throw ErrorReport(e.n->sourceRange())
88 << "temporary: the only valid use of a module is looking up an "
89 "attribute but found "
90 << *e.n;
91 }
92 size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
93 auto iv = e.mod->getSlot(slot_idx);
94 if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
95 if (c->is_module()) {
96 for (Use use : e.n->output()->uses()) {
97 to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
98 }
99 to_clean.emplace_back(e.n);
100 continue;
101 }
102 }
103 e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
104 e.n->destroy();
105 }
106
107 while (!to_clean.empty()) {
108 Node* n = to_clean.back();
109 AT_ASSERT(!n->hasUses());
110 n->destroy();
111 to_clean.pop_back();
112 }
113 AT_ASSERT(!self_value->hasUses());
114 g->eraseInput(self_offset);
115
116 return std::make_pair(std::move(g), std::move(extra_ivalues));
117}
118
119static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
120 std::vector<IValue> result;
121 result.reserve(slots.size());
122 for (const Slot& slot : slots) {
123 auto obj = slot.obj->getSlot(slot.offset);
124 if (obj.isTensor()) {
125 result.emplace_back(obj.toTensor());
126 } else {
127 // Unpack quantization packed tensor
128 auto type = obj.type();
129 TORCH_CHECK(
130 (type ==
131 getCustomClass(
132 "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
133 (type ==
134 getCustomClass(
135 "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
136 (type ==
137 getCustomClass(
138 "__torch__.torch.classes.quantized.LinearPackedParamsBase")),
139 "Unknown type ",
140 type->repr_str(),
141 " encountered in graph lowering. This type is not supported in ONNX export.");
142 result.emplace_back(
143 script::Object(obj.toObject()).run_method("__getstate__"));
144 }
145 }
146 return result;
147}
148
149std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
150 Graph& graph,
151 const ModulePtr& self) {
152 auto result = lower_graph(self, graph);
153 return std::make_pair(result.first, loadTensors(result.second));
154}
155
156} // namespace jit
157} // namespace torch
158