1 | #include <torch/csrc/jit/passes/normalize_ops.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | namespace { |
9 | |
10 | // having multiple ops in our IR that do the same thing makes the IR more |
11 | // difficult to consumer for downstream user of the IR, such as our own |
12 | // optimization passes here, we convert op aliases into a standard form |
13 | bool normalizeOpAliases(graph_node_list_iterator& iter) { |
14 | auto alias = getOperatorAliasMap().find(iter->kind()); |
15 | if (alias != getOperatorAliasMap().end()) { |
16 | iter->replaceWithNewSymbol(alias->second); |
17 | iter.destroyCurrent(); |
18 | return true; |
19 | } |
20 | return false; |
21 | } |
22 | |
23 | // Normalize rsub such that `rsub(x,y) = sub(x,y)` |
24 | bool normalizeRSub(graph_node_list_iterator& iter) { |
25 | if (iter->matches( |
26 | "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" )) { |
27 | ArrayRef<Value*> args = iter->inputs(); |
28 | Node* newSub = iter->replaceWithNewSymbol(aten::sub); |
29 | newSub->replaceInput(0, args[1]); |
30 | newSub->replaceInput(1, args[0]); |
31 | iter.destroyCurrent(); |
32 | return true; |
33 | } |
34 | return false; |
35 | } |
36 | |
37 | // Normalizes a `__is__` comparison with a bool to `eq` (and same with |
38 | // `__isnot__`) |
39 | bool normalizeIsBool(graph_node_list_iterator& iter) { |
40 | ArrayRef<Value*> args = iter->inputs(); |
41 | if (args.size() == 2 && args[0]->type() == BoolType::get() && |
42 | args[1]->type() == BoolType::get()) { |
43 | if (iter->kind() == aten::__is__) { |
44 | iter->replaceWithNewSymbol(aten::eq); |
45 | iter.destroyCurrent(); |
46 | return true; |
47 | } |
48 | if (iter->kind() == aten::__isnot__) { |
49 | iter->replaceWithNewSymbol(aten::ne); |
50 | iter.destroyCurrent(); |
51 | return true; |
52 | } |
53 | } |
54 | return false; |
55 | } |
56 | |
57 | void NormalizeOps(Block* block) { |
58 | for (auto it = block->nodes().begin(), end = block->nodes().end(); |
59 | it != end;) { |
60 | for (auto sub : it->blocks()) { |
61 | NormalizeOps(sub); |
62 | } |
63 | |
64 | if (normalizeRSub(it)) { |
65 | continue; |
66 | } |
67 | |
68 | if (normalizeOpAliases(it)) { |
69 | continue; |
70 | } |
71 | |
72 | if (normalizeIsBool(it)) { |
73 | continue; |
74 | } |
75 | |
76 | it++; |
77 | } |
78 | } |
79 | |
80 | } // namespace |
81 | |
82 | const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() { |
83 | // map from op alias -> normalized op |
84 | static const std::unordered_map<Symbol, Symbol> alias_map = { |
85 | {aten::absolute, aten::abs}, |
86 | {aten::absolute_, aten::abs_}, |
87 | {aten::clip, aten::clamp}, |
88 | {aten::clip_, aten::clamp_}, |
89 | {aten::det, aten::linalg_det}, |
90 | {aten::matrix_power, aten::linalg_matrix_power}, |
91 | {aten::matrix_exp, aten::linalg_matrix_exp}, |
92 | {aten::ger, aten::outer}, |
93 | {aten::arccos, aten::acos}, |
94 | {aten::arccos_, aten::acos_}, |
95 | {aten::arcsin, aten::asin}, |
96 | {aten::arcsin_, aten::asin_}, |
97 | {aten::arctan, aten::atan}, |
98 | {aten::arctan_, aten::atan_}, |
99 | {aten::arctan2, aten::atan2}, |
100 | {aten::arctan2_, aten::atan2_}, |
101 | {aten::arccosh, aten::acosh}, |
102 | {aten::arccosh_, aten::acosh_}, |
103 | {aten::arcsinh, aten::asinh}, |
104 | {aten::arcsinh_, aten::asinh_}, |
105 | {aten::arctanh, aten::atanh}, |
106 | {aten::arctanh_, aten::atanh_}, |
107 | {aten::fix, aten::trunc}, |
108 | {aten::fix_, aten::trunc_}, |
109 | {aten::negative, aten::neg}, |
110 | {aten::negative_, aten::neg_}, |
111 | {aten::subtract, aten::sub}, |
112 | {aten::subtract_, aten::sub_}, |
113 | {aten::greater_equal, aten::ge}, |
114 | {aten::greater_equal_, aten::ge_}, |
115 | {aten::greater, aten::gt}, |
116 | {aten::greater_, aten::gt_}, |
117 | {aten::less_equal, aten::le}, |
118 | {aten::less_equal_, aten::le_}, |
119 | {aten::less, aten::lt}, |
120 | {aten::less_, aten::lt_}, |
121 | {aten::not_equal, aten::ne}, |
122 | {aten::not_equal_, aten::ne_}, |
123 | {aten::divide, aten::div}, |
124 | {aten::divide_, aten::div_}, |
125 | {aten::multiply, aten::mul}, |
126 | {aten::multiply_, aten::mul_}, |
127 | {aten::linalg_matmul, aten::matmul}, |
128 | {aten::inverse, aten::linalg_inv}, |
129 | {aten::true_divide, aten::div}, |
130 | {aten::true_divide_, aten::div_}, |
131 | {aten::concat, aten::cat}, |
132 | {aten::concatenate, aten::cat}, |
133 | {aten::row_stack, aten::vstack}, |
134 | {aten::swapdims, aten::transpose}, |
135 | {aten::swapdims_, aten::transpose_}, |
136 | {aten::swapaxes, aten::transpose}, |
137 | {aten::swapaxes_, aten::transpose_}, |
138 | {aten::moveaxis, aten::movedim}, |
139 | {aten::special_erf, aten::erf}, |
140 | {aten::special_erfc, aten::erfc}, |
141 | {aten::special_erfinv, aten::erfinv}, |
142 | {aten::special_expit, aten::sigmoid}, |
143 | {aten::special_exp2, aten::exp2}, |
144 | {aten::special_expm1, aten::expm1}, |
145 | {aten::special_logit, aten::logit}, |
146 | {aten::special_logsumexp, aten::logsumexp}, |
147 | {aten::special_round, aten::round}, |
148 | {aten::special_log1p, aten::log1p}, |
149 | {aten::special_sinc, aten::sinc}, |
150 | {aten::special_digamma, aten::digamma}, |
151 | {aten::special_psi, aten::digamma}, |
152 | {aten::special_i0, aten::i0}, |
153 | {aten::special_xlogy, aten::xlogy}, |
154 | {aten::special_log_softmax, aten::log_softmax}, |
155 | {aten::orgqr, aten::linalg_householder_product}, |
156 | {aten::adjoint, aten::mH}, |
157 | {aten::special_multigammaln, aten::mvlgamma}, |
158 | {aten::special_polygamma, aten::polygamma}, |
159 | {aten::special_softmax, aten::softmax}, |
160 | {aten::special_gammainc, aten::igamma}, |
161 | {aten::special_gammaincc, aten::igammac}, |
162 | {aten::special_gammaln, aten::lgamma}}; |
163 | return alias_map; |
164 | } |
165 | |
166 | void NormalizeOps(const std::shared_ptr<Graph>& graph) { |
167 | NormalizeOps(graph->block()); |
168 | } |
169 | |
170 | } // namespace jit |
171 | } // namespace torch |
172 | |