1 | #include <torch/csrc/jit/ir/alias_analysis.h> |
2 | #include <torch/csrc/jit/passes/peephole_dict_idioms.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | namespace { |
8 | |
9 | class DictNodeImplBase { |
10 | public: |
11 | virtual ~DictNodeImplBase() = default; |
12 | |
13 | virtual bool contains(const IValue&) const = 0; |
14 | virtual size_t size() const = 0; |
15 | virtual Value* get(const IValue&) const = 0; |
16 | |
17 | bool canOptimize() { |
18 | return !has_overlap_ && !has_non_const_key_; |
19 | } |
20 | |
21 | protected: |
22 | bool has_overlap_ = false; |
23 | bool has_non_const_key_ = false; |
24 | }; |
25 | |
26 | template <class KeyType> |
27 | class DictNodeImpl : public DictNodeImplBase { |
28 | public: |
29 | DictNodeImpl( |
30 | std::function<KeyType(const IValue&)> ivalue_converter, |
31 | Node* dict_creation_node) |
32 | : ivalue_converter_(std::move(ivalue_converter)) { |
33 | for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) { |
34 | auto key_opt = toIValue(dict_creation_node->input(i)); |
35 | |
36 | // Key is not constant if we cannot convert to IValue |
37 | if (key_opt == c10::nullopt) { |
38 | has_non_const_key_ = true; |
39 | continue; |
40 | } |
41 | |
42 | KeyType key = ivalue_converter_(*key_opt); |
43 | if (dict_.find(key) == dict_.end()) { |
44 | dict_.emplace(key, dict_creation_node->input(i + 1)); |
45 | } else { |
46 | has_overlap_ = true; |
47 | } |
48 | } |
49 | } |
50 | |
51 | bool contains(const IValue& ivalue) const override { |
52 | auto key = ivalue_converter_(ivalue); |
53 | return dict_.find(key) != dict_.end(); |
54 | } |
55 | |
56 | size_t size() const override { |
57 | return dict_.size(); |
58 | } |
59 | |
60 | Value* get(const IValue& ivalue) const override { |
61 | auto val = ivalue_converter_(ivalue); |
62 | auto loc = dict_.find(val); |
63 | if (loc != dict_.end()) { |
64 | return loc->second; |
65 | } |
66 | TORCH_CHECK(false, "Cannot get non-existent key" ); |
67 | } |
68 | |
69 | private: |
70 | std::unordered_map<KeyType, Value*> dict_; |
71 | std::function<KeyType(const IValue&)> ivalue_converter_; |
72 | }; |
73 | |
74 | class DictNode { |
75 | public: |
76 | explicit DictNode(Node* dict_creation_node) { |
77 | auto dict_type = dict_creation_node->output()->type(); |
78 | auto key_value_types = dict_type->containedTypes(); |
79 | TORCH_CHECK( |
80 | key_value_types.size() == 2, "Dict must have 2 contained types" ); |
81 | const auto& key_type = key_value_types[0]; |
82 | |
83 | switch (key_type->kind()) { |
84 | case TypeKind::IntType: { |
85 | auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); }; |
86 | impl_ = std::make_unique<DictNodeImpl<int64_t>>( |
87 | std::move(ivalue_converter), dict_creation_node); |
88 | break; |
89 | } |
90 | |
91 | case TypeKind::FloatType: { |
92 | auto ivalue_converter = [](const IValue& ival) { |
93 | return ival.toDouble(); |
94 | }; |
95 | impl_ = std::make_unique<DictNodeImpl<double>>( |
96 | std::move(ivalue_converter), dict_creation_node); |
97 | break; |
98 | } |
99 | |
100 | case TypeKind::StringType: { |
101 | auto ivalue_converter = [](const IValue& ival) { |
102 | return *ival.toString(); |
103 | }; |
104 | impl_ = std::make_unique<DictNodeImpl<std::string>>( |
105 | std::move(ivalue_converter), dict_creation_node); |
106 | break; |
107 | } |
108 | |
109 | default: |
110 | impl_ = nullptr; |
111 | } |
112 | } |
113 | |
114 | bool canOptimize() const { |
115 | if (impl_) { |
116 | return impl_->canOptimize(); |
117 | } |
118 | return false; |
119 | } |
120 | |
121 | size_t size() const { |
122 | if (impl_) { |
123 | return impl_->size(); |
124 | } |
125 | return 0; |
126 | } |
127 | |
128 | c10::optional<Value*> getOrNullopt(const IValue& key) const { |
129 | if (impl_ && impl_->contains(key)) { |
130 | return impl_->get(key); |
131 | } |
132 | return c10::nullopt; |
133 | } |
134 | |
135 | private: |
136 | std::unique_ptr<DictNodeImplBase> impl_; |
137 | }; |
138 | |
139 | bool isDict(Value* v) { |
140 | return v->type()->castRaw<DictType>() != nullptr; |
141 | } |
142 | |
143 | class PeepholeOptimizeDictIdiomsImpl { |
144 | public: |
145 | explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph) |
146 | : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {} |
147 | |
148 | bool run() { |
149 | collectMutatedDicts(graph_->block()); |
150 | return runBlock(graph_->block()); |
151 | } |
152 | |
153 | private: |
154 | void checkForMutatedDicts(Value* v) { |
155 | if (isDict(v) && aliasDb_->hasWriters(v)) { |
156 | mutated_dicts_.insert(v); |
157 | } |
158 | } |
159 | |
160 | void collectMutatedDicts(Block* b) { |
161 | for (Value* v : b->inputs()) { |
162 | checkForMutatedDicts(v); |
163 | } |
164 | for (Node* n : b->nodes()) { |
165 | for (Value* v : n->outputs()) { |
166 | checkForMutatedDicts(v); |
167 | } |
168 | for (Block* block : n->blocks()) { |
169 | collectMutatedDicts(block); |
170 | } |
171 | } |
172 | } |
173 | |
174 | const DictNode& getDictNode(Node* creation_node) { |
175 | auto cached = dict_cache_.find(creation_node); |
176 | if (cached == dict_cache_.end()) { |
177 | cached = |
178 | dict_cache_.emplace(creation_node, DictNode(creation_node)).first; |
179 | } |
180 | |
181 | return cached->second; |
182 | } |
183 | |
184 | c10::optional<Value*> getValueFromDict(Node* dict_creation_node, Value* key) { |
185 | const DictNode& dict_node = getDictNode(dict_creation_node); |
186 | auto key_opt = toIValue(key); |
187 | // Key is not constant if we cannot convert to IValue |
188 | if (key_opt == c10::nullopt) { |
189 | return c10::nullopt; |
190 | } |
191 | IValue key_ival = *key_opt; |
192 | if (dict_node.canOptimize()) { |
193 | return dict_node.getOrNullopt(key_ival); |
194 | } |
195 | return c10::nullopt; |
196 | } |
197 | |
198 | c10::optional<int64_t> computeLen(Node* dict_creation_node) { |
199 | const DictNode& dict_node = getDictNode(dict_creation_node); |
200 | if (dict_node.canOptimize()) { |
201 | return static_cast<int64_t>(dict_node.size()); |
202 | } |
203 | return c10::nullopt; |
204 | } |
205 | |
206 | bool optimizeLen(Node* len_node, Node* creation_node) { |
207 | if (creation_node->kind() == prim::DictConstruct) { |
208 | auto len = computeLen(creation_node); |
209 | if (len != c10::nullopt) { |
210 | WithInsertPoint guard(len_node); |
211 | len_node->output()->replaceAllUsesWith(graph_->insertConstant(len)); |
212 | return true; |
213 | } |
214 | } |
215 | return false; |
216 | } |
217 | |
218 | bool optimizeGetItem(Node* getitem_node, Node* creation_node) { |
219 | if (creation_node->kind() == prim::DictConstruct) { |
220 | auto key = getitem_node->input(1); |
221 | auto value = getValueFromDict(creation_node, key); |
222 | if (value != c10::nullopt) { |
223 | getitem_node->output()->replaceAllUsesWith(*value); |
224 | return true; |
225 | } |
226 | } |
227 | return false; |
228 | } |
229 | |
230 | bool runBlock(Block* block) { |
231 | bool changed = false; |
232 | for (Node* node : block->nodes()) { |
233 | for (Block* b : node->blocks()) { |
234 | changed |= runBlock(b); |
235 | } |
236 | |
237 | // only optimizing dict ops |
238 | if (node->inputs().empty() || !isDict(node->input(0))) { |
239 | continue; |
240 | } |
241 | |
242 | auto first_input = node->input(0); |
243 | |
244 | // only optimizing ops with unmutated inputs |
245 | if (mutated_dicts_.count(first_input)) { |
246 | continue; |
247 | } |
248 | |
249 | if (node->kind() == aten::len) { |
250 | changed |= optimizeLen(node, first_input->node()); |
251 | } else if (node->kind() == aten::__getitem__) { |
252 | changed |= optimizeGetItem(node, first_input->node()); |
253 | } |
254 | } |
255 | return changed; |
256 | } |
257 | |
258 | std::shared_ptr<Graph> graph_; |
259 | std::unordered_set<Value*> mutated_dicts_; |
260 | std::unique_ptr<AliasDb> aliasDb_; |
261 | std::unordered_map<Node*, DictNode> dict_cache_; |
262 | }; |
263 | |
264 | } // namespace |
265 | |
266 | bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph) { |
267 | PeepholeOptimizeDictIdiomsImpl opt(graph); |
268 | return opt.run(); |
269 | } |
270 | |
271 | } // namespace jit |
272 | } // namespace torch |
273 | |