1 | #include <torch/csrc/jit/passes/graph_fuser.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/codegen/fuser/interface.h> |
6 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
7 | #include <torch/csrc/jit/ir/alias_analysis.h> |
8 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
9 | #include <torch/csrc/jit/passes/constant_pooling.h> |
10 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
11 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
12 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
13 | #include <torch/csrc/jit/runtime/autodiff.h> |
14 | #include <torch/csrc/jit/runtime/custom_operator.h> |
15 | #include <torch/csrc/jit/runtime/operator.h> |
16 | |
17 | #include <queue> |
18 | #include <unordered_map> |
19 | #include <utility> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | |
24 | namespace { |
25 | |
26 | // What is a simple mappable operator? It: |
27 | // - Has a single tensor output |
28 | // - Output and all tensor inputs have the same shape |
29 | // - Output and all tensor inputs have the same scalar type |
30 | // or all tensor inputs have the same scalar type and |
31 | // output is identified in PropagateInputShapes |
32 | // - Output and all tensor inputs should be on the same device |
33 | // - Produces dense non-overlapping outputs |
34 | // Some of these restrictions may be relaxable, but you should |
35 | // carefully read the code first, as we rely on these assumptions. |
36 | bool isSimpleMap(Node* node) { |
37 | static OperatorSet simple_mappable{{ |
38 | "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor" , |
39 | |
40 | "aten::abs(Tensor self) -> Tensor" , |
41 | "aten::acos(Tensor self) -> Tensor" , |
42 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
43 | "aten::asin(Tensor self) -> Tensor" , |
44 | "aten::atan(Tensor self) -> Tensor" , |
45 | "aten::atan2(Tensor self, Tensor other) -> Tensor" , |
46 | "aten::ceil(Tensor self) -> Tensor" , |
47 | "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor" , |
48 | "aten::cos(Tensor self) -> Tensor" , |
49 | "aten::cosh(Tensor self) -> Tensor" , |
50 | "aten::div(Tensor self, Tensor other) -> Tensor" , |
51 | "aten::exp(Tensor self) -> Tensor" , |
52 | "aten::expm1(Tensor self) -> Tensor" , |
53 | "aten::erf(Tensor self) -> Tensor" , |
54 | "aten::erfc(Tensor self) -> Tensor" , |
55 | "aten::floor(Tensor self) -> Tensor" , |
56 | "aten::fmod(Tensor self, Tensor other) -> Tensor" , |
57 | "aten::frac(Tensor self) -> Tensor" , |
58 | "aten::lgamma(Tensor self) -> Tensor" , |
59 | "aten::log(Tensor self) -> Tensor" , |
60 | "aten::log10(Tensor self) -> Tensor" , |
61 | "aten::log1p(Tensor self) -> Tensor" , |
62 | "aten::log2(Tensor self) -> Tensor" , |
63 | "aten::logit(Tensor self, float? eps=None) -> Tensor" , |
64 | "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor" , |
65 | "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor" , |
66 | "aten::max(Tensor self, Tensor other) -> Tensor" , |
67 | "aten::min(Tensor self, Tensor other) -> Tensor" , |
68 | "aten::mul(Tensor self, Tensor other) -> Tensor" , |
69 | "aten::neg(Tensor self) -> Tensor" , |
70 | "aten::pow(Tensor self, Tensor exponent) -> Tensor" , |
71 | "aten::pow(Tensor self, Scalar exponent) -> Tensor" , |
72 | "aten::pow(Scalar self, Tensor exponent) -> Tensor" , |
73 | "aten::reciprocal(Tensor self) -> Tensor" , |
74 | "aten::relu(Tensor self) -> Tensor" , |
75 | "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor" , |
76 | "aten::remainder(Tensor self, Tensor other) -> Tensor" , |
77 | "aten::round(Tensor self) -> Tensor" , |
78 | "aten::rsqrt(Tensor self) -> Tensor" , |
79 | "aten::sigmoid(Tensor self) -> Tensor" , |
80 | "aten::sin(Tensor self) -> Tensor" , |
81 | "aten::sinh(Tensor self) -> Tensor" , |
82 | "aten::sqrt(Tensor self) -> Tensor" , |
83 | "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
84 | "aten::tan(Tensor self) -> Tensor" , |
85 | "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
86 | "aten::tanh(Tensor self) -> Tensor" , |
87 | "aten::trunc(Tensor self) -> Tensor" , |
88 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
89 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
90 | "aten::mul(Tensor self, Scalar other) -> Tensor" , |
91 | "aten::div(Tensor self, Scalar other) -> Tensor" , |
92 | |
93 | "aten::eq(Tensor self, Tensor other) -> Tensor" , |
94 | "aten::eq(Tensor self, Scalar other) -> Tensor" , |
95 | "aten::ne(Tensor self, Tensor other) -> Tensor" , |
96 | "aten::ne(Tensor self, Scalar other) -> Tensor" , |
97 | "aten::ge(Tensor self, Tensor other) -> Tensor" , |
98 | "aten::ge(Tensor self, Scalar other) -> Tensor" , |
99 | "aten::gt(Tensor self, Tensor other) -> Tensor" , |
100 | "aten::gt(Tensor self, Scalar other) -> Tensor" , |
101 | "aten::le(Tensor self, Tensor other) -> Tensor" , |
102 | "aten::le(Tensor self, Scalar other) -> Tensor" , |
103 | "aten::lt(Tensor self, Tensor other) -> Tensor" , |
104 | "aten::lt(Tensor self, Scalar other) -> Tensor" , |
105 | |
106 | "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor" , |
107 | "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor" , |
108 | |
109 | "aten::type_as(Tensor self, Tensor other) -> Tensor" , |
110 | }}; |
111 | if (!node->isMemberOf(simple_mappable)) { |
112 | return false; |
113 | } |
114 | for (Value* input : node->inputs()) { |
115 | if (input->type()->isSubtypeOf(*TensorType::get()) || |
116 | input->type()->isSubtypeOf(*FloatType::get())) { |
117 | continue; |
118 | } |
119 | if (input->node()->kind() != prim::Constant) { |
120 | return false; |
121 | } |
122 | } |
123 | return true; |
124 | } |
125 | |
126 | struct GraphFuser { |
127 | using FusionCallback = std::function<bool(GraphFuser*, Node*)>; |
128 | |
129 | Block* block_; |
130 | AliasDb* aliasDb_; |
131 | std::shared_ptr<Graph> graph_; |
132 | FusionCallback callback_ = [](GraphFuser* gf, Node* n) { |
133 | return gf->isFusableDefault(n, gf->strict_fuser_check_); |
134 | }; |
135 | Symbol kind_ = prim::FusionGroup; |
136 | bool strict_fuser_check_ = false; |
137 | |
138 | // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. |
139 | // The specific limit is a function of constant memory size, amount available |
140 | // to pass arguments, and some implementation dependence. Select a safe |
141 | // limit here. |
142 | // This limit is also applied to other devices in the fuser by default. |
143 | // Change with setInputArgLimit |
144 | size_t subgraph_arg_limit_ = 128; |
145 | |
146 | GraphFuser(AliasDb* aliasDb, Block* block, bool strict_fuser_check) |
147 | : block_(block), |
148 | aliasDb_(aliasDb), |
149 | strict_fuser_check_(strict_fuser_check) {} |
150 | |
151 | // Custom passes require kind to specified |
152 | GraphFuser( |
153 | AliasDb* aliasDb, |
154 | Block* block, |
155 | FusionCallback callback, |
156 | Symbol kind, |
157 | bool strict_fuser_check = false) |
158 | : block_(block), |
159 | aliasDb_(aliasDb), |
160 | callback_(std::move(callback)), |
161 | kind_(kind), |
162 | strict_fuser_check_(strict_fuser_check) {} |
163 | |
164 | void setInputArgLimit(size_t limit) { |
165 | subgraph_arg_limit_ = limit; |
166 | } |
167 | |
168 | value_list tensorInputs(Node* node) { |
169 | return filter(node->inputs(), [](Value* v) { |
170 | return v->type()->isSubtypeOf(*TensorType::get()); |
171 | }); |
172 | } |
173 | |
174 | bool isFusable(Node* node) { |
175 | return callback_(this, node); |
176 | } |
177 | |
178 | bool isFusableDevice(Value* v, bool strict_fuser_check) { |
179 | if (!v->type()->isSubtypeOf(*TensorType::get())) { |
180 | return true; |
181 | } |
182 | auto device = v->type()->expectRef<TensorType>().device(); |
183 | if (!device) { |
184 | return !strict_fuser_check; |
185 | } |
186 | if ((*device).is_cpu()) { |
187 | return canFuseOnCPULegacy(); |
188 | } else if ((*device).is_cuda()) { |
189 | return canFuseOnGPU(); |
190 | } else if ((*device).is_xpu()) { |
191 | return false; |
192 | } else { |
193 | TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for graph fuser" ); |
194 | } |
195 | } |
196 | |
197 | // Default fusability check - used when the user doesn't pass in |
198 | // a callback. |
199 | bool isFusableDefault(Node* node, bool strict_fuser_check) { |
200 | bool fusableDevice = true; |
201 | for (const auto& output : node->outputs()) { |
202 | if (!output->uses().empty()) { |
203 | fusableDevice &= isFusableDevice(output, strict_fuser_check); |
204 | } |
205 | } |
206 | return fusableDevice && isFusableMap(node); |
207 | } |
208 | |
209 | bool isFusableMap(Node* node) { |
210 | // We don't want to bother with cross-block node movements, as they |
211 | // are not necessarily correct. |
212 | if (node->owningBlock() != block_) |
213 | return false; |
214 | return node->kind() == prim::FusionGroup || isSimpleMap(node); |
215 | } |
216 | |
217 | bool isFusableCatNode(Node* node) { |
218 | if (node->kind() != aten::cat) |
219 | return false; |
220 | if (!node->is_constant(attr::dim)) |
221 | return false; |
222 | |
223 | auto tensors_node = node->namedInput(attr::tensors)->node(); |
224 | if ((tensors_node->inputs().size() + node->outputs().size()) > |
225 | subgraph_arg_limit_) { |
226 | return false; |
227 | } |
228 | if (tensors_node->kind() != prim::ListConstruct) |
229 | return false; |
230 | // NB: Note that technically other uses of the list aren't a big problem for |
231 | // us. It would be enough to place the prim::FusedConcat before the |
232 | // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would |
233 | // still be satisfied. However, I don't expect this to be necessary any time |
234 | // soon, and so we're simply assuming that we don't have to deal with it. |
235 | if (tensors_node->output()->uses().size() > 1) |
236 | return false; |
237 | return true; |
238 | } |
239 | |
240 | bool calculatesSize(Node* node) { |
241 | return node->matches("aten::size(Tensor self) -> int[]" ); |
242 | } |
243 | |
244 | bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) { |
245 | auto defining_node = producer->node(); |
246 | for (auto o : defining_node->outputs()) { |
247 | for (auto u : o->uses()) { |
248 | if (u.user != consumer && !calculatesSize(u.user)) |
249 | return false; |
250 | } |
251 | } |
252 | return true; |
253 | } |
254 | |
255 | Graph& getSubgraph(Node* n) { |
256 | AT_ASSERT(n->kind() == kind_); |
257 | return *n->g(attr::Subgraph); |
258 | } |
259 | |
260 | void mergeFusionGroups(Node* consumer_group, Node* producer_group) { |
261 | // Now we have two fusion groups! |
262 | // Revert the fusion - place all inner nodes of producer back in the outer |
263 | // graph. |
264 | std::vector<Node*> temporary_nodes; |
265 | auto producer_subgraph = &getSubgraph(producer_group); |
266 | |
267 | // Initialize a map of inner graph values to outer graph values |
268 | std::unordered_map<Value*, Value*> inner_to_outer; |
269 | auto inner_inputs = producer_subgraph->inputs(); |
270 | auto outer_inputs = producer_group->inputs(); |
271 | for (const auto i : c10::irange(inner_inputs.size())) { |
272 | inner_to_outer[inner_inputs[i]] = outer_inputs[i]; |
273 | } |
274 | |
275 | // Clone all nodes |
276 | for (auto inner : producer_subgraph->nodes()) { |
277 | Node* outer = block_->owningGraph()->createClone( |
278 | inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); }); |
279 | outer->insertBefore(producer_group); |
280 | temporary_nodes.emplace_back(outer); |
281 | auto inner_outputs = inner->outputs(); |
282 | auto outer_outputs = outer->outputs(); |
283 | for (const auto i : c10::irange(inner_outputs.size())) { |
284 | inner_to_outer[inner_outputs[i]] = outer_outputs[i]; |
285 | } |
286 | } |
287 | |
288 | // Replace uses of producer_group outputs and destroy the producer |
289 | auto subgraph_outputs = producer_subgraph->outputs(); |
290 | for (const auto i : c10::irange(subgraph_outputs.size())) { |
291 | auto outer_output = inner_to_outer.at(subgraph_outputs[i]); |
292 | producer_group->outputs()[i]->replaceAllUsesWith(outer_output); |
293 | // new producer outputs have same aliasing properties as outer_output |
294 | aliasDb_->replaceWithNewValue(producer_group->outputs()[i], outer_output); |
295 | } |
296 | producer_group->destroy(); |
297 | producer_group = |
298 | nullptr; // Just to get a clear error in case someone uses it |
299 | |
300 | // Inline the temporary nodes into the first group |
301 | auto consumer_subgraph = &getSubgraph(consumer_group); |
302 | for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); |
303 | ++it) { |
304 | Node* node = *it; |
305 | Node* merged = mergeNodeIntoGroup(consumer_group, node); |
306 | // If any of the outputs are still used then we need to add them |
307 | auto outputs = node->outputs(); |
308 | for (const auto i : c10::irange(outputs.size())) { |
309 | auto output = outputs[i]; |
310 | if (output->uses().empty()) |
311 | continue; |
312 | consumer_subgraph->registerOutput(merged->outputs()[i]); |
313 | auto new_output = consumer_group->addOutput(); |
314 | output->replaceAllUsesWith(new_output); |
315 | aliasDb_->replaceWithNewValue(output, new_output); |
316 | new_output->setType(output->type()); |
317 | } |
318 | node->destroy(); |
319 | } |
320 | } |
321 | |
322 | // insert a producer node into a consuming fusion group. |
323 | // DOES NOT WORK if n is a consumer of an output of the fusion group |
324 | // returns the node _inside_ the group that represents the node |
325 | Node* mergeNodeIntoGroup(Node* group, Node* n) { |
326 | AT_ASSERT(n->kind() != kind_); |
327 | auto& subgraph = getSubgraph(group); |
328 | // map from nodes in the surrounding graph to parameters in the fusion |
329 | // group's subgraph that correspond to them |
330 | std::unordered_map<Value*, Value*> inputs_map; |
331 | size_t i = 0; |
332 | size_t tensor_insert_idx = 0; |
333 | AT_ASSERT(group->inputs().size() == subgraph.inputs().size()); |
334 | for (auto input : group->inputs()) { |
335 | inputs_map[input] = subgraph.inputs()[i++]; |
336 | if (input->type()->isSubtypeOf(*TensorType::get())) |
337 | tensor_insert_idx = i; |
338 | } |
339 | // add n's inputs to the fusion group's input list if we don't already have |
340 | // them |
341 | // we insert tensors first because the fuser assumes that to be the case |
342 | // (as a legacy from tensors only) |
343 | WithInsertPoint guard(*subgraph.nodes().begin()); |
344 | for (auto input : n->inputs()) { |
345 | if (inputs_map.count(input) == 0) { |
346 | if (input->type()->isSubtypeOf(*TensorType::get())) { |
347 | auto in_group = subgraph.insertInput(tensor_insert_idx); |
348 | in_group->setType(input->type()); |
349 | inputs_map[input] = in_group; |
350 | group->insertInput(tensor_insert_idx, input); |
351 | tensor_insert_idx++; |
352 | } else if ( |
353 | (input->type()->isSubtypeOf(*FloatType::get()) && |
354 | input->node()->kind() != prim::Constant) || |
355 | (n->kind() == aten::_grad_sum_to_size && |
356 | input->type()->isSubtypeOf(*ListType::ofInts()))) { |
357 | auto in_group = subgraph.addInput(); |
358 | in_group->setType(input->type()); |
359 | inputs_map[input] = in_group; |
360 | group->addInput(input); |
361 | } else { |
362 | // We don't support passing in scalars as arguments to fused kernels, |
363 | // so we generally don't allow fusing tensor-scalar operations unless |
364 | // the scalar is constant. In those cases we inline the constants |
365 | // directly in the body of the fused group. |
366 | AT_ASSERT(input->node()->kind() == prim::Constant); |
367 | Node* in_const = |
368 | subgraph.createClone(input->node(), [](Value*) -> Value* { |
369 | throw std::runtime_error("unexpected input" ); |
370 | }); |
371 | subgraph.insertNode(in_const); |
372 | inputs_map[input] = in_const->output(); |
373 | } |
374 | } |
375 | } |
376 | // copy n into the graph, remapping its inputs to internal nodes |
377 | Node* in_graph = subgraph.createClone( |
378 | n, [&](Value* k) -> Value* { return inputs_map[k]; }); |
379 | // if n's outputs are already inputs to the fusion group, |
380 | // we need to remove them because n is now inside the fusion group. |
381 | // |
382 | // i.e., |
383 | // x = f(w); group(x, y, z) becomes group(w, y, z). |
384 | // x, y, z = f(w); group(x, y, z) becomes group(w). |
385 | // |
386 | // remapping nodes that used the input to the newly-merged node |
387 | // n is not an input when the fusion group is empty |
388 | auto inputs = group->inputs(); |
389 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
390 | auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]); |
391 | if (it != inputs.end()) { |
392 | size_t p = it - inputs.begin(); |
393 | group->removeInput(p); |
394 | subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]); |
395 | subgraph.eraseInput(p); |
396 | } |
397 | } |
398 | return subgraph.insertNode(in_graph); |
399 | } |
400 | |
401 | // turn consumer node n into a fusion group with just n inside |
402 | // to prepare for fusion and replace uses of n with the new group |
403 | Node* createSingletonFusionGroup(Node* n) { |
404 | auto group = block_->owningGraph()->createWithSubgraph(kind_); |
405 | // propogate position information for the new node so we can always |
406 | // have a valid mapping |
407 | group->insertBefore(n); |
408 | Node* mergedNode = mergeNodeIntoGroup(group, n); |
409 | getSubgraph(group).registerOutput(mergedNode->output()); |
410 | auto sel = group->addOutput(); |
411 | sel->copyMetadata(n->output()); |
412 | aliasDb_->replaceWithNewValue(n->output(), sel); |
413 | n->replaceAllUsesWith(group); |
414 | n->destroy(); |
415 | return group; |
416 | } |
417 | |
418 | at::optional<Node*> tryFuse(Node* consumer, Value* producer) { |
419 | // this handles cases where producer can be moved _into_ the fusion group of |
420 | // consumer. |
421 | // TODO: extend to fusion of consumer into _producer's_ fusion blob |
422 | // if the consumer allInputsAreThisProducer(consumer,producer) |
423 | // we can move the consumer up into the producer. |
424 | // but this requires better handling of merging fusion groups so it is not |
425 | // done now |
426 | bool shouldFuse = isFusable(producer->node()) && |
427 | // Rearrange nodes such that all uses of producer are after the |
428 | // consumer. Fusion will rewrite those later uses to use the version of |
429 | // producer generated by the fused blob. In this case, producer becomes |
430 | // an output of the fusion group. |
431 | aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer); |
432 | |
433 | if (!shouldFuse) { |
434 | return at::nullopt; |
435 | } |
436 | |
437 | if ((consumer->inputs().size() + consumer->outputs().size() + |
438 | producer->node()->inputs().size() + |
439 | producer->node()->outputs().size()) > subgraph_arg_limit_) { |
440 | return at::nullopt; |
441 | } |
442 | |
443 | auto group = consumer; |
444 | if (consumer->kind() != kind_) { |
445 | group = createSingletonFusionGroup(consumer); |
446 | } |
447 | |
448 | if (producer->node()->kind() == kind_) { |
449 | mergeFusionGroups(group, producer->node()); |
450 | return group; |
451 | } |
452 | AT_ASSERT(producer->node()->outputs().size() == 1); |
453 | Node* merged = mergeNodeIntoGroup(group, producer->node()); |
454 | // remaining uses of this producer can occur because we allow |
455 | // fusion in cases where uses remain after the consumer |
456 | // if these exist, re-route them to the version of producer |
457 | // created in FusionGroup |
458 | if (!producer->uses().empty()) { |
459 | getSubgraph(group).registerOutput(merged->output()); |
460 | Value* new_producer = group->addOutput(); |
461 | new_producer->copyMetadata(producer); |
462 | aliasDb_->replaceWithNewValue(producer, new_producer); |
463 | producer->replaceAllUsesWith(new_producer); |
464 | } |
465 | producer->node()->destroy(); |
466 | return group; |
467 | } |
468 | |
469 | bool canFuseChunk(Node* consumer, Value* producer) { |
470 | if (consumer->kind() != prim::FusionGroup) { |
471 | return false; |
472 | } |
473 | // Does the chunk have constant chunks/dim? |
474 | auto* chunk = producer->node(); |
475 | if (chunk->kind() != prim::ConstantChunk) |
476 | return false; |
477 | // And all uses of the chunk are in this consumer |
478 | for (auto s : chunk->outputs()) { |
479 | for (auto u : s->uses()) { |
480 | if (u.user != consumer) { |
481 | return false; |
482 | } |
483 | } |
484 | } |
485 | // And isn't a no-op chunk (chunks == 1). Have CSE clean this up. |
486 | // We could fuse this but it's better to just delete the node. |
487 | if (chunk->i(attr::chunks) == 1) { |
488 | return false; |
489 | } |
490 | return true; |
491 | } |
492 | |
493 | c10::optional<Node*> findFusedChunk(Node* group, Value* input) { |
494 | AT_ASSERT(group->kind() == prim::FusionGroup); |
495 | auto it = std::find(group->inputs().begin(), group->inputs().end(), input); |
496 | if (it == group->inputs().end()) { |
497 | return c10::nullopt; |
498 | } |
499 | size_t input_index = it - group->inputs().begin(); |
500 | auto& subgraph = getSubgraph(group); |
501 | auto* subgraph_input = subgraph.inputs().at(input_index); |
502 | // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use |
503 | auto* node = subgraph_input->uses().at(0).user; |
504 | if (node->kind() == prim::ConstantChunk) { |
505 | AT_ASSERT(subgraph_input->uses().size() == 1); |
506 | return node; |
507 | } |
508 | return c10::nullopt; |
509 | } |
510 | |
511 | void fuseChunkByReusingExistingFusedChunk( |
512 | Node* group, |
513 | Node* chunk, |
514 | Node* existingFusedChunk) { |
515 | if (chunk->outputs().size() != existingFusedChunk->outputs().size()) { |
516 | return; |
517 | } |
518 | auto& subgraph = getSubgraph(group); |
519 | for (size_t i = 0; i < chunk->outputs().size(); ++i) { |
520 | // Find the input to the FusionGroup (group) |
521 | auto* replacement_val = existingFusedChunk->outputs().at(i); |
522 | auto* val = chunk->outputs().at(i); |
523 | auto it = std::find(group->inputs().begin(), group->inputs().end(), val); |
524 | auto input_index = it - group->inputs().begin(); |
525 | |
526 | // Rewrite the graph to use replacement_val |
527 | auto group_input = subgraph.inputs().at(input_index); |
528 | group_input->replaceAllUsesWith(replacement_val); |
529 | |
530 | // Remove the input, it's no longer needed |
531 | group->removeInput(input_index); |
532 | subgraph.eraseInput(input_index); |
533 | } |
534 | chunk->destroy(); |
535 | } |
536 | |
537 | // There are two invariants for prim::ConstantChunk: |
538 | // (1) the tensor input to prim::ConstantChunk must be an input to the fusion |
539 | // group (2) no two ConstantChunks in the same FusionGroup can share a tensor |
540 | // input. |
541 | graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) { |
542 | auto* chunk = producer->node(); |
543 | AT_ASSERT(consumer->kind() == prim::FusionGroup); |
544 | AT_ASSERT(chunk->kind() == prim::ConstantChunk); |
545 | |
546 | // if producer's input is already an input to a prim::ConstantChunk node, |
547 | // we cannot add a new prim::ConstantChunk node because of invariant (2). |
548 | auto* chunked_tensor = producer->node()->input(); |
549 | if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) { |
550 | fuseChunkByReusingExistingFusedChunk( |
551 | consumer, chunk, *existingFusedChunk); |
552 | return consumer->reverseIterator(); |
553 | } |
554 | |
555 | // Move prim::ConstantChunk into the FusionGroup |
556 | mergeNodeIntoGroup(consumer, chunk); |
557 | chunk->destroy(); |
558 | return consumer->reverseIterator(); |
559 | } |
560 | |
561 | value_list sortReverseTopological(ArrayRef<Value*> inputs) { |
562 | value_list result; |
563 | for (auto i : inputs) { |
564 | if (i->node()->owningBlock() == block_) { |
565 | result.push_back(i); |
566 | } |
567 | } |
568 | // Sort in reverse topological order |
569 | std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { |
570 | return a->node()->isAfter(b->node()); |
571 | }); |
572 | return result; |
573 | } |
574 | |
575 | graph_node_list::iterator scanNodeForChunks(Node* consumer) { |
576 | if (consumer->kind() == prim::FusionGroup) { |
577 | auto inputs = sortReverseTopological(consumer->inputs()); |
578 | for (auto producer : inputs) { |
579 | if (!canFuseChunk(consumer, producer)) { |
580 | continue; |
581 | } |
582 | return fuseChunk(consumer, producer); |
583 | } |
584 | } |
585 | return ++consumer->reverseIterator(); |
586 | } |
587 | |
588 | at::ArrayRef<Value*> broadcast_tensors(value_list inputs) { |
589 | AT_ASSERT(!inputs.empty()); |
590 | auto* g = inputs[0]->owningGraph(); |
591 | auto* input_list = |
592 | g->insertNode(g->createList(TensorType::get(), inputs))->output(); |
593 | aliasDb_->createValue(input_list); |
594 | auto* output_list = g->insert(aten::broadcast_tensors, {input_list}); |
595 | aliasDb_->createValue(output_list); |
596 | auto* unpack_node = g->insertNode( |
597 | g->create(prim::ListUnpack, {output_list}, inputs.size())); |
598 | |
599 | // We are doing: |
600 | // input_list = listConstruct(a, b, ...) |
601 | // output_list = broadcast_tensors(input_list) |
602 | // a_broadcasted, b_broadcasted = listUnpack(output_list) |
603 | // `a_broadcasted` should receive the same aliasing info as `a` |
604 | TORCH_INTERNAL_ASSERT(unpack_node->outputs().size() == inputs.size()); |
605 | for (const auto i : c10::irange(inputs.size())) { |
606 | Value* original_input = inputs[i]; |
607 | Value* broadcasted_output = unpack_node->outputs()[i]; |
608 | aliasDb_->copyValue(original_input, broadcasted_output); |
609 | } |
610 | |
611 | return unpack_node->outputs(); |
612 | } |
613 | |
614 | void insertExplicitBroadcast(Node* node) { |
615 | WithInsertPoint insert_guard{node}; |
616 | auto tensors = tensorInputs(node); |
617 | auto new_tensors = broadcast_tensors(std::move(tensors)); |
618 | |
619 | // Replace tensors inputs with broadcasted values |
620 | auto new_tensors_it = new_tensors.begin(); |
621 | for (size_t i = 0; i < node->inputs().size(); ++i) { |
622 | if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) { |
623 | AT_ASSERT(new_tensors_it != new_tensors.end()); |
624 | node->replaceInput(i, *(new_tensors_it++)); |
625 | } |
626 | } |
627 | } |
628 | |
629 | Node* promoteChunkToBroadcastingChunk(Node* chunk) { |
630 | AT_ASSERT(chunk->kind() == prim::ConstantChunk); |
631 | |
632 | size_t nchunks = chunk->i(attr::chunks); |
633 | Node* bchunk = |
634 | chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks); |
635 | bchunk->addInput(chunk->input()); |
636 | for (const auto i : c10::irange(nchunks)) { |
637 | auto* old_output = chunk->outputs().at(i); |
638 | auto* new_output = bchunk->outputs().at(i); |
639 | new_output->copyMetadata(old_output); |
640 | aliasDb_->replaceWithNewValue(old_output, new_output); |
641 | old_output->replaceAllUsesWith(new_output); |
642 | } |
643 | bchunk->copyAttributes(*chunk); |
644 | bchunk->insertAfter(chunk); |
645 | chunk->destroy(); |
646 | return bchunk; |
647 | } |
648 | |
649 | // in places where op can be fused into a consumer but chunk is in the way |
650 | // distribute chunk to op's operands: |
651 | // replace a,b = chunk(op(x,y,z)) with: |
652 | // x', y', z' = broadcast_tensors([x, y, z]) |
653 | // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type) |
654 | // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type) |
655 | // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type) |
656 | // a = op(x0,y0,z0) (a,b have their same size but are now contiguous) |
657 | // b = op(x1,y1,x1) |
658 | // |
659 | // The graph fuser uses an intermediate prim::BroadcastingChunk node to |
660 | // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts |
661 | // all of its inputs and then chunks each input, in order, the same way. |
662 | // The above graph is equivalent to: |
663 | // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z) |
664 | // a = op(x0,y0,z0) |
665 | // b = op(x1,y1,x1) |
666 | // |
667 | // NB: The explicit broadcast is important for correctness. |
668 | // Let's say we have: |
669 | // %z = aten::mul(%x, %y) |
670 | // %z.1, %z.2 = aten::chunk(%z, ...) |
671 | // ... = prim::FusionGroup(%z.1, %z.2, ...) |
672 | // It's possible that %x and %y do not have the same size as %z and |
673 | // need to be expanded first so that they can be chunked like %z |
674 | // |
675 | // NB: Chunk motion only occurs with fusable consumers, which implies |
676 | // that there is always some other operation, e.g., a+b, that happens |
677 | // after the chunk, and will be put into the fusion group. This is |
678 | // important, because distributing the chunk changes the contiguity |
679 | // of a and b, and so the results would be invalid, except that we know |
680 | // that simple_mappable operations will restore contiguity before |
681 | // we exit the fusion group. |
682 | // |
683 | // NB: The intermediate BroadcastingChunk is important for moving chunks past |
684 | // more than one operation: the graph fuser is not able to easily move |
685 | // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible |
686 | // ops |
687 | // x = f(v, w) |
688 | // z = g(x, y) |
689 | // a, b = chunk(z) |
690 | // c = h(a, b) |
691 | // becomes (with the broadcast_tensors + chunk approach): |
692 | // x = f(v, w) |
693 | // x', y' = broadcast_tensors([x, y]) |
694 | // ax, bx = chunk(x') |
695 | // ay, by = chunk(y') |
696 | // a = g(ax, ay) |
697 | // b = g(bx, by) |
698 | // c = h(a, b) |
699 | // The broadcast_tensors node makes it harder to move f into the resulting |
700 | // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior |
701 | // together results in: |
702 | // x = f(v, w) |
703 | // ax, bx, ay, by = BroadcastingChunk(x, y) |
704 | // a = g(ax, ay) |
705 | // b = g(bx, by) |
706 | // c = h(a, b) |
707 | // making it easier to move f after the BroadcastingChunk: |
708 | // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w) |
709 | // ax = f(av, aw) |
710 | // by = f(bv, bw) |
711 | // a = g(ax, ay) |
712 | // b = g(bx, by) |
713 | // c = h(a, b) |
714 | |
715 | bool tryToMoveChunk(Node* consumer, Value* producer) { |
716 | // is the output from a chunk/bchunk node? |
717 | auto* chunk = producer->node(); |
718 | if (chunk->kind() != prim::ConstantChunk && |
719 | chunk->kind() != prim::BroadcastingChunk) |
720 | return false; |
721 | |
722 | // try to find a producer to move after the chunk/bchunk. The producer must |
723 | // be fusible into the consumer. |
724 | auto it = std::find_if( |
725 | chunk->inputs().begin(), |
726 | chunk->inputs().end(), |
727 | [&](Value* producer_for_chunk) { |
728 | return isFusableMap(producer_for_chunk->node()) && |
729 | allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk); |
730 | }); |
731 | if (it == chunk->inputs().end()) { |
732 | return false; |
733 | } |
734 | Value* producer_for_chunk = *it; |
735 | size_t producer_index = it - chunk->inputs().begin(); |
736 | |
737 | // all uses of the chunk must be in this consumer |
738 | for (auto s : chunk->outputs()) { |
739 | for (auto u : s->uses()) { |
740 | if (u.user != consumer) |
741 | return false; |
742 | } |
743 | } |
744 | // multiple return operators |
745 | Node* producer_for_chunk_node = producer_for_chunk->node(); |
746 | AT_ASSERT(producer_for_chunk_node->outputs().size() == 1); |
747 | |
748 | // Convert chunk to bchunk, if it isn't one already. The bchunk represents a |
749 | // broadcast and one or more chunk operations. |
750 | auto* bchunk = chunk; |
751 | if (chunk->kind() == prim::ConstantChunk) { |
752 | bchunk = promoteChunkToBroadcastingChunk(chunk); |
753 | } |
754 | size_t nchunks = bchunk->i(attr::chunks); |
755 | WithInsertPoint guard(bchunk->next()); |
756 | |
757 | std::vector<Value*> producer_chunk_outputs; |
758 | for (const auto i : c10::irange(nchunks)) { |
759 | producer_chunk_outputs.push_back( |
760 | bchunk->output(nchunks * producer_index + i)); |
761 | } |
762 | |
763 | // Add each of op's operands to the bchunk node. |
764 | // chunked_inputs[input_nr][chunk_output_idx] |
765 | // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr]) |
766 | std::vector<std::vector<Value*>> chunked_inputs; |
767 | |
768 | for (auto input : producer_for_chunk_node->inputs()) { |
769 | // XXX: we only work with pointwise ops in here, so we know it is valid to |
770 | // push the concat only through tensor arguments (and all other args can |
771 | // be safely ignored). |
772 | if (!input->type()->isSubtypeOf(*TensorType::get())) |
773 | continue; |
774 | |
775 | // if 'input' is already an input to the bchunk, reuse it. |
776 | auto bchunk_inputs = bchunk->inputs(); |
777 | auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input); |
778 | if (it != bchunk_inputs.end()) { |
779 | chunked_inputs.emplace_back(); |
780 | auto input_index = std::distance(bchunk_inputs.begin(), it); |
781 | for (const auto chunki : c10::irange(nchunks)) { |
782 | chunked_inputs.back().push_back( |
783 | bchunk->outputs().at(nchunks * input_index + chunki)); |
784 | } |
785 | continue; |
786 | } |
787 | |
788 | // NB: I decided not to use cloneFrom here, because if we make cloneFrom |
789 | // copy selects one day, it is definitely not what you want here (selects |
790 | // have different types). |
791 | // TODO: Perhaps we should use cloneFrom now, as it seems unlikely |
792 | // to copy select nodes now that we have refactored to have a Value |
793 | // distinct from Node. |
794 | bchunk->addInput(input); |
795 | chunked_inputs.emplace_back(); // alas, to not be C++17 |
796 | for (auto chunk_sel : producer_chunk_outputs) { |
797 | Value* input_chunk_sel = bchunk->addOutput(); |
798 | input_chunk_sel->setType(chunk_sel->type()); |
799 | // Add a fresh value for each output element of the broadcasting chunk |
800 | // node. This is safe because it will be consumed only by the chunked |
801 | // ops. |
802 | aliasDb_->createValue(input_chunk_sel); |
803 | chunked_inputs.back().push_back(input_chunk_sel); |
804 | } |
805 | } |
806 | |
807 | // apply the op to each chunk of the chunked operands, |
808 | // and then rewrite the graph to use them! |
809 | for (auto chunk_sel : producer_chunk_outputs) { |
810 | auto original_inputs = producer_for_chunk_node->inputs(); |
811 | Node* chunked_op = |
812 | block_->owningGraph()->create(producer_for_chunk_node->kind()); |
813 | chunked_op->copyAttributes(*producer_for_chunk_node); |
814 | chunked_op->output()->setType(chunk_sel->type()); |
815 | auto chunked_inputs_it = chunked_inputs.begin(); |
816 | for (Value* original_input : original_inputs) { |
817 | if (original_input->type()->isSubtypeOf(*TensorType::get())) { |
818 | AT_ASSERT(chunked_inputs_it != chunked_inputs.end()); |
819 | chunked_op->addInput( |
820 | // NOLINTNEXTLINE(clang-analyzer-core.DivideZero) |
821 | chunked_inputs_it->at(chunk_sel->offset() % nchunks)); |
822 | ++chunked_inputs_it; |
823 | } else { |
824 | chunked_op->addInput(original_input); |
825 | } |
826 | } |
827 | bchunk->owningGraph()->insertNode(chunked_op); |
828 | chunk_sel->replaceAllUsesWith(chunked_op->output()); |
829 | aliasDb_->replaceWithNewValue(chunk_sel, chunked_op->output()); |
830 | } |
831 | |
832 | bchunk->removeInput(producer_index); |
833 | for (const auto i : c10::irange(nchunks)) { |
834 | (void)i; // Suppress unused variable warning |
835 | bchunk->eraseOutput(nchunks * producer_index); |
836 | } |
837 | |
838 | // The output of producer_for_chunk_node could have been used in some |
839 | // aten::size operators, so we need to clean those up as well (we simply |
840 | // broadcast all its tensor inputs). |
841 | // We need to insert these early in the graph, i.e. immediately after |
842 | // the producer_for_chunk_node as we will have the _size_if_not_same |
843 | // that may be before the bchunk. |
844 | WithInsertPoint guard2(producer_for_chunk_node); |
845 | auto size_calc_uses = producer_for_chunk_node->output()->uses(); |
846 | if (!size_calc_uses.empty()) { |
847 | auto tensor_inputs = filter( |
848 | producer_for_chunk_node->inputs(), |
849 | [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); |
850 | auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) { |
851 | Value* output = v->owningGraph()->insert(aten::size, {v}); |
852 | aliasDb_->createValue(output); |
853 | return output; |
854 | }); |
855 | AT_ASSERT(!tensor_sizes.empty()); |
856 | Value* output_size = tensor_sizes.size() == 1 |
857 | ? tensor_sizes[0] |
858 | : broadcastSizes(tensor_sizes, aliasDb_); |
859 | for (Use u : size_calc_uses) { |
860 | u.user->output()->replaceAllUsesWith(output_size); |
861 | u.user->destroy(); |
862 | } |
863 | } |
864 | producer_for_chunk_node->destroy(); |
865 | return true; |
866 | } |
867 | |
868 | // returns where to continue scanning, and whether any fusion was made |
869 | std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) { |
870 | if (isFusable(consumer)) { |
871 | // handle inputs in reverse topological order as well... |
872 | // otherwise in f(a,a+b) it will appear a is used twice if we consider |
873 | // the f-a fusion before the f-(a+b) fusion first. |
874 | auto inputs = sortReverseTopological(consumer->inputs()); |
875 | for (auto producer : inputs) { |
876 | if (tryToMoveChunk(consumer, producer)) { |
877 | // the chunk before this consumer was re-arranged to allow fusion, |
878 | // we scan this consumer again to perform the fusion |
879 | return std::make_pair(consumer->reverseIterator(), true); |
880 | } |
881 | auto fusion_group = tryFuse(consumer, producer); |
882 | if (fusion_group) { |
883 | // after fusion, consumer moves into a FusionGroup, so inputs is no |
884 | // longer valid so we rescan the new FusionGroup for more fusions... |
885 | return std::make_pair(fusion_group.value()->reverseIterator(), true); |
886 | } |
887 | } |
888 | } |
889 | return std::make_pair(++consumer->reverseIterator(), false); |
890 | } |
891 | |
892 | void replaceIntermediateBroadcastingChunks() { |
893 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
894 | auto* node = *it; |
895 | ++it; // We might delete node, so increment the iterator now. |
896 | if (node->kind() != prim::BroadcastingChunk) { |
897 | continue; |
898 | } |
899 | auto* bchunk = node; |
900 | insertExplicitBroadcast(bchunk); |
901 | |
902 | auto* graph = block_->owningGraph(); |
903 | size_t nchunks = bchunk->i(attr::chunks); |
904 | WithInsertPoint guard(bchunk->next()); |
905 | |
906 | // Split the bchunk into bchunks.inputs().size() number of chunk nodes. |
907 | for (size_t input_offset = 0; input_offset < bchunk->inputs().size(); |
908 | input_offset++) { |
909 | auto* input = bchunk->inputs().at(input_offset); |
910 | |
911 | Node* new_chunk = |
912 | graph->insertNode(graph->create(prim::ConstantChunk, input, 0)); |
913 | new_chunk->copyAttributes(*bchunk); |
914 | for (const auto output_offset : c10::irange(nchunks)) { |
915 | auto new_output = new_chunk->addOutput(); |
916 | auto old_output = |
917 | bchunk->outputs().at(input_offset * nchunks + output_offset); |
918 | new_output->copyMetadata(old_output); |
919 | aliasDb_->replaceWithNewValue(old_output, new_output); |
920 | old_output->replaceAllUsesWith(new_output); |
921 | } |
922 | } |
923 | bchunk->destroy(); |
924 | } |
925 | } |
926 | |
927 | // Builds up expressions that compute shapes of all intermediates (and |
928 | // outputs) of the fusion group, based on the sizes of inputs. You should run |
929 | // DCE to remove those that you end up not using. |
930 | std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) { |
931 | WithInsertPoint insert_guard{fusion_group->next()}; |
932 | std::unordered_map<Value*, Value*> shape_of; |
933 | |
934 | Graph* graph = fusion_group->owningGraph(); |
935 | auto subgraph = fusion_group->g(attr::Subgraph); |
936 | |
937 | auto inputs = fusion_group->inputs(); |
938 | auto sinputs = subgraph->inputs(); |
939 | AT_ASSERT(inputs.size() == sinputs.size()); |
940 | for (const auto i : c10::irange(inputs.size())) { |
941 | if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { |
942 | Value* soutput = graph->insert(aten::size, {inputs[i]}); |
943 | aliasDb_->createValue(soutput); |
944 | shape_of[sinputs[i]] = soutput; |
945 | } |
946 | } |
947 | |
948 | // When we have a guarantee that an output won't be removed, because it's |
949 | // used in expressions that don't involve size checks, we can use its size |
950 | // instead of computing a long chain of broadcasts, starting from the |
951 | // beginning of the kernel. |
952 | auto outputs = fusion_group->outputs(); |
953 | auto soutputs = subgraph->outputs(); |
954 | AT_ASSERT(outputs.size() == soutputs.size()); |
955 | for (const auto i : c10::irange(outputs.size())) { |
956 | if (usedOnlyInSize(outputs[i])) |
957 | continue; |
958 | Value* soutput = graph->insert(aten::size, {outputs[i]}); |
959 | aliasDb_->createValue(soutput); |
960 | shape_of[soutputs[i]] = soutput; |
961 | } |
962 | |
963 | for (Node* n : subgraph->nodes()) { |
964 | // XXX: Use of shape_of.emplace is crucial to the output shape |
965 | // optimization! |
966 | if (n->kind() == prim::FusedConcat) { |
967 | // This is a bit more involved, because we have to account for the case |
968 | // when inputs have different shapes, but fortunately those tensors are |
969 | // always outputs, and so we can simply avoid replacing their queries, |
970 | // because it won't help us. |
971 | continue; |
972 | } |
973 | if (n->kind() == prim::Constant) { |
974 | continue; |
975 | } |
976 | if (n->kind() == prim::ConstantChunk) { |
977 | Node* sizes_node = graph->insertNode( |
978 | graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); |
979 | sizes_node->i_(attr::dim, n->i(attr::dim)); |
980 | sizes_node->i_(attr::chunks, n->i(attr::chunks)); |
981 | for (Value* output : sizes_node->outputs()) { |
982 | aliasDb_->createValue(output); |
983 | } |
984 | Value* regular_size = sizes_node->outputs().at(0); |
985 | Value* last_size = sizes_node->outputs().at(1); |
986 | regular_size->setType(ListType::ofInts()); |
987 | last_size->setType(ListType::ofInts()); |
988 | auto outputs = n->outputs(); |
989 | for (Value* o : outputs.slice(0, outputs.size() - 1)) { |
990 | shape_of.emplace(o, regular_size); |
991 | } |
992 | shape_of.emplace(outputs.at(outputs.size() - 1), last_size); |
993 | continue; |
994 | } |
995 | auto tensor_inputs = filter(n->inputs(), [](Value* v) { |
996 | return v->type()->isSubtypeOf(*TensorType::get()); |
997 | }); |
998 | auto shapes = |
999 | fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); }); |
1000 | AT_ASSERT(!shapes.empty()); |
1001 | shape_of.emplace( |
1002 | n->output(), |
1003 | shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes, aliasDb_)); |
1004 | } |
1005 | return shape_of; |
1006 | } |
1007 | |
1008 | void removeOutputsUsedOnlyInSize(Node* fusion_group) { |
1009 | if (fusion_group->kind() != prim::FusionGroup) |
1010 | return; |
1011 | auto subgraph = fusion_group->g(attr::Subgraph); |
1012 | |
1013 | auto shape_of = buildShapeExpressions(fusion_group); |
1014 | auto outputs = fusion_group->outputs().vec(); |
1015 | auto soutputs = subgraph->outputs().vec(); |
1016 | // XXX: Iterating in this order is not only good for performance reasons! |
1017 | // It is also crucial for correctness (i has to reflect the current true |
1018 | // index of outputs[i])! |
1019 | for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) { |
1020 | auto output = outputs[i]; |
1021 | auto soutput = soutputs[i]; |
1022 | if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) { |
1023 | auto uses = output->uses(); |
1024 | for (Use u : uses) { |
1025 | AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]" )); |
1026 | u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); |
1027 | u.user->destroy(); |
1028 | } |
1029 | fusion_group->eraseOutput(i); |
1030 | subgraph->eraseOutput(i); |
1031 | } |
1032 | } |
1033 | } |
1034 | |
1035 | bool canFuseWithConcat(Value* producer, Node* before_check) { |
1036 | if (!isFusable(producer->node())) { |
1037 | return false; |
1038 | } |
1039 | // NB: it is important that this check happens after isFusable, which checks |
1040 | // that the blocks match, and it's not a special node like prim::Param |
1041 | if (!aliasDb_->couldMoveBeforeTopologically( |
1042 | producer->node(), before_check)) { |
1043 | return false; |
1044 | } |
1045 | |
1046 | // If the number of kernel args could exceed the limit, skip. |
1047 | if ((before_check->inputs().size() + before_check->outputs().size() + |
1048 | producer->node()->inputs().size() + |
1049 | producer->node()->outputs().size()) > subgraph_arg_limit_) { |
1050 | return false; |
1051 | } |
1052 | |
1053 | // Fusion groups can be merged with concat's group if and only if |
1054 | // the value they produce isn't already coming from a concat |
1055 | if (producer->node()->kind() == prim::FusionGroup) { |
1056 | auto subgraph = producer->node()->g(attr::Subgraph); |
1057 | auto* node = subgraph->outputs().at(producer->offset())->node(); |
1058 | return node->kind() != prim::FusedConcat; |
1059 | } |
1060 | return true; |
1061 | } |
1062 | |
1063 | Node* createFusedConcat(Node* node) { |
1064 | AT_ASSERT(node->kind() == aten::cat); |
1065 | |
1066 | Graph* graph = node->owningGraph(); |
1067 | Node* list_construct = node->namedInput(attr::tensors)->node(); |
1068 | int64_t dim = node->get<int64_t>(attr::dim).value(); |
1069 | |
1070 | Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs()) |
1071 | ->i_(attr::dim, dim); |
1072 | fused_cat->insertBefore(list_construct); |
1073 | fused_cat->output()->copyMetadata(node->output()); |
1074 | aliasDb_->copyValue(node->output(), fused_cat->output()); |
1075 | |
1076 | // NB: this deletes the fused_cat node from the original graph |
1077 | return createSingletonFusionGroup(fused_cat); |
1078 | } |
1079 | |
1080 | void fuseConcats() { |
1081 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend(); |
1082 | ++it) { |
1083 | Node* cat = *it; |
1084 | if (!isFusableCatNode(cat)) { |
1085 | continue; |
1086 | } |
1087 | Node* list_construct = cat->namedInput(attr::tensors)->node(); |
1088 | Node* fused_cat = createFusedConcat(cat); |
1089 | Value* fused_cat_out = fused_cat->output(); |
1090 | |
1091 | auto sorted_inputs = sortReverseTopological(fused_cat->inputs()); |
1092 | size_t input_idx = 0; |
1093 | bool any_fused = false; |
1094 | while (input_idx < sorted_inputs.size()) { |
1095 | Value* input = sorted_inputs[input_idx++]; |
1096 | if (!canFuseWithConcat(input, fused_cat)) { |
1097 | continue; |
1098 | } |
1099 | any_fused = true; |
1100 | auto maybe_group = tryFuse(fused_cat, input); |
1101 | AT_ASSERT(maybe_group && maybe_group == fused_cat); |
1102 | // We could have destroyed multiple inputs when performing this fusion, |
1103 | // so we have to recompute the list and iterate over it again. |
1104 | sorted_inputs = sortReverseTopological(fused_cat->inputs()); |
1105 | input_idx = 0; |
1106 | } |
1107 | |
1108 | if (any_fused) { |
1109 | cat->output()->replaceAllUsesWith(fused_cat_out); |
1110 | it.destroyCurrent(); |
1111 | if (list_construct->output()->uses().empty()) { |
1112 | list_construct->destroy(); |
1113 | } |
1114 | } else { |
1115 | fused_cat->destroy(); |
1116 | } |
1117 | } |
1118 | } |
1119 | |
1120 | void optimizeFusedGraphs() { |
1121 | for (Node* node : block_->nodes()) { |
1122 | if (node->kind() != prim::FusionGroup) { |
1123 | continue; |
1124 | } |
1125 | auto subgraph = node->g(attr::Subgraph); |
1126 | EliminateDeadCode(subgraph); |
1127 | EliminateCommonSubexpression(subgraph); |
1128 | ConstantPooling(subgraph); |
1129 | } |
1130 | } |
1131 | |
1132 | void run() { |
1133 | // TODO: old fuser is not maintained internally, somewhere it is being turned on |
1134 | // inadvertently for certain workflows. make this a no-op until we identify |
1135 | // location |
1136 | #if defined(FBCODE_CAFFE2) |
1137 | return; |
1138 | #endif |
1139 | |
1140 | // Run the pass until no changes are made. |
1141 | // This is necessary, because the algorithm can miss out on certain fusion |
1142 | // opportunities if ran only once. Consider this graph: |
1143 | // |
1144 | // %1 = f(...) |
1145 | // %2 = g(%1) |
1146 | // %3 = h(%1) |
1147 | // %4 = l(%3) |
1148 | // return (%4, %2) |
1149 | // |
1150 | // where f, g, h, l are simple map ops. |
1151 | // The first iteration will fuse %4 and %3, and see that %1 is an input, but |
1152 | // can't be fused, because it has a different use before the fusion group |
1153 | // in our topological ordering. Then, %2 will be considered, and fused with |
1154 | // %1. If we do another iteration, the algorithm will consider the fusion of |
1155 | // these two groups and fix the situation. |
1156 | bool any_changed = true; |
1157 | while (any_changed) { |
1158 | any_changed = false; |
1159 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
1160 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1161 | bool changed; |
1162 | std::tie(it, changed) = scanNode(*it); |
1163 | any_changed |= changed; |
1164 | } |
1165 | } |
1166 | |
1167 | fuseConcats(); |
1168 | |
1169 | optimizeFusedGraphs(); |
1170 | |
1171 | // The graph fuser can add intermediate prim::BroadcastingChunk nodes. |
1172 | // Replace them with broadcasts + chunks. |
1173 | replaceIntermediateBroadcastingChunks(); |
1174 | |
1175 | // Fuse starting chunks into the group. |
1176 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
1177 | it = scanNodeForChunks(*it); |
1178 | } |
1179 | |
1180 | // Remove outputs that have been added only because we need their size |
1181 | for (Node* n : block_->nodes()) { |
1182 | removeOutputsUsedOnlyInSize(n); |
1183 | } |
1184 | |
1185 | for (Node* node : block_->nodes()) { |
1186 | for (Block* sub_block : node->blocks()) { |
1187 | GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_) |
1188 | .run(); |
1189 | } |
1190 | } |
1191 | } |
1192 | }; |
1193 | |
1194 | void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) { |
1195 | auto nodes = block->nodes(); |
1196 | for (auto it = nodes.begin(); it != nodes.end(); ++it) { |
1197 | Node* node = *it; |
1198 | for (Block* subblock : node->blocks()) { |
1199 | PeepholeOptimizeShapeExpressions(subblock, db); |
1200 | } |
1201 | if (node->kind() == prim::BroadcastSizes) { |
1202 | // Remove no-op broadcasts. |
1203 | if (node->inputs().size() == 1) { |
1204 | node->output()->replaceAllUsesWith(node->input()); |
1205 | it.destroyCurrent(); |
1206 | continue; |
1207 | } |
1208 | // Deduplicate inputs, but use their unique() values to ensure |
1209 | // this process only depends on the graph. |
1210 | std::map<size_t, Value*> unique_to_value; |
1211 | for (Value* input : node->inputs()) { |
1212 | unique_to_value.emplace(input->unique(), input); |
1213 | } |
1214 | if (unique_to_value.size() != node->inputs().size()) { |
1215 | std::vector<Value*> inputs; |
1216 | inputs.reserve(unique_to_value.size()); |
1217 | for (auto& entry : unique_to_value) { |
1218 | inputs.push_back(entry.second); |
1219 | } |
1220 | if (inputs.size() == 1) { |
1221 | node->output()->replaceAllUsesWith(inputs[0]); |
1222 | } else { |
1223 | WithInsertPoint insert_guard{node}; |
1224 | node->output()->replaceAllUsesWith(broadcastSizes(inputs, db)); |
1225 | } |
1226 | it.destroyCurrent(); |
1227 | --it; // Revisit the node with deduplicated inputs |
1228 | continue; |
1229 | } |
1230 | // Remove compose simple chains of broadcasts into a single node. |
1231 | const auto& uses = node->output()->uses(); |
1232 | if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) { |
1233 | Node* user = uses[0].user; |
1234 | user->removeInput(uses[0].offset); |
1235 | // NB: we don't care about deduplication in here, as we will visit user |
1236 | // later. |
1237 | for (Value* i : node->inputs()) { |
1238 | user->addInput(i); |
1239 | } |
1240 | it.destroyCurrent(); |
1241 | } |
1242 | } |
1243 | } |
1244 | } |
1245 | |
1246 | } // anonymous namespace |
1247 | |
1248 | static bool cpu_fuser_enabled_legacy = false; |
1249 | |
1250 | bool canFuseOnCPULegacy() { |
1251 | return cpu_fuser_enabled_legacy; |
1252 | } |
1253 | |
1254 | void overrideCanFuseOnCPULegacy(bool value) { |
1255 | cpu_fuser_enabled_legacy = value; |
1256 | } |
1257 | |
1258 | void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) { |
1259 | AliasDb db(graph); |
1260 | GraphFuser(&db, graph->block(), strict_fuser_check).run(); |
1261 | Lint(&db); |
1262 | // After FuseGraph some common subexpressions may come back |
1263 | EliminateCommonSubexpression(graph); |
1264 | // We might have emitted a fair amount of useless shape propagating code, so |
1265 | // remove it |
1266 | EliminateDeadCode(graph); |
1267 | // Improve the quality of shape propagation code that was left |
1268 | PeepholeOptimizeShapeExpressions(graph->block(), &db); |
1269 | } |
1270 | |
1271 | void CustomFuseGraph( |
1272 | std::shared_ptr<Graph>& graph, |
1273 | const std::function<bool(Node*)>& fn, |
1274 | Symbol kind, |
1275 | size_t arg_limit) { |
1276 | AliasDb db(graph); |
1277 | auto g = GraphFuser( |
1278 | &db, |
1279 | graph->block(), |
1280 | [=](GraphFuser* gf, Node* n) { return fn(n) || n->kind() == kind; }, |
1281 | kind); |
1282 | g.setInputArgLimit(arg_limit); |
1283 | g.run(); |
1284 | Lint(&db); |
1285 | } |
1286 | |
1287 | } // namespace jit |
1288 | } // namespace torch |
1289 | |