1#include <torch/csrc/jit/passes/peephole.h>
2
3#include <ATen/core/jit_type.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/ir/ir_views.h>
6#include <torch/csrc/jit/jit_log.h>
7
8namespace torch {
9namespace jit {
10
11namespace {
12
13/**
14 * Check whether the arithmetic node is binary between integers, and return a
15 * constant int value if there exists one.
16 *
17 * @pre node is integer arithmetic.
18 * @post if there's one constant in two operands, then the second operand is
19 * constant.
20 */
21c10::optional<int64_t> checkArithNode(Node& node) {
22 if (node.inputs().size() != 2 || node.input(0)->type() != IntType::get() ||
23 node.input(1)->type() != IntType::get()) {
24 return {};
25 }
26
27 if (node.kind() == aten::mul || node.kind() == aten::add) {
28 if (auto i = constant_as<int64_t>(node.input(0))) {
29 node.permuteInputs({1, 0});
30 return i;
31 }
32 }
33
34 return constant_as<int64_t>(node.input(1));
35}
36
37/**
38 * Remove a mul/floordiv node if it is multiplication or division by 1.
39 *
40 * @pre node is either aten::mul, aten::floordiv or aten::div
41 */
42bool trySimplifyMulOrDiv(Node& node) {
43 auto constant = checkArithNode(node);
44 if (!constant || *constant != 1) {
45 return false;
46 }
47
48 node.output()->replaceAllUsesWith(node.inputs()[0]);
49 return true;
50}
51
52/**
53 * Simplify an add/sub node with its input node, i.e. merge the constant parts
54 * together.
55 *
56 * @pre node is either aten::add or aten::sub
57 */
58bool trySimplifyAddOrSub(Node& node) {
59 auto constant = checkArithNode(node);
60 if (!constant) {
61 return false;
62 }
63
64 if (constant == 0) {
65 node.output()->replaceAllUsesWith(node.input(0));
66 return true;
67 }
68
69 auto& dep = *node.inputs()[0]->node();
70 if (dep.kind() != aten::add && dep.kind() != aten::sub) {
71 return false;
72 }
73
74 auto delta = checkArithNode(dep);
75 if (!delta) {
76 return false;
77 }
78 auto merged =
79 dep.kind() == node.kind() ? *constant + *delta : *constant - *delta;
80
81 if (merged == 0) {
82 node.output()->replaceAllUsesWith(dep.inputs()[0]);
83 } else {
84 WithInsertPoint g(&node);
85 node.replaceInput(0, dep.inputs()[0]);
86 node.replaceInput(1, node.owningGraph()->insertConstant(merged));
87 }
88 return true;
89}
90
91} // namespace
92
93struct PeepholeOptimizeNonTensorImpl {
94 // NOLINTNEXTLINE(modernize-pass-by-value)
95 PeepholeOptimizeNonTensorImpl(const std::shared_ptr<Graph>& graph)
96 : graph_(graph) {}
97
98 bool run() {
99 return optimizeBlock(graph_->block());
100 }
101
102 bool optimizeBlock(Block* block) {
103 bool changed = false;
104 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
105 auto* node = *it;
106
107 for (Block* sub_block : node->blocks()) {
108 changed |= optimizeBlock(sub_block);
109 }
110
111 if (node->kind() != prim::Constant) {
112 WithInsertPoint guard(node);
113 // Any Value whose type is None should be replaced with a Constant
114 // This can occur if a module has an optional attribute, and it is
115 // initialized as None.
116 for (Value* output : node->outputs()) {
117 if (output->type()->cast<NoneType>()) {
118 output->replaceAllUsesWith(graph_->insertConstant(IValue()));
119 changed = true;
120 }
121 }
122 }
123 // XXX: remember that if you want to simplify an expression by combining
124 // multiple nodes into a different one, then you need to check that they
125 // all belong to the given block
126 // TODO: this doesn't work with Scalar-Tensor ops! We should
127 // canonicalize those
128 if (node->kind() == prim::If) {
129 IfView n(node);
130 // this handles redundant short circuits like "x and True" or "x or
131 // False"
132 for (const auto i : c10::irange(n.outputs().size())) {
133 if (n.outputs().at(i)->type() != BoolType::get()) {
134 continue;
135 }
136 bool true_val =
137 constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
138 bool false_val =
139 constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
140 // if an if node's output equals its condition replace output with
141 // condition
142 if (true_val && !false_val) {
143 GRAPH_UPDATE(
144 "Replacing ",
145 n.outputs().at(i)->debugName(),
146 " (True or False) with ",
147 n.cond()->debugName());
148 n.outputs().at(i)->replaceAllUsesWith(n.cond());
149 changed = true;
150 }
151 }
152
153 // check for types that can be refined
154 for (size_t i = 0; i < n.outputs().size(); ++i) {
155 // common case of optional for now
156 bool inputs_non_optional =
157 !n.thenOutputs().at(i)->type()->cast<OptionalType>() &&
158 !n.elseOutputs().at(i)->type()->cast<OptionalType>();
159 auto output_optional =
160 n.outputs().at(i)->type()->cast<OptionalType>();
161 if (inputs_non_optional && output_optional) {
162 if (auto unif = unifyTypes(
163 n.thenOutputs().at(i)->type(),
164 n.elseOutputs().at(i)->type())) {
165 n.outputs().at(i)->setType(*unif);
166 changed = true;
167 }
168 }
169 }
170 } else if (
171 node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
172 // if we are comparing a None value with a value that can't be None
173 // replace the output with true if node is __isnot__ or false if node is
174 // __is__
175 AT_ASSERT(node->inputs().size() == 2);
176 for (size_t check_none_index : {0, 1}) {
177 bool input_must_be_none =
178 node->inputs().at(check_none_index)->mustBeNone();
179 bool other_must_not_be_none =
180 node->inputs().at(1 - check_none_index)->mustNotBeNone();
181 if (input_must_be_none && other_must_not_be_none) {
182 WithInsertPoint guard(node);
183 auto output = node->owningGraph()->insertConstant(
184 node->kind() == aten::__isnot__);
185 GRAPH_UPDATE(
186 "Folding ", getHeader(node), " to ", output->debugName());
187 node->output()->replaceAllUsesWith(output);
188 changed = true;
189 }
190 }
191 } else if (
192 node->kind() == prim::unchecked_unwrap_optional ||
193 node->kind() == aten::_unwrap_optional) {
194 // we are unwrapping an input that can't be None, remove the unwrap
195 auto input = node->input();
196 if (input->mustNotBeNone()) {
197 GRAPH_UPDATE(
198 "Unwrapping ",
199 getHeader(node),
200 " as ",
201 node->input(),
202 " can't be optional");
203 node->output()->replaceAllUsesWith(node->input());
204 changed = true;
205 }
206 } else if (node->kind() == prim::unchecked_cast) {
207 // unchecked_cast is not generated for tensor properties, so we are not
208 // losing anything by calling unshapedType here
209 auto input_type = unshapedType(node->input()->type());
210 auto output_type = unshapedType(node->output()->type());
211 if (input_type->isSubtypeOf(*output_type)) {
212 GRAPH_UPDATE(
213 "Removing ",
214 getHeader(node),
215 " as input type subtypes output type");
216 node->output()->replaceAllUsesWith(node->input());
217 changed = true;
218 }
219 } else if (
220 (node->kind() == aten::Int || node->kind() == aten::ceil) &&
221 node->inputs().size() == 1 &&
222 node->input()->type()->cast<IntType>()) {
223 GRAPH_UPDATE(
224 "Removing ", getHeader(node), " as input is already an integer");
225 node->output()->replaceAllUsesWith(node->input());
226 changed = true;
227 } else if (node->kind() == aten::ne || node->kind() == aten::eq) {
228 if (node->inputs().size() != 2 ||
229 node->inputs().at(0) != node->inputs().at(1)) {
230 continue;
231 }
232 auto inp_type = node->inputs().at(0)->type();
233 // only handling common immutable types here because other types like
234 // Tensor or list of Tensor might throw on aten::eq
235 auto immut_type = [&](const TypePtr& type) {
236 auto kind = type->kind();
237 static const std::vector<TypeKind> handled_immutable_types = {
238 TypeKind::BoolType,
239 TypeKind::IntType,
240 TypeKind::FloatType,
241 TypeKind::NoneType};
242 return (
243 std::find(
244 handled_immutable_types.begin(),
245 handled_immutable_types.end(),
246 kind) != handled_immutable_types.end());
247 };
248 bool non_throwing_type = false;
249 if (auto li_type = inp_type->cast<ListType>()) {
250 non_throwing_type = immut_type(li_type->getElementType());
251 } else if (auto di_type = inp_type->cast<DictType>()) {
252 non_throwing_type =
253 (immut_type(di_type->getKeyType()) &&
254 immut_type(di_type->getValueType()));
255 } else {
256 non_throwing_type = immut_type(inp_type);
257 }
258 if (non_throwing_type) {
259 WithInsertPoint guard(node);
260 node->output()->replaceAllUsesWith(
261 graph_->insertConstant(node->kind() == aten::eq));
262 changed = true;
263 }
264 } else if (
265 node->kind() == aten::mul || node->kind() == aten::floordiv ||
266 node->kind() == aten::div) {
267 changed |= trySimplifyMulOrDiv(*node);
268 } else if (node->kind() == aten::add || node->kind() == aten::sub) {
269 changed |= trySimplifyAddOrSub(*node);
270 }
271 }
272 return changed;
273 }
274
275 private:
276 std::shared_ptr<Graph> graph_;
277};
278
279bool PeepholeOptimizeNonTensor(const std::shared_ptr<Graph>& graph) {
280 PeepholeOptimizeNonTensorImpl peephole(graph);
281 bool changed = peephole.run();
282 GRAPH_DUMP("After PeepholeOptimize: ", graph);
283 return changed;
284}
285
286} // namespace jit
287} // namespace torch
288