1 | #include <torch/csrc/jit/passes/decompose_ops.h> |
2 | |
3 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
4 | #include <torch/csrc/jit/passes/constant_propagation.h> |
5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | #include <torch/csrc/jit/passes/shape_analysis.h> |
7 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
8 | #include <torch/csrc/jit/runtime/custom_operator.h> |
9 | #include <torch/csrc/jit/runtime/operator.h> |
10 | |
11 | #include <ATen/core/symbol.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | namespace { |
17 | c10::AliasAnalysisKind aliasAnalysisFromSchema() { |
18 | return c10::AliasAnalysisKind::FROM_SCHEMA; |
19 | } |
20 | } // namespace |
21 | |
22 | // helper to determine if an optional tensor argument/value passed in is |
23 | // statically defined (neither a None constant nor a Optional[Tensor] type) |
24 | // return yes, no, or no value if we can't tell |
25 | c10::optional<bool> isDefined(Value* tensor) { |
26 | if (tensor->type()->isSubtypeOf(*TensorType::get())) { |
27 | return true; |
28 | } |
29 | if (tensor->node()->mustBeNone()) { |
30 | return false; |
31 | } |
32 | return {}; |
33 | } |
34 | |
35 | bool isDecomposableNorm(Node* normalize_op) { |
36 | static const OperatorSet decomposable_normalization_ops = { |
37 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" , |
38 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor" , |
39 | }; |
40 | Value* input = normalize_op->namedInput(attr::input); |
41 | if (!input->type()->isSubtypeOf(*TensorType::get())) { |
42 | return false; |
43 | } |
44 | auto device = input->type()->expectRef<TensorType>().device(); |
45 | // As of now, we do the decomposition for batchnorm/layernorm on GPU device |
46 | // only |
47 | if (!device || !(*device).is_cuda()) { |
48 | return false; |
49 | } |
50 | |
51 | if (normalize_op->isMemberOf(decomposable_normalization_ops)) { |
52 | // If we can't determine if weight and bias is defined statically there's |
53 | // really no point in decomposing normalization into simpler ops, since it |
54 | // won't get fused into a single kernel. |
55 | return isDefined(normalize_op->namedInput(attr::weight)).has_value() && |
56 | isDefined(normalize_op->namedInput(attr::bias)).has_value(); |
57 | } |
58 | return false; |
59 | } |
60 | |
61 | RegisterOperators reg_ops( |
62 | {Operator( |
63 | "aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)" , |
64 | [](Stack& stack) { |
65 | const int64_t ndim = pop(stack).toInt(); |
66 | auto self = pop(stack).toTensor(); |
67 | c10::SmallVector<int64_t, 8> sizes(ndim, 1); |
68 | AT_ASSERT(self.dim() == 1); |
69 | sizes.at(1) = self.size(0); |
70 | push(stack, self.reshape(sizes)); |
71 | }, |
72 | aliasAnalysisFromSchema()), |
73 | Operator( |
74 | "aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)" , |
75 | [](Stack& stack) { |
76 | const int64_t normalized_ndim = pop(stack).toInt(); |
77 | auto input_shape = pop(stack).toIntList(); |
78 | auto self = pop(stack).toTensor(); |
79 | const int64_t input_ndim = input_shape.size(); |
80 | c10::SmallVector<int64_t, 8> sizes(input_ndim, 1); |
81 | for (int i = 0; i < input_ndim - normalized_ndim; ++i) { |
82 | sizes.at(i) = input_shape.get(i); |
83 | } |
84 | push(stack, self.reshape(sizes)); |
85 | }, |
86 | aliasAnalysisFromSchema())}); |
87 | |
88 | bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) { |
89 | bool decomposed = false; |
90 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
91 | ++it) { |
92 | for (auto sub : it->blocks()) { |
93 | DecomposeOps(sub, decompose_funcs); |
94 | } |
95 | |
96 | if (it->matches( |
97 | "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor" , |
98 | /*const_inputs=*/{attr::beta, attr::alpha})) { |
99 | // For the case where we have an addmm where alpha and beta are Attributes |
100 | // and both of those scalars are equal to 1.0, decompose this into an mm |
101 | // followed by an add so that it can go through the existing optimization |
102 | // (batchmm) |
103 | if (it->get<at::Scalar>(attr::alpha)->toComplexDouble() != 1.0 || |
104 | it->get<at::Scalar>(attr::beta)->toComplexDouble() != 1.0) { |
105 | continue; |
106 | } |
107 | |
108 | decomposed = true; |
109 | WithInsertPoint guard(*it); |
110 | std::shared_ptr<Graph> d_graph = |
111 | toGraphFunction(decompose_funcs.get_function("addmm" )).graph(); |
112 | Value* new_output = |
113 | insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0); |
114 | // Set the output of the decomposed graph to have the same output type as |
115 | // the original op otherwise the canonicalized graph will have TensorType |
116 | // as the output of this node which is incorrect |
117 | new_output->setType(it->output()->type()); |
118 | it->output()->replaceAllUsesWith(new_output); |
119 | it.destroyCurrent(); |
120 | } else if ( |
121 | it->matches( |
122 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" )) { |
123 | if (!isDecomposableNorm(*it)) { |
124 | continue; |
125 | } |
126 | decomposed = true; |
127 | WithInsertPoint insert_guard{*it}; |
128 | Graph* graph = it->owningGraph(); |
129 | Value* input = it->namedInput(attr::input); |
130 | Value* input_dim = graph->insert(aten::dim, {input}); |
131 | std::vector<Value*> inputs{ |
132 | input, |
133 | it->namedInput(attr::running_mean), |
134 | it->namedInput(attr::running_var), |
135 | it->namedInput(attr::training), |
136 | it->namedInput(attr::momentum), |
137 | it->namedInput(attr::eps)}; |
138 | |
139 | // inline the compiled decomposed batchnorm |
140 | std::shared_ptr<Graph> d_graph = |
141 | toGraphFunction(decompose_funcs.get_function("batch_norm" )).graph(); |
142 | Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0); |
143 | |
144 | // post processing the graph |
145 | Value* weight = it->namedInput(attr::weight); |
146 | Value* bias = it->namedInput(attr::bias); |
147 | if (isDefined(weight).value()) { |
148 | Value* expanded_weight = |
149 | graph->insert(aten::_ncf_unsqueeze, {weight, input_dim}); |
150 | new_output = graph->insert(aten::mul, {new_output, expanded_weight}); |
151 | } |
152 | if (isDefined(bias).value()) { |
153 | Value* expanded_bias = |
154 | graph->insert(aten::_ncf_unsqueeze, {bias, input_dim}); |
155 | new_output = graph->insert(aten::add, {new_output, expanded_bias}); |
156 | } |
157 | it->output()->replaceAllUsesWith(new_output); |
158 | it.destroyCurrent(); |
159 | } else if ( |
160 | it->matches( |
161 | "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor" )) { |
162 | if (!isDecomposableNorm(*it)) { |
163 | continue; |
164 | } |
165 | decomposed = true; |
166 | WithInsertPoint insert_guard{*it}; |
167 | Graph* graph = it->owningGraph(); |
168 | std::vector<Value*> inputs{ |
169 | it->namedInput(attr::input), |
170 | it->namedInput(attr::normalized_shape), |
171 | it->namedInput(attr::eps), |
172 | it->namedInput(attr::cudnn_enable)}; |
173 | |
174 | // inline the compiled decomposed layernorm |
175 | std::shared_ptr<Graph> d_graph = |
176 | toGraphFunction(decompose_funcs.get_function("layer_norm" )).graph(); |
177 | Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0); |
178 | |
179 | // post processing the graph |
180 | Value* weight = it->namedInput(attr::weight); |
181 | Value* bias = it->namedInput(attr::bias); |
182 | if (isDefined(weight).value()) { |
183 | new_output = graph->insert(aten::mul, {new_output, weight}); |
184 | } |
185 | if (isDefined(bias).value()) { |
186 | new_output = graph->insert(aten::add, {new_output, bias}); |
187 | } |
188 | it->output()->replaceAllUsesWith(new_output); |
189 | it.destroyCurrent(); |
190 | } |
191 | } |
192 | return decomposed; |
193 | } |
194 | |
195 | void DecomposeOps(std::shared_ptr<Graph>& graph) { |
196 | static CompilationUnit decompose_funcs(R"SCRIPT( |
197 | def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0): |
198 | return self + mat1.mm(mat2) |
199 | |
200 | def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor: |
201 | if training: |
202 | norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum) |
203 | else: |
204 | norm_mean = torch._unwrap_optional(running_mean) |
205 | norm_var = torch._unwrap_optional(running_var) |
206 | norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim()) |
207 | norm_var = torch._ncf_unsqueeze(norm_var, input.dim()) |
208 | norm_invstd = 1 / (torch.sqrt(norm_var + eps)) |
209 | return ((input - norm_mean) * norm_invstd) |
210 | |
211 | def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor: |
212 | input_ndim = input.dim() |
213 | normalized_ndim = len(normalized_shape) |
214 | n = 1 |
215 | for i in range(input_ndim - normalized_ndim): |
216 | n *= input.size(i) |
217 | input_reshape = input.contiguous().view(1, n, -1) |
218 | mean, invstd = torch.batch_norm_stats(input_reshape, eps) |
219 | input_shape = input.size() |
220 | mean = torch._ncf_view(mean, input_shape, normalized_ndim) |
221 | invstd = torch._ncf_view(invstd, input_shape, normalized_ndim) |
222 | |
223 | return (input - mean) * invstd |
224 | )SCRIPT" ); |
225 | bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs); |
226 | if (is_decomposed) { |
227 | // we only re-run those passes when the graph get decomposed |
228 | PropagateInputShapes(graph); |
229 | ConstantPropagation(graph); |
230 | EliminateDeadCode(graph); |
231 | } |
232 | } |
233 | |
234 | } // namespace jit |
235 | } // namespace torch |
236 | |