1#include <torch/csrc/jit/ir/alias_analysis.h>
2#include <torch/csrc/jit/passes/peephole_dict_idioms.h>
3
4namespace torch {
5namespace jit {
6
7namespace {
8
9class DictNodeImplBase {
10 public:
11 virtual ~DictNodeImplBase() = default;
12
13 virtual bool contains(const IValue&) const = 0;
14 virtual size_t size() const = 0;
15 virtual Value* get(const IValue&) const = 0;
16
17 bool canOptimize() {
18 return !has_overlap_ && !has_non_const_key_;
19 }
20
21 protected:
22 bool has_overlap_ = false;
23 bool has_non_const_key_ = false;
24};
25
26template <class KeyType>
27class DictNodeImpl : public DictNodeImplBase {
28 public:
29 DictNodeImpl(
30 std::function<KeyType(const IValue&)> ivalue_converter,
31 Node* dict_creation_node)
32 : ivalue_converter_(std::move(ivalue_converter)) {
33 for (size_t i = 0; i < dict_creation_node->inputs().size(); i += 2) {
34 auto key_opt = toIValue(dict_creation_node->input(i));
35
36 // Key is not constant if we cannot convert to IValue
37 if (key_opt == c10::nullopt) {
38 has_non_const_key_ = true;
39 continue;
40 }
41
42 KeyType key = ivalue_converter_(*key_opt);
43 if (dict_.find(key) == dict_.end()) {
44 dict_.emplace(key, dict_creation_node->input(i + 1));
45 } else {
46 has_overlap_ = true;
47 }
48 }
49 }
50
51 bool contains(const IValue& ivalue) const override {
52 auto key = ivalue_converter_(ivalue);
53 return dict_.find(key) != dict_.end();
54 }
55
56 size_t size() const override {
57 return dict_.size();
58 }
59
60 Value* get(const IValue& ivalue) const override {
61 auto val = ivalue_converter_(ivalue);
62 auto loc = dict_.find(val);
63 if (loc != dict_.end()) {
64 return loc->second;
65 }
66 TORCH_CHECK(false, "Cannot get non-existent key");
67 }
68
69 private:
70 std::unordered_map<KeyType, Value*> dict_;
71 std::function<KeyType(const IValue&)> ivalue_converter_;
72};
73
74class DictNode {
75 public:
76 explicit DictNode(Node* dict_creation_node) {
77 auto dict_type = dict_creation_node->output()->type();
78 auto key_value_types = dict_type->containedTypes();
79 TORCH_CHECK(
80 key_value_types.size() == 2, "Dict must have 2 contained types");
81 const auto& key_type = key_value_types[0];
82
83 switch (key_type->kind()) {
84 case TypeKind::IntType: {
85 auto ivalue_converter = [](const IValue& ival) { return ival.toInt(); };
86 impl_ = std::make_unique<DictNodeImpl<int64_t>>(
87 std::move(ivalue_converter), dict_creation_node);
88 break;
89 }
90
91 case TypeKind::FloatType: {
92 auto ivalue_converter = [](const IValue& ival) {
93 return ival.toDouble();
94 };
95 impl_ = std::make_unique<DictNodeImpl<double>>(
96 std::move(ivalue_converter), dict_creation_node);
97 break;
98 }
99
100 case TypeKind::StringType: {
101 auto ivalue_converter = [](const IValue& ival) {
102 return *ival.toString();
103 };
104 impl_ = std::make_unique<DictNodeImpl<std::string>>(
105 std::move(ivalue_converter), dict_creation_node);
106 break;
107 }
108
109 default:
110 impl_ = nullptr;
111 }
112 }
113
114 bool canOptimize() const {
115 if (impl_) {
116 return impl_->canOptimize();
117 }
118 return false;
119 }
120
121 size_t size() const {
122 if (impl_) {
123 return impl_->size();
124 }
125 return 0;
126 }
127
128 c10::optional<Value*> getOrNullopt(const IValue& key) const {
129 if (impl_ && impl_->contains(key)) {
130 return impl_->get(key);
131 }
132 return c10::nullopt;
133 }
134
135 private:
136 std::unique_ptr<DictNodeImplBase> impl_;
137};
138
139bool isDict(Value* v) {
140 return v->type()->castRaw<DictType>() != nullptr;
141}
142
143class PeepholeOptimizeDictIdiomsImpl {
144 public:
145 explicit PeepholeOptimizeDictIdiomsImpl(std::shared_ptr<Graph> graph)
146 : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
147
148 bool run() {
149 collectMutatedDicts(graph_->block());
150 return runBlock(graph_->block());
151 }
152
153 private:
154 void checkForMutatedDicts(Value* v) {
155 if (isDict(v) && aliasDb_->hasWriters(v)) {
156 mutated_dicts_.insert(v);
157 }
158 }
159
160 void collectMutatedDicts(Block* b) {
161 for (Value* v : b->inputs()) {
162 checkForMutatedDicts(v);
163 }
164 for (Node* n : b->nodes()) {
165 for (Value* v : n->outputs()) {
166 checkForMutatedDicts(v);
167 }
168 for (Block* block : n->blocks()) {
169 collectMutatedDicts(block);
170 }
171 }
172 }
173
174 const DictNode& getDictNode(Node* creation_node) {
175 auto cached = dict_cache_.find(creation_node);
176 if (cached == dict_cache_.end()) {
177 cached =
178 dict_cache_.emplace(creation_node, DictNode(creation_node)).first;
179 }
180
181 return cached->second;
182 }
183
184 c10::optional<Value*> getValueFromDict(Node* dict_creation_node, Value* key) {
185 const DictNode& dict_node = getDictNode(dict_creation_node);
186 auto key_opt = toIValue(key);
187 // Key is not constant if we cannot convert to IValue
188 if (key_opt == c10::nullopt) {
189 return c10::nullopt;
190 }
191 IValue key_ival = *key_opt;
192 if (dict_node.canOptimize()) {
193 return dict_node.getOrNullopt(key_ival);
194 }
195 return c10::nullopt;
196 }
197
198 c10::optional<int64_t> computeLen(Node* dict_creation_node) {
199 const DictNode& dict_node = getDictNode(dict_creation_node);
200 if (dict_node.canOptimize()) {
201 return static_cast<int64_t>(dict_node.size());
202 }
203 return c10::nullopt;
204 }
205
206 bool optimizeLen(Node* len_node, Node* creation_node) {
207 if (creation_node->kind() == prim::DictConstruct) {
208 auto len = computeLen(creation_node);
209 if (len != c10::nullopt) {
210 WithInsertPoint guard(len_node);
211 len_node->output()->replaceAllUsesWith(graph_->insertConstant(len));
212 return true;
213 }
214 }
215 return false;
216 }
217
218 bool optimizeGetItem(Node* getitem_node, Node* creation_node) {
219 if (creation_node->kind() == prim::DictConstruct) {
220 auto key = getitem_node->input(1);
221 auto value = getValueFromDict(creation_node, key);
222 if (value != c10::nullopt) {
223 getitem_node->output()->replaceAllUsesWith(*value);
224 return true;
225 }
226 }
227 return false;
228 }
229
230 bool runBlock(Block* block) {
231 bool changed = false;
232 for (Node* node : block->nodes()) {
233 for (Block* b : node->blocks()) {
234 changed |= runBlock(b);
235 }
236
237 // only optimizing dict ops
238 if (node->inputs().empty() || !isDict(node->input(0))) {
239 continue;
240 }
241
242 auto first_input = node->input(0);
243
244 // only optimizing ops with unmutated inputs
245 if (mutated_dicts_.count(first_input)) {
246 continue;
247 }
248
249 if (node->kind() == aten::len) {
250 changed |= optimizeLen(node, first_input->node());
251 } else if (node->kind() == aten::__getitem__) {
252 changed |= optimizeGetItem(node, first_input->node());
253 }
254 }
255 return changed;
256 }
257
258 std::shared_ptr<Graph> graph_;
259 std::unordered_set<Value*> mutated_dicts_;
260 std::unique_ptr<AliasDb> aliasDb_;
261 std::unordered_map<Node*, DictNode> dict_cache_;
262};
263
264} // namespace
265
266bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph) {
267 PeepholeOptimizeDictIdiomsImpl opt(graph);
268 return opt.run();
269}
270
271} // namespace jit
272} // namespace torch
273