1 | #include <torch/csrc/jit/passes/peephole.h> |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/ir/ir_views.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | namespace { |
12 | |
13 | /** |
14 | * Check whether the arithmetic node is binary between integers, and return a |
15 | * constant int value if there exists one. |
16 | * |
17 | * @pre node is integer arithmetic. |
18 | * @post if there's one constant in two operands, then the second operand is |
19 | * constant. |
20 | */ |
21 | c10::optional<int64_t> checkArithNode(Node& node) { |
22 | if (node.inputs().size() != 2 || node.input(0)->type() != IntType::get() || |
23 | node.input(1)->type() != IntType::get()) { |
24 | return {}; |
25 | } |
26 | |
27 | if (node.kind() == aten::mul || node.kind() == aten::add) { |
28 | if (auto i = constant_as<int64_t>(node.input(0))) { |
29 | node.permuteInputs({1, 0}); |
30 | return i; |
31 | } |
32 | } |
33 | |
34 | return constant_as<int64_t>(node.input(1)); |
35 | } |
36 | |
37 | /** |
38 | * Remove a mul/floordiv node if it is multiplication or division by 1. |
39 | * |
40 | * @pre node is either aten::mul, aten::floordiv or aten::div |
41 | */ |
42 | bool trySimplifyMulOrDiv(Node& node) { |
43 | auto constant = checkArithNode(node); |
44 | if (!constant || *constant != 1) { |
45 | return false; |
46 | } |
47 | |
48 | node.output()->replaceAllUsesWith(node.inputs()[0]); |
49 | return true; |
50 | } |
51 | |
52 | /** |
53 | * Simplify an add/sub node with its input node, i.e. merge the constant parts |
54 | * together. |
55 | * |
56 | * @pre node is either aten::add or aten::sub |
57 | */ |
58 | bool trySimplifyAddOrSub(Node& node) { |
59 | auto constant = checkArithNode(node); |
60 | if (!constant) { |
61 | return false; |
62 | } |
63 | |
64 | if (constant == 0) { |
65 | node.output()->replaceAllUsesWith(node.input(0)); |
66 | return true; |
67 | } |
68 | |
69 | auto& dep = *node.inputs()[0]->node(); |
70 | if (dep.kind() != aten::add && dep.kind() != aten::sub) { |
71 | return false; |
72 | } |
73 | |
74 | auto delta = checkArithNode(dep); |
75 | if (!delta) { |
76 | return false; |
77 | } |
78 | auto merged = |
79 | dep.kind() == node.kind() ? *constant + *delta : *constant - *delta; |
80 | |
81 | if (merged == 0) { |
82 | node.output()->replaceAllUsesWith(dep.inputs()[0]); |
83 | } else { |
84 | WithInsertPoint g(&node); |
85 | node.replaceInput(0, dep.inputs()[0]); |
86 | node.replaceInput(1, node.owningGraph()->insertConstant(merged)); |
87 | } |
88 | return true; |
89 | } |
90 | |
91 | } // namespace |
92 | |
93 | struct PeepholeOptimizeNonTensorImpl { |
94 | // NOLINTNEXTLINE(modernize-pass-by-value) |
95 | PeepholeOptimizeNonTensorImpl(const std::shared_ptr<Graph>& graph) |
96 | : graph_(graph) {} |
97 | |
98 | bool run() { |
99 | return optimizeBlock(graph_->block()); |
100 | } |
101 | |
102 | bool optimizeBlock(Block* block) { |
103 | bool changed = false; |
104 | for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { |
105 | auto* node = *it; |
106 | |
107 | for (Block* sub_block : node->blocks()) { |
108 | changed |= optimizeBlock(sub_block); |
109 | } |
110 | |
111 | if (node->kind() != prim::Constant) { |
112 | WithInsertPoint guard(node); |
113 | // Any Value whose type is None should be replaced with a Constant |
114 | // This can occur if a module has an optional attribute, and it is |
115 | // initialized as None. |
116 | for (Value* output : node->outputs()) { |
117 | if (output->type()->cast<NoneType>()) { |
118 | output->replaceAllUsesWith(graph_->insertConstant(IValue())); |
119 | changed = true; |
120 | } |
121 | } |
122 | } |
123 | // XXX: remember that if you want to simplify an expression by combining |
124 | // multiple nodes into a different one, then you need to check that they |
125 | // all belong to the given block |
126 | // TODO: this doesn't work with Scalar-Tensor ops! We should |
127 | // canonicalize those |
128 | if (node->kind() == prim::If) { |
129 | IfView n(node); |
130 | // this handles redundant short circuits like "x and True" or "x or |
131 | // False" |
132 | for (const auto i : c10::irange(n.outputs().size())) { |
133 | if (n.outputs().at(i)->type() != BoolType::get()) { |
134 | continue; |
135 | } |
136 | bool true_val = |
137 | constant_as<bool>(n.thenOutputs().at(i)).value_or(false); |
138 | bool false_val = |
139 | constant_as<bool>(n.elseOutputs().at(i)).value_or(true); |
140 | // if an if node's output equals its condition replace output with |
141 | // condition |
142 | if (true_val && !false_val) { |
143 | GRAPH_UPDATE( |
144 | "Replacing " , |
145 | n.outputs().at(i)->debugName(), |
146 | " (True or False) with " , |
147 | n.cond()->debugName()); |
148 | n.outputs().at(i)->replaceAllUsesWith(n.cond()); |
149 | changed = true; |
150 | } |
151 | } |
152 | |
153 | // check for types that can be refined |
154 | for (size_t i = 0; i < n.outputs().size(); ++i) { |
155 | // common case of optional for now |
156 | bool inputs_non_optional = |
157 | !n.thenOutputs().at(i)->type()->cast<OptionalType>() && |
158 | !n.elseOutputs().at(i)->type()->cast<OptionalType>(); |
159 | auto output_optional = |
160 | n.outputs().at(i)->type()->cast<OptionalType>(); |
161 | if (inputs_non_optional && output_optional) { |
162 | if (auto unif = unifyTypes( |
163 | n.thenOutputs().at(i)->type(), |
164 | n.elseOutputs().at(i)->type())) { |
165 | n.outputs().at(i)->setType(*unif); |
166 | changed = true; |
167 | } |
168 | } |
169 | } |
170 | } else if ( |
171 | node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { |
172 | // if we are comparing a None value with a value that can't be None |
173 | // replace the output with true if node is __isnot__ or false if node is |
174 | // __is__ |
175 | AT_ASSERT(node->inputs().size() == 2); |
176 | for (size_t check_none_index : {0, 1}) { |
177 | bool input_must_be_none = |
178 | node->inputs().at(check_none_index)->mustBeNone(); |
179 | bool other_must_not_be_none = |
180 | node->inputs().at(1 - check_none_index)->mustNotBeNone(); |
181 | if (input_must_be_none && other_must_not_be_none) { |
182 | WithInsertPoint guard(node); |
183 | auto output = node->owningGraph()->insertConstant( |
184 | node->kind() == aten::__isnot__); |
185 | GRAPH_UPDATE( |
186 | "Folding " , getHeader(node), " to " , output->debugName()); |
187 | node->output()->replaceAllUsesWith(output); |
188 | changed = true; |
189 | } |
190 | } |
191 | } else if ( |
192 | node->kind() == prim::unchecked_unwrap_optional || |
193 | node->kind() == aten::_unwrap_optional) { |
194 | // we are unwrapping an input that can't be None, remove the unwrap |
195 | auto input = node->input(); |
196 | if (input->mustNotBeNone()) { |
197 | GRAPH_UPDATE( |
198 | "Unwrapping " , |
199 | getHeader(node), |
200 | " as " , |
201 | node->input(), |
202 | " can't be optional" ); |
203 | node->output()->replaceAllUsesWith(node->input()); |
204 | changed = true; |
205 | } |
206 | } else if (node->kind() == prim::unchecked_cast) { |
207 | // unchecked_cast is not generated for tensor properties, so we are not |
208 | // losing anything by calling unshapedType here |
209 | auto input_type = unshapedType(node->input()->type()); |
210 | auto output_type = unshapedType(node->output()->type()); |
211 | if (input_type->isSubtypeOf(*output_type)) { |
212 | GRAPH_UPDATE( |
213 | "Removing " , |
214 | getHeader(node), |
215 | " as input type subtypes output type" ); |
216 | node->output()->replaceAllUsesWith(node->input()); |
217 | changed = true; |
218 | } |
219 | } else if ( |
220 | (node->kind() == aten::Int || node->kind() == aten::ceil) && |
221 | node->inputs().size() == 1 && |
222 | node->input()->type()->cast<IntType>()) { |
223 | GRAPH_UPDATE( |
224 | "Removing " , getHeader(node), " as input is already an integer" ); |
225 | node->output()->replaceAllUsesWith(node->input()); |
226 | changed = true; |
227 | } else if (node->kind() == aten::ne || node->kind() == aten::eq) { |
228 | if (node->inputs().size() != 2 || |
229 | node->inputs().at(0) != node->inputs().at(1)) { |
230 | continue; |
231 | } |
232 | auto inp_type = node->inputs().at(0)->type(); |
233 | // only handling common immutable types here because other types like |
234 | // Tensor or list of Tensor might throw on aten::eq |
235 | auto immut_type = [&](const TypePtr& type) { |
236 | auto kind = type->kind(); |
237 | static const std::vector<TypeKind> handled_immutable_types = { |
238 | TypeKind::BoolType, |
239 | TypeKind::IntType, |
240 | TypeKind::FloatType, |
241 | TypeKind::NoneType}; |
242 | return ( |
243 | std::find( |
244 | handled_immutable_types.begin(), |
245 | handled_immutable_types.end(), |
246 | kind) != handled_immutable_types.end()); |
247 | }; |
248 | bool non_throwing_type = false; |
249 | if (auto li_type = inp_type->cast<ListType>()) { |
250 | non_throwing_type = immut_type(li_type->getElementType()); |
251 | } else if (auto di_type = inp_type->cast<DictType>()) { |
252 | non_throwing_type = |
253 | (immut_type(di_type->getKeyType()) && |
254 | immut_type(di_type->getValueType())); |
255 | } else { |
256 | non_throwing_type = immut_type(inp_type); |
257 | } |
258 | if (non_throwing_type) { |
259 | WithInsertPoint guard(node); |
260 | node->output()->replaceAllUsesWith( |
261 | graph_->insertConstant(node->kind() == aten::eq)); |
262 | changed = true; |
263 | } |
264 | } else if ( |
265 | node->kind() == aten::mul || node->kind() == aten::floordiv || |
266 | node->kind() == aten::div) { |
267 | changed |= trySimplifyMulOrDiv(*node); |
268 | } else if (node->kind() == aten::add || node->kind() == aten::sub) { |
269 | changed |= trySimplifyAddOrSub(*node); |
270 | } |
271 | } |
272 | return changed; |
273 | } |
274 | |
275 | private: |
276 | std::shared_ptr<Graph> graph_; |
277 | }; |
278 | |
279 | bool PeepholeOptimizeNonTensor(const std::shared_ptr<Graph>& graph) { |
280 | PeepholeOptimizeNonTensorImpl peephole(graph); |
281 | bool changed = peephole.run(); |
282 | GRAPH_DUMP("After PeepholeOptimize: " , graph); |
283 | return changed; |
284 | } |
285 | |
286 | } // namespace jit |
287 | } // namespace torch |
288 | |