1#include <torch/csrc/jit/passes/decompose_ops.h>
2
3#include <torch/csrc/jit/frontend/ir_emitter.h>
4#include <torch/csrc/jit/passes/constant_propagation.h>
5#include <torch/csrc/jit/passes/dead_code_elimination.h>
6#include <torch/csrc/jit/passes/shape_analysis.h>
7#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
8#include <torch/csrc/jit/runtime/custom_operator.h>
9#include <torch/csrc/jit/runtime/operator.h>
10
11#include <ATen/core/symbol.h>
12
13namespace torch {
14namespace jit {
15
16namespace {
17c10::AliasAnalysisKind aliasAnalysisFromSchema() {
18 return c10::AliasAnalysisKind::FROM_SCHEMA;
19}
20} // namespace
21
22// helper to determine if an optional tensor argument/value passed in is
23// statically defined (neither a None constant nor a Optional[Tensor] type)
24// return yes, no, or no value if we can't tell
25c10::optional<bool> isDefined(Value* tensor) {
26 if (tensor->type()->isSubtypeOf(*TensorType::get())) {
27 return true;
28 }
29 if (tensor->node()->mustBeNone()) {
30 return false;
31 }
32 return {};
33}
34
35bool isDecomposableNorm(Node* normalize_op) {
36 static const OperatorSet decomposable_normalization_ops = {
37 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
38 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
39 };
40 Value* input = normalize_op->namedInput(attr::input);
41 if (!input->type()->isSubtypeOf(*TensorType::get())) {
42 return false;
43 }
44 auto device = input->type()->expectRef<TensorType>().device();
45 // As of now, we do the decomposition for batchnorm/layernorm on GPU device
46 // only
47 if (!device || !(*device).is_cuda()) {
48 return false;
49 }
50
51 if (normalize_op->isMemberOf(decomposable_normalization_ops)) {
52 // If we can't determine if weight and bias is defined statically there's
53 // really no point in decomposing normalization into simpler ops, since it
54 // won't get fused into a single kernel.
55 return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
56 isDefined(normalize_op->namedInput(attr::bias)).has_value();
57 }
58 return false;
59}
60
61RegisterOperators reg_ops(
62 {Operator(
63 "aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)",
64 [](Stack& stack) {
65 const int64_t ndim = pop(stack).toInt();
66 auto self = pop(stack).toTensor();
67 c10::SmallVector<int64_t, 8> sizes(ndim, 1);
68 AT_ASSERT(self.dim() == 1);
69 sizes.at(1) = self.size(0);
70 push(stack, self.reshape(sizes));
71 },
72 aliasAnalysisFromSchema()),
73 Operator(
74 "aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)",
75 [](Stack& stack) {
76 const int64_t normalized_ndim = pop(stack).toInt();
77 auto input_shape = pop(stack).toIntList();
78 auto self = pop(stack).toTensor();
79 const int64_t input_ndim = input_shape.size();
80 c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
81 for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
82 sizes.at(i) = input_shape.get(i);
83 }
84 push(stack, self.reshape(sizes));
85 },
86 aliasAnalysisFromSchema())});
87
88bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
89 bool decomposed = false;
90 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
91 ++it) {
92 for (auto sub : it->blocks()) {
93 DecomposeOps(sub, decompose_funcs);
94 }
95
96 if (it->matches(
97 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
98 /*const_inputs=*/{attr::beta, attr::alpha})) {
99 // For the case where we have an addmm where alpha and beta are Attributes
100 // and both of those scalars are equal to 1.0, decompose this into an mm
101 // followed by an add so that it can go through the existing optimization
102 // (batchmm)
103 if (it->get<at::Scalar>(attr::alpha)->toComplexDouble() != 1.0 ||
104 it->get<at::Scalar>(attr::beta)->toComplexDouble() != 1.0) {
105 continue;
106 }
107
108 decomposed = true;
109 WithInsertPoint guard(*it);
110 std::shared_ptr<Graph> d_graph =
111 toGraphFunction(decompose_funcs.get_function("addmm")).graph();
112 Value* new_output =
113 insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
114 // Set the output of the decomposed graph to have the same output type as
115 // the original op otherwise the canonicalized graph will have TensorType
116 // as the output of this node which is incorrect
117 new_output->setType(it->output()->type());
118 it->output()->replaceAllUsesWith(new_output);
119 it.destroyCurrent();
120 } else if (
121 it->matches(
122 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
123 if (!isDecomposableNorm(*it)) {
124 continue;
125 }
126 decomposed = true;
127 WithInsertPoint insert_guard{*it};
128 Graph* graph = it->owningGraph();
129 Value* input = it->namedInput(attr::input);
130 Value* input_dim = graph->insert(aten::dim, {input});
131 std::vector<Value*> inputs{
132 input,
133 it->namedInput(attr::running_mean),
134 it->namedInput(attr::running_var),
135 it->namedInput(attr::training),
136 it->namedInput(attr::momentum),
137 it->namedInput(attr::eps)};
138
139 // inline the compiled decomposed batchnorm
140 std::shared_ptr<Graph> d_graph =
141 toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
142 Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
143
144 // post processing the graph
145 Value* weight = it->namedInput(attr::weight);
146 Value* bias = it->namedInput(attr::bias);
147 if (isDefined(weight).value()) {
148 Value* expanded_weight =
149 graph->insert(aten::_ncf_unsqueeze, {weight, input_dim});
150 new_output = graph->insert(aten::mul, {new_output, expanded_weight});
151 }
152 if (isDefined(bias).value()) {
153 Value* expanded_bias =
154 graph->insert(aten::_ncf_unsqueeze, {bias, input_dim});
155 new_output = graph->insert(aten::add, {new_output, expanded_bias});
156 }
157 it->output()->replaceAllUsesWith(new_output);
158 it.destroyCurrent();
159 } else if (
160 it->matches(
161 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor")) {
162 if (!isDecomposableNorm(*it)) {
163 continue;
164 }
165 decomposed = true;
166 WithInsertPoint insert_guard{*it};
167 Graph* graph = it->owningGraph();
168 std::vector<Value*> inputs{
169 it->namedInput(attr::input),
170 it->namedInput(attr::normalized_shape),
171 it->namedInput(attr::eps),
172 it->namedInput(attr::cudnn_enable)};
173
174 // inline the compiled decomposed layernorm
175 std::shared_ptr<Graph> d_graph =
176 toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
177 Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
178
179 // post processing the graph
180 Value* weight = it->namedInput(attr::weight);
181 Value* bias = it->namedInput(attr::bias);
182 if (isDefined(weight).value()) {
183 new_output = graph->insert(aten::mul, {new_output, weight});
184 }
185 if (isDefined(bias).value()) {
186 new_output = graph->insert(aten::add, {new_output, bias});
187 }
188 it->output()->replaceAllUsesWith(new_output);
189 it.destroyCurrent();
190 }
191 }
192 return decomposed;
193}
194
195void DecomposeOps(std::shared_ptr<Graph>& graph) {
196 static CompilationUnit decompose_funcs(R"SCRIPT(
197 def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
198 return self + mat1.mm(mat2)
199
200 def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
201 if training:
202 norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
203 else:
204 norm_mean = torch._unwrap_optional(running_mean)
205 norm_var = torch._unwrap_optional(running_var)
206 norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
207 norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
208 norm_invstd = 1 / (torch.sqrt(norm_var + eps))
209 return ((input - norm_mean) * norm_invstd)
210
211 def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
212 input_ndim = input.dim()
213 normalized_ndim = len(normalized_shape)
214 n = 1
215 for i in range(input_ndim - normalized_ndim):
216 n *= input.size(i)
217 input_reshape = input.contiguous().view(1, n, -1)
218 mean, invstd = torch.batch_norm_stats(input_reshape, eps)
219 input_shape = input.size()
220 mean = torch._ncf_view(mean, input_shape, normalized_ndim)
221 invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
222
223 return (input - mean) * invstd
224 )SCRIPT");
225 bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs);
226 if (is_decomposed) {
227 // we only re-run those passes when the graph get decomposed
228 PropagateInputShapes(graph);
229 ConstantPropagation(graph);
230 EliminateDeadCode(graph);
231 }
232}
233
234} // namespace jit
235} // namespace torch
236