1 | #include <torch/csrc/jit/passes/cuda_graph_fuser.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/irange.h> |
5 | #include <instrumentation.h> |
6 | #include <parser.h> |
7 | #include <partition.h> |
8 | #include <transform_view.h> |
9 | #include <utils.h> |
10 | #include <torch/csrc/jit/frontend/ir_emitter.h> |
11 | #include <torch/csrc/jit/ir/alias_analysis.h> |
12 | #include <torch/csrc/jit/jit_log.h> |
13 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
14 | #include <torch/csrc/jit/passes/constant_pooling.h> |
15 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
16 | #include <torch/csrc/jit/passes/pass_manager.h> |
17 | #include <torch/csrc/jit/passes/remove_mutation.h> |
18 | #include <torch/csrc/jit/passes/restore_mutation.h> |
19 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
20 | #include <torch/csrc/jit/runtime/autodiff.h> |
21 | #include <torch/csrc/jit/runtime/custom_operator.h> |
22 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
23 | #include <torch/csrc/jit/runtime/operator.h> |
24 | |
25 | #include <torch/csrc/jit/ir/alias_analysis.h> |
26 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
27 | |
28 | #include <queue> |
29 | #include <unordered_map> |
30 | |
31 | namespace torch { |
32 | namespace jit { |
33 | namespace fuser { |
34 | namespace cuda { |
35 | |
36 | constexpr size_t NVRTC_KERNEL_ARG_LIMIT = 128; |
37 | |
38 | namespace { |
39 | |
40 | bool usedOnlyInDtype(Value* v) { |
41 | const auto& uses = v->uses(); |
42 | if (uses.empty()) { |
43 | return false; |
44 | } |
45 | return std::all_of(uses.begin(), uses.end(), [](const Use& u) { |
46 | return u.user->matches("prim::dtype(Tensor a) -> int" ); |
47 | }); |
48 | } |
49 | |
50 | Value* broadcastSizes(at::ArrayRef<Value*> sizes) { |
51 | AT_ASSERT(!sizes.empty()); |
52 | Graph* graph = sizes[0]->owningGraph(); |
53 | Node* insertion_point = sizes[0]->node()->next(); |
54 | for (size_t i = 1; i < sizes.size(); i++) { |
55 | if (insertion_point->isBefore(sizes[i]->node()->next())) { |
56 | insertion_point = sizes[i]->node()->next(); |
57 | } |
58 | } |
59 | WithInsertPoint guard(insertion_point); |
60 | Node* broadcast_n = |
61 | graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); |
62 | broadcast_n->output()->setType(ListType::ofInts()); |
63 | return broadcast_n->output(); |
64 | } |
65 | |
66 | Value* createConditionalConstant(Node* profile_ivalue) { |
67 | TORCH_INTERNAL_ASSERT(profile_ivalue->kind() == prim::profile_ivalue); |
68 | |
69 | auto graph = profile_ivalue->owningGraph(); |
70 | |
71 | IValue val; // default to None |
72 | if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int_list" ))) { |
73 | // int[] |
74 | val = IValue(profile_ivalue->is(Symbol::attr("profiled_int_list" ))); |
75 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool_list" ))) { |
76 | // bool[] |
77 | auto int_list = profile_ivalue->is(Symbol::attr("profiled_bool_list" )); |
78 | std::vector<bool> bool_list(int_list.begin(), int_list.end()); |
79 | val = IValue(bool_list); |
80 | } else if (profile_ivalue->hasAttribute( |
81 | Symbol::attr("profiled_reduction_size" ))) { |
82 | // int[] |
83 | val = IValue(profile_ivalue->is(Symbol::attr("profiled_reduction_size" ))); |
84 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_view_size" ))) { |
85 | // int[] |
86 | val = IValue(profile_ivalue->is(Symbol::attr("profiled_view_size" ))); |
87 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_bool" ))) { |
88 | // bool |
89 | val = IValue( |
90 | static_cast<bool>(profile_ivalue->i(Symbol::attr("profiled_bool" )))); |
91 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_int" ))) { |
92 | // int |
93 | val = IValue( |
94 | static_cast<int>(profile_ivalue->i(Symbol::attr("profiled_int" )))); |
95 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_str" ))) { |
96 | // str |
97 | val = IValue(static_cast<std::string>( |
98 | profile_ivalue->s(Symbol::attr("profiled_str" )))); |
99 | } else if (profile_ivalue->hasAttribute(Symbol::attr("profiled_ival" ))) { |
100 | // ival |
101 | val = IValue(profile_ivalue->ival(Symbol::attr("profiled_ival" ))); |
102 | } else { |
103 | GRAPH_DEBUG("no profile info in profile_ivalue node: " , *profile_ivalue); |
104 | TORCH_WARN_ONCE( |
105 | __func__, |
106 | " profile_node " , |
107 | *profile_ivalue, |
108 | " does not have profile information" ); |
109 | return nullptr; |
110 | } |
111 | |
112 | return graph->insertConstant(val); |
113 | } |
114 | |
115 | struct CudaGraphFuser { |
116 | using FusionCallback = std::function<bool(Node*)>; |
117 | |
118 | Block* block_; |
119 | std::unique_ptr<AliasDb> aliasDb_; |
120 | std::shared_ptr<Graph> graph_; |
121 | Symbol kind_ = prim::CudaFusionGroup; |
122 | std::unordered_map<Value*, Value*> fusion_value_to_runtime_shape_; |
123 | |
124 | // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. |
125 | // The specific limit is a function of constant memory size, amount available |
126 | // to pass arguments, and some implementation dependence. Select a safe |
127 | // limit here. |
128 | // This limit is also applied to other devices in the fuser by default. |
129 | // Change with setInputArgLimit |
130 | size_t subgraph_arg_limit_ = NVRTC_KERNEL_ARG_LIMIT; |
131 | |
132 | CudaGraphFuser(Block* block, std::shared_ptr<Graph> graph) |
133 | : block_(block), graph_(std::move(graph)) {} |
134 | |
135 | void setInputArgLimit(size_t limit) { |
136 | subgraph_arg_limit_ = limit; |
137 | } |
138 | |
139 | value_list tensorInputs(Node* node) { |
140 | return filter(node->inputs(), [](Value* v) { |
141 | return v->type()->isSubtypeOf(*TensorType::get()); |
142 | }); |
143 | } |
144 | |
145 | bool calculatesSize(Node* node) { |
146 | return node->matches("aten::size(Tensor self) -> int[]" ); |
147 | } |
148 | |
149 | bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) { |
150 | auto defining_node = producer->node(); |
151 | for (auto o : defining_node->outputs()) { |
152 | for (auto u : o->uses()) { |
153 | if (u.user != consumer && !calculatesSize(u.user)) |
154 | return false; |
155 | } |
156 | } |
157 | return true; |
158 | } |
159 | |
160 | Graph& getSubgraph(Node* n) { |
161 | AT_ASSERT(n->kind() == kind_); |
162 | return *n->g(attr::Subgraph); |
163 | } |
164 | |
165 | void mergeFusionGroups(Node* consumer_group, Node* producer_group) { |
166 | // Now we have two fusion groups! |
167 | // Revert the fusion - place all inner nodes of producer back in the outer |
168 | // graph. |
169 | std::vector<Node*> temporary_nodes; |
170 | auto producer_subgraph = &getSubgraph(producer_group); |
171 | |
172 | // Initialize a map of inner graph values to outer graph values |
173 | std::unordered_map<Value*, Value*> inner_to_outer; |
174 | auto inner_inputs = producer_subgraph->inputs(); |
175 | auto outer_inputs = producer_group->inputs(); |
176 | for (const auto i : c10::irange(inner_inputs.size())) { |
177 | inner_to_outer[inner_inputs[i]] = outer_inputs[i]; |
178 | } |
179 | |
180 | // Clone all nodes |
181 | for (auto inner : producer_subgraph->nodes()) { |
182 | Node* outer = block_->owningGraph()->createClone( |
183 | inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); }); |
184 | outer->insertBefore(producer_group); |
185 | temporary_nodes.emplace_back(outer); |
186 | auto inner_outputs = inner->outputs(); |
187 | auto outer_outputs = outer->outputs(); |
188 | for (const auto i : c10::irange(inner_outputs.size())) { |
189 | inner_to_outer[inner_outputs[i]] = outer_outputs[i]; |
190 | } |
191 | } |
192 | |
193 | // Replace uses of producer_group outputs and destroy the producer |
194 | auto subgraph_outputs = producer_subgraph->outputs(); |
195 | for (const auto i : c10::irange(subgraph_outputs.size())) { |
196 | auto outer_output = inner_to_outer.at(subgraph_outputs[i]); |
197 | producer_group->outputs()[i]->replaceAllUsesWith(outer_output); |
198 | } |
199 | producer_group->destroy(); |
200 | producer_group = |
201 | nullptr; // Just to get a clear error in case someone uses it |
202 | |
203 | // Inline the temporary nodes into the first group |
204 | auto consumer_subgraph = &getSubgraph(consumer_group); |
205 | for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); |
206 | ++it) { |
207 | Node* node = *it; |
208 | Node* merged = mergeNodeIntoGroup(consumer_group, node); |
209 | // If any of the outputs are still used then we need to add them |
210 | auto outputs = node->outputs(); |
211 | for (const auto i : c10::irange(outputs.size())) { |
212 | auto output = outputs[i]; |
213 | if (output->uses().size() == 0) |
214 | continue; |
215 | consumer_subgraph->registerOutput(merged->outputs()[i]); |
216 | auto new_output = consumer_group->addOutput(); |
217 | output->replaceAllUsesWith(new_output); |
218 | new_output->setType(output->type()); |
219 | } |
220 | node->destroy(); |
221 | } |
222 | } |
223 | |
224 | // insert a producer node into a consuming fusion group. |
225 | // DOES NOT WORK if n is a consumer of an output of the fusion group |
226 | // returns the node _inside_ the group that represents the node |
227 | Node* mergeNodeIntoGroup(Node* group, Node* n) { |
228 | AT_ASSERT(n->kind() != kind_); |
229 | auto& subgraph = getSubgraph(group); |
230 | // map from nodes in the surrounding graph to parameters in the fusion |
231 | // group's subgraph that correspond to them |
232 | std::unordered_map<Value*, Value*> inputs_map; |
233 | size_t i = 0; |
234 | size_t tensor_insert_idx = 0; |
235 | for (auto input : group->inputs()) { |
236 | inputs_map[input] = subgraph.inputs()[i++]; |
237 | if (input->type()->isSubtypeOf(*TensorType::get())) |
238 | tensor_insert_idx = i; |
239 | } |
240 | // add n's inputs to the fusion group's input list if we don't already have |
241 | // them |
242 | // we insert tensors first because the fuser assumes that to be the case |
243 | // (as a legacy from tensors only) |
244 | WithInsertPoint guard(*subgraph.nodes().begin()); |
245 | for (auto input : n->inputs()) { |
246 | if (inputs_map.count(input) == 0) { |
247 | // TODO: we are following the convention for no good reason; |
248 | // we don't need tensor to come before any other inputs. |
249 | if (input->type()->isSubtypeOf(*TensorType::get())) { |
250 | auto in_group = subgraph.insertInput(tensor_insert_idx); |
251 | in_group->setType(input->type()); |
252 | inputs_map[input] = in_group; |
253 | group->insertInput(tensor_insert_idx, input); |
254 | tensor_insert_idx++; |
255 | } else if ( |
256 | // TODO: extend the supporting inputs here. |
257 | (input->type()->isSubtypeOf(*FloatType::get()) && |
258 | input->node()->kind() != prim::Constant)) { |
259 | auto in_group = subgraph.addInput(); |
260 | in_group->setType(input->type()); |
261 | inputs_map[input] = in_group; |
262 | group->addInput(input); |
263 | } else if (input->node()->kind() == prim::Constant) { |
264 | // inline the constants directly in the body of the fused group. |
265 | Node* in_const = |
266 | subgraph.createClone(input->node(), [&](Value* v) -> Value* { |
267 | if (v->node()->kind() != prim::profile_ivalue) { |
268 | throw std::runtime_error( |
269 | std::string( |
270 | "merging constant with unexpected input from node" ) + |
271 | v->node()->kind().toDisplayString()); |
272 | } |
273 | group->addInput(v->node()->output()); |
274 | |
275 | // we are doing this just to keep alias_analysis silent with |
276 | // their checks |
277 | auto in_group = subgraph.addInput(); |
278 | in_group->setType(v->type()); |
279 | return in_group; |
280 | }); |
281 | subgraph.insertNode(in_const); |
282 | inputs_map[input] = in_const->output(); |
283 | } else { |
284 | // TODO: we need to figure out what are supported input scalar |
285 | auto in_group = subgraph.addInput(); |
286 | in_group->setType(input->type()); |
287 | inputs_map[input] = in_group; |
288 | group->addInput(input); |
289 | } |
290 | } |
291 | } |
292 | // copy n into the graph, remapping its inputs to internal nodes |
293 | Node* in_graph = subgraph.createClone( |
294 | n, [&](Value* k) -> Value* { return inputs_map[k]; }); |
295 | // if n's outputs are already inputs to the fusion group, |
296 | // we need to remove them because n is now inside the fusion group. |
297 | // |
298 | // i.e., |
299 | // x = f(w); group(x, y, z) becomes group(w, y, z). |
300 | // x, y, z = f(w); group(x, y, z) becomes group(w). |
301 | // |
302 | // remapping nodes that used the input to the newly-merged node |
303 | // n is not an input when the fusion group is empty |
304 | auto inputs = group->inputs(); |
305 | for (const auto i : c10::irange(n->outputs().size())) { |
306 | auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]); |
307 | if (it != inputs.end()) { |
308 | size_t p = it - inputs.begin(); |
309 | group->removeInput(p); |
310 | subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]); |
311 | subgraph.eraseInput(p); |
312 | } |
313 | } |
314 | return subgraph.insertNode(in_graph); |
315 | } |
316 | |
317 | // turn consumer node n into a fusion group with just n inside |
318 | // to prepare for fusion and replace uses of n with the new group |
319 | Node* createSingletonFusionGroup(Node* n) { |
320 | auto group = block_->owningGraph()->createWithSubgraph(kind_); |
321 | // propogate position information for the new node so we can always |
322 | // have a valid mapping |
323 | group->insertBefore(n); |
324 | Node* mergedNode = mergeNodeIntoGroup(group, n); |
325 | for (const auto i : c10::irange(n->outputs().size())) { |
326 | getSubgraph(group).registerOutput(mergedNode->output(i)); |
327 | auto sel = group->addOutput(); |
328 | sel->copyMetadata(n->output(i)); |
329 | } |
330 | n->replaceAllUsesWith(group); |
331 | n->destroy(); |
332 | return group; |
333 | } |
334 | |
335 | at::optional<Node*> tryFuse(Node* consumer, Node* producer) { |
336 | // this handles cases where producer can be moved _into_ the fusion group of |
337 | // consumer. |
338 | // TODO: extend to fusion of consumer into _producer's_ fusion blob |
339 | // if the consumer allInputsAreThisProducer(consumer,producer) |
340 | // we can move the consumer up into the producer. |
341 | // but this requires better handling of merging fusion groups so it is not |
342 | // done now |
343 | bool shouldFuse = |
344 | fuser::cuda::isFusibleCudaFusionGroup(consumer, producer) && |
345 | // Rearrange nodes such that all uses of producer's outputs are after |
346 | // consumer. Fusion will rewrite those later uses to use the version of |
347 | // producer generated by the fused blob. In this case, producer becomes |
348 | // an output of the fusion group. |
349 | aliasDb_->moveBeforeTopologicallyValid(producer, consumer); |
350 | |
351 | if (!shouldFuse) { |
352 | return at::nullopt; |
353 | } |
354 | |
355 | if ((consumer->inputs().size() + consumer->outputs().size() + |
356 | producer->inputs().size() + producer->outputs().size()) > |
357 | subgraph_arg_limit_) { |
358 | return at::nullopt; |
359 | } |
360 | |
361 | auto group = consumer; |
362 | if (consumer->kind() != kind_) { |
363 | group = createSingletonFusionGroup(consumer); |
364 | } |
365 | |
366 | if (producer->kind() == kind_) { |
367 | mergeFusionGroups(group, producer); |
368 | return group; |
369 | } |
370 | Node* merged = mergeNodeIntoGroup(group, producer); |
371 | // remaining uses of this producer can occur because we allow |
372 | // fusion in cases where uses remain after the consumer |
373 | // if these exist, re-route them to the version of producer |
374 | // created in FusionGroup |
375 | |
376 | // We need to apply this to all outputs from producer->node(); |
377 | auto producer_outputs = producer->outputs(); |
378 | for (const auto i : c10::irange(producer_outputs.size())) { |
379 | if (producer_outputs[i]->uses().size() != 0) { |
380 | getSubgraph(group).registerOutput(merged->outputs()[i]); |
381 | Value* new_producer = group->addOutput(); |
382 | new_producer->copyMetadata(producer_outputs[i]); |
383 | producer_outputs[i]->replaceAllUsesWith(new_producer); |
384 | } |
385 | } |
386 | producer->destroy(); |
387 | return group; |
388 | } |
389 | |
390 | c10::optional<Node*> findFusedChunk(Node* group, Value* input) { |
391 | AT_ASSERT(group->kind() == kind_); |
392 | auto it = std::find(group->inputs().begin(), group->inputs().end(), input); |
393 | if (it == group->inputs().end()) { |
394 | return c10::nullopt; |
395 | } |
396 | size_t input_index = it - group->inputs().begin(); |
397 | auto& subgraph = getSubgraph(group); |
398 | auto* subgraph_input = subgraph.inputs().at(input_index); |
399 | // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use |
400 | auto* node = subgraph_input->uses().at(0).user; |
401 | if (node->kind() == prim::ConstantChunk) { |
402 | AT_ASSERT(subgraph_input->uses().size() == 1); |
403 | return node; |
404 | } |
405 | return c10::nullopt; |
406 | } |
407 | |
408 | void fuseChunkByReusingExistingFusedChunk( |
409 | Node* group, |
410 | Node* chunk, |
411 | Node* existingFusedChunk) { |
412 | if (chunk->outputs().size() != existingFusedChunk->outputs().size()) { |
413 | return; |
414 | } |
415 | auto& subgraph = getSubgraph(group); |
416 | for (const auto i : c10::irange(chunk->outputs().size())) { |
417 | // Find the input to the FusionGroup (group) |
418 | auto* replacement_val = existingFusedChunk->outputs().at(i); |
419 | auto* val = chunk->outputs().at(i); |
420 | auto it = std::find(group->inputs().begin(), group->inputs().end(), val); |
421 | auto input_index = it - group->inputs().begin(); |
422 | |
423 | // Rewrite the graph to use replacement_val |
424 | auto group_input = subgraph.inputs().at(input_index); |
425 | group_input->replaceAllUsesWith(replacement_val); |
426 | |
427 | // Remove the input, it's no longer needed |
428 | group->removeInput(input_index); |
429 | subgraph.eraseInput(input_index); |
430 | } |
431 | chunk->destroy(); |
432 | } |
433 | |
434 | value_list sortReverseTopological(ArrayRef<Value*> inputs) { |
435 | value_list result; |
436 | for (auto i : inputs) { |
437 | if (i->node()->owningBlock() == block_) { |
438 | result.push_back(i); |
439 | } |
440 | } |
441 | // Sort in reverse topological order |
442 | std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { |
443 | return a->node()->isAfter(b->node()); |
444 | }); |
445 | return result; |
446 | } |
447 | |
448 | at::ArrayRef<Value*> broadcast_tensors(value_list inputs) { |
449 | AT_ASSERT(inputs.size() > 0); |
450 | auto* g = inputs[0]->owningGraph(); |
451 | auto* input_list = |
452 | g->insertNode(g->createList(TensorType::get(), inputs))->output(); |
453 | auto* output_list = g->insert(aten::broadcast_tensors, {input_list}); |
454 | auto* unpack_node = g->insertNode( |
455 | g->create(prim::ListUnpack, {output_list}, inputs.size())); |
456 | return unpack_node->outputs(); |
457 | } |
458 | |
459 | void insertExplicitBroadcast(Node* node) { |
460 | WithInsertPoint insert_guard{node}; |
461 | auto tensors = tensorInputs(node); |
462 | auto new_tensors = broadcast_tensors(tensors); |
463 | |
464 | // Replace tensors inputs with broadcasted values |
465 | auto new_tensors_it = new_tensors.begin(); |
466 | for (const auto i : c10::irange(node->inputs().size())) { |
467 | if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) { |
468 | AT_ASSERT(new_tensors_it != new_tensors.end()); |
469 | node->replaceInput(i, *(new_tensors_it++)); |
470 | } |
471 | } |
472 | } |
473 | |
474 | Node* promoteChunkToBroadcastingChunk(Node* chunk) { |
475 | AT_ASSERT(chunk->kind() == prim::ConstantChunk); |
476 | |
477 | size_t nchunks = chunk->i(attr::chunks); |
478 | Node* bchunk = |
479 | chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks); |
480 | bchunk->addInput(chunk->input()); |
481 | for (const auto i : c10::irange(nchunks)) { |
482 | auto* old_output = chunk->outputs().at(i); |
483 | auto* new_output = bchunk->outputs().at(i); |
484 | new_output->copyMetadata(old_output); |
485 | old_output->replaceAllUsesWith(new_output); |
486 | } |
487 | bchunk->copyAttributes(*chunk); |
488 | bchunk->insertAfter(chunk); |
489 | chunk->destroy(); |
490 | return bchunk; |
491 | } |
492 | |
493 | // in places where op can be fused into a consumer but chunk is in the way |
494 | // distribute chunk to op's operands: |
495 | // replace a,b = chunk(op(x,y,z)) with: |
496 | // x', y', z' = broadcast_tensors([x, y, z]) |
497 | // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type) |
498 | // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type) |
499 | // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type) |
500 | // a = op(x0,y0,z0) (a,b have their same size but are now contiguous) |
501 | // b = op(x1,y1,x1) |
502 | // |
503 | // The graph fuser uses an intermediate prim::BroadcastingChunk node to |
504 | // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts |
505 | // all of its inputs and then chunks each input, in order, the same way. |
506 | // The above graph is equivalent to: |
507 | // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z) |
508 | // a = op(x0,y0,z0) |
509 | // b = op(x1,y1,x1) |
510 | // |
511 | // NB: The explicit broadcast is important for correctness. |
512 | // Let's say we have: |
513 | // %z = aten::mul(%x, %y) |
514 | // %z.1, %z.2 = aten::chunk(%z, ...) |
515 | // ... = prim::CudaFusionGroup(%z.1, %z.2, ...) |
516 | // It's possible that %x and %y do not have the same size as %z and |
517 | // need to be expanded first so that they can be chunked like %z |
518 | // |
519 | // NB: Chunk motion only occurs with fusable consumers, which implies |
520 | // that there is always some other operation, e.g., a+b, that happens |
521 | // after the chunk, and will be put into the fusion group. This is |
522 | // important, because distributing the chunk changes the contiguity |
523 | // of a and b, and so the results would be invalid, except that we know |
524 | // that simple_mappable operations will restore contiguity before |
525 | // we exit the fusion group. |
526 | // |
527 | // NB: The intermediate BroadcastingChunk is important for moving chunks past |
528 | // more than one operation: the graph fuser is not able to easily move |
529 | // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusable |
530 | // ops |
531 | // x = f(v, w) |
532 | // z = g(x, y) |
533 | // a, b = chunk(z) |
534 | // c = h(a, b) |
535 | // becomes (with the broadcast_tensors + chunk approach): |
536 | // x = f(v, w) |
537 | // x', y' = broadcast_tensors([x, y]) |
538 | // ax, bx = chunk(x') |
539 | // ay, by = chunk(y') |
540 | // a = g(ax, ay) |
541 | // b = g(bx, by) |
542 | // c = h(a, b) |
543 | // The broadcast_tensors node makes it harder to move f into the resulting |
544 | // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior |
545 | // together results in: |
546 | // x = f(v, w) |
547 | // ax, bx, ay, by = BroadcastingChunk(x, y) |
548 | // a = g(ax, ay) |
549 | // b = g(bx, by) |
550 | // c = h(a, b) |
551 | // making it easier to move f after the BroadcastingChunk: |
552 | // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w) |
553 | // ax = f(av, aw) |
554 | // by = f(bv, bw) |
555 | // a = g(ax, ay) |
556 | // b = g(bx, by) |
557 | // c = h(a, b) |
558 | |
559 | bool tryToMoveChunk(Node* consumer, Value* producer) { |
560 | // is the output from a chunk/bchunk node? |
561 | auto* chunk = producer->node(); |
562 | if (chunk->kind() != prim::ConstantChunk && |
563 | chunk->kind() != prim::BroadcastingChunk) |
564 | return false; |
565 | |
566 | // try to find a producer to move after the chunk/bchunk. The producer must |
567 | // be fusable into the consumer. |
568 | auto it = std::find_if( |
569 | chunk->inputs().begin(), |
570 | chunk->inputs().end(), |
571 | [&](Value* producer_for_chunk) { |
572 | return fuser::cuda::isFusibleCudaFusionGroup( |
573 | consumer, producer_for_chunk->node()) && |
574 | isElementWiseNode(consumer) && |
575 | allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk); |
576 | }); |
577 | if (it == chunk->inputs().end()) { |
578 | return false; |
579 | } |
580 | Value* producer_for_chunk = *it; |
581 | size_t producer_index = it - chunk->inputs().begin(); |
582 | |
583 | // all uses of the chunk must be in this consumer |
584 | for (auto s : chunk->outputs()) { |
585 | for (auto u : s->uses()) { |
586 | if (u.user != consumer) |
587 | return false; |
588 | } |
589 | } |
590 | // multiple return operators |
591 | Node* producer_for_chunk_node = producer_for_chunk->node(); |
592 | AT_ASSERT(producer_for_chunk_node->outputs().size() == 1); |
593 | |
594 | // Convert chunk to bchunk, if it isn't one already. The bchunk represents a |
595 | // broadcast and one or more chunk operations. |
596 | auto* bchunk = chunk; |
597 | if (chunk->kind() == prim::ConstantChunk) { |
598 | bchunk = promoteChunkToBroadcastingChunk(chunk); |
599 | } |
600 | size_t nchunks = bchunk->i(attr::chunks); |
601 | TORCH_INTERNAL_ASSERT(nchunks > 0, "number of chunks cannot be zero" ); |
602 | WithInsertPoint guard(bchunk->next()); |
603 | |
604 | std::vector<Value*> producer_chunk_outputs; |
605 | for (const auto i : c10::irange(nchunks)) { |
606 | producer_chunk_outputs.push_back( |
607 | bchunk->output(nchunks * producer_index + i)); |
608 | } |
609 | |
610 | // Add each of op's operands to the bchunk node. |
611 | // chunked_inputs[input_nr][chunk_output_idx] |
612 | // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr]) |
613 | std::vector<std::vector<Value*>> chunked_inputs; |
614 | |
615 | // We have asserted single output earlier |
616 | auto producer_output_sizes = |
617 | producer_for_chunk_node->output()->type()->cast<TensorType>()->sizes(); |
618 | |
619 | for (auto input : producer_for_chunk_node->inputs()) { |
620 | // XXX: we only work with pointwise ops in here, so we know it is valid to |
621 | // push the concat only through tensor arguments (and all other args can |
622 | // be safely ignored). |
623 | if (!input->type()->isSubtypeOf(*TensorType::get())) |
624 | continue; |
625 | |
626 | // if 'input' is already an input to the bchunk, reuse it. |
627 | auto bchunk_inputs = bchunk->inputs(); |
628 | auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input); |
629 | if (it != bchunk_inputs.end()) { |
630 | chunked_inputs.emplace_back(); |
631 | auto input_index = std::distance(bchunk_inputs.begin(), it); |
632 | for (const auto chunk : c10::irange(nchunks)) { |
633 | chunked_inputs.back().push_back( |
634 | bchunk->outputs().at(nchunks * input_index + chunk)); |
635 | } |
636 | continue; |
637 | } |
638 | |
639 | // NB: I decided not to use cloneFrom here, because if we make cloneFrom |
640 | // copy selects one day, it is definitely not what you want here (selects |
641 | // have different types). |
642 | // TODO: Perhaps we should use cloneFrom now, as it seems unlikely |
643 | // to copy select nodes now that we have refactored to have a Value |
644 | // distinct from Node. |
645 | bchunk->addInput(input); |
646 | chunked_inputs.emplace_back(); // alas, to not be C++17 |
647 | |
648 | // properly compute strides for BroadcastingChunk |
649 | // |
650 | // We copy stride of each dimension from input to output for |
651 | // BroadcastingChunk. A note is that Chunk should not alter strides, |
652 | // However, broadcasted dimension should have a stride 0. We could have |
653 | // broadcasting happening on existing dimensions in input (case1), as well |
654 | // as extended dimension that does not exist in input (case2). |
655 | // e.g. |
656 | // If we look at an input tensor t0 with shape [3, 1] broadcasted to |
657 | // output tensor t1 with shape [4, 1, 3, 3], |
658 | // We set stride to zero in case of broadcast, which could happen in: |
659 | // case1: t1.dim[3] (broadcasted as in the description above) |
660 | // case2: t1.dim[0] (broadcasted implicitly) |
661 | std::vector<int64_t> strides; |
662 | auto input_type = input->type()->cast<TensorType>(); |
663 | auto input_sizes = input_type->sizes(); |
664 | auto input_strides = input_type->strides(); |
665 | if (producer_output_sizes.isComplete() && input_sizes.isComplete() && |
666 | input_strides.isComplete()) { |
667 | auto input_c_sizes = input_sizes.concrete_sizes().value(); |
668 | auto input_c_strides = input_strides.concrete_sizes().value(); |
669 | auto output_c_sizes = producer_output_sizes.concrete_sizes().value(); |
670 | int output_index = int(output_c_sizes.size()) - 1; |
671 | strides.resize(output_index + 1); |
672 | AT_ASSERT(output_index >= int(input_c_sizes.size()) - 1); |
673 | for (int input_index = int(input_c_sizes.size()) - 1; input_index >= 0; |
674 | input_index--, output_index--) { |
675 | // in braodcast case 1, we set stride to 0; |
676 | // otherwise, stride remain the same. |
677 | if (input_c_sizes[input_index] == 1 && |
678 | output_c_sizes[output_index] != 1) { |
679 | strides[output_index] = 0; |
680 | } else { |
681 | strides[output_index] = input_c_strides[input_index]; |
682 | } |
683 | } |
684 | |
685 | // continue on expanding dimensions to set stride to 0 for case2 |
686 | while (output_index >= 0) { |
687 | strides[output_index] = |
688 | output_c_sizes[output_index] == 1 ? strides[output_index + 1] : 0; |
689 | output_index--; |
690 | } |
691 | } |
692 | |
693 | for (auto chunk_sel : producer_chunk_outputs) { |
694 | Value* input_chunk_sel = bchunk->addOutput(); |
695 | auto chunk_sel_type = chunk_sel->type()->cast<TensorType>(); |
696 | if (strides.empty() || !chunk_sel_type->sizes().isComplete()) { |
697 | input_chunk_sel->setType(chunk_sel_type); |
698 | } else { |
699 | input_chunk_sel->setType(chunk_sel_type->withSizesStrides( |
700 | chunk_sel_type->sizes().concrete_sizes().value(), strides)); |
701 | } |
702 | chunked_inputs.back().push_back(input_chunk_sel); |
703 | } |
704 | } |
705 | |
706 | // apply the op to each chunk of the chunked operands, |
707 | // and then rewrite the graph to use them! |
708 | for (auto chunk_sel : producer_chunk_outputs) { |
709 | auto original_inputs = producer_for_chunk_node->inputs(); |
710 | Node* chunked_op = |
711 | block_->owningGraph()->create(producer_for_chunk_node->kind()); |
712 | chunked_op->copyAttributes(*producer_for_chunk_node); |
713 | chunked_op->output()->setType(chunk_sel->type()); |
714 | auto chunked_inputs_it = chunked_inputs.begin(); |
715 | for (Value* original_input : original_inputs) { |
716 | if (original_input->type()->isSubtypeOf(*TensorType::get())) { |
717 | AT_ASSERT(chunked_inputs_it != chunked_inputs.end()); |
718 | chunked_op->addInput( |
719 | // NOLINTNEXTLINE(clang-analyzer-core.DivideZero) |
720 | chunked_inputs_it->at(chunk_sel->offset() % nchunks)); |
721 | ++chunked_inputs_it; |
722 | } else { |
723 | chunked_op->addInput(original_input); |
724 | } |
725 | } |
726 | bchunk->owningGraph()->insertNode(chunked_op); |
727 | chunk_sel->replaceAllUsesWith(chunked_op->output()); |
728 | } |
729 | |
730 | bchunk->removeInput(producer_index); |
731 | for (const auto _ : c10::irange(nchunks)) { |
732 | bchunk->eraseOutput(nchunks * producer_index); |
733 | } |
734 | |
735 | // The output of producer_for_chunk_node could have been used in some |
736 | // aten::size operators, so we need to clean those up as well (we simply |
737 | // broadcast all its tensor inputs). |
738 | // We need to insert these early in the graph, i.e. immediately after |
739 | // the producer_for_chunk_node as we will have the _size_if_not_same |
740 | // that may be before the bchunk. |
741 | WithInsertPoint guard2(producer_for_chunk_node); |
742 | auto size_calc_uses = producer_for_chunk_node->output()->uses(); |
743 | if (!size_calc_uses.empty()) { |
744 | auto tensor_inputs = filter( |
745 | producer_for_chunk_node->inputs(), |
746 | [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); }); |
747 | auto tensor_sizes = fmap(tensor_inputs, [](Value* v) { |
748 | return v->owningGraph()->insert(aten::size, {v}); |
749 | }); |
750 | AT_ASSERT(!tensor_sizes.empty()); |
751 | Value* output_size = tensor_sizes.size() == 1 |
752 | ? tensor_sizes[0] |
753 | : broadcastSizes(tensor_sizes); |
754 | for (Use u : size_calc_uses) { |
755 | u.user->output()->replaceAllUsesWith(output_size); |
756 | u.user->destroy(); |
757 | } |
758 | } |
759 | producer_for_chunk_node->destroy(); |
760 | return true; |
761 | } |
762 | |
763 | // returns where to continue scanning, and whether any fusion was made |
764 | std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) { |
765 | if (fuser::cuda::isFusibleCudaFusionGroup(consumer)) { |
766 | // handle inputs in reverse topological order as well... |
767 | // otherwise in f(a,a+b) it will appear a is used twice if we consider |
768 | // the f-a fusion before the f-(a+b) fusion first. |
769 | auto inputs = sortReverseTopological(consumer->inputs()); |
770 | for (auto producer : inputs) { |
771 | if (tryToMoveChunk(consumer, producer)) { |
772 | // the chunk before this consumer was re-arranged to allow fusion, |
773 | // we scan this consumer again to perform the fusion |
774 | return std::make_pair(consumer->reverseIterator(), true); |
775 | } |
776 | if (getSingletonFusion() && consumer->kind() != kind_) { |
777 | consumer = createSingletonFusionGroup(consumer); |
778 | } |
779 | auto fusion_group = tryFuse(consumer, producer->node()); |
780 | if (fusion_group) { |
781 | // after fusion, consumer moves into a FusionGroup, so inputs is no |
782 | // longer valid so we rescan the new FusionGroup for more fusions... |
783 | return std::make_pair(fusion_group.value()->reverseIterator(), true); |
784 | } |
785 | |
786 | // horizontal fusion only applies on non-scalar tensor inputs |
787 | if (getHorizontalFusion() && |
788 | producer->type()->isSubtypeOf(*TensorType::get()) && |
789 | !is_cpu_scalar(*producer->type()->cast<TensorType>())) { |
790 | // fusing nodes sharing inputs, this could save memory bandwidth by |
791 | // reducing number of tensor read. |
792 | for (const auto& u : producer->uses()) { |
793 | // only merge nodes before consumer, since any sibling after |
794 | // consumer has already considered merging this consumer to them |
795 | // already. |
796 | if (u.user->isBefore(consumer)) { |
797 | auto fusion_group = tryFuse(consumer, u.user); |
798 | if (fusion_group) { |
799 | return std::make_pair( |
800 | fusion_group.value()->reverseIterator(), true); |
801 | } |
802 | } |
803 | } |
804 | } |
805 | } |
806 | } |
807 | return std::make_pair(++consumer->reverseIterator(), false); |
808 | } |
809 | |
810 | void replaceIntermediateBroadcastingChunks() { |
811 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
812 | auto* node = *it; |
813 | ++it; // We might delete node, so increment the iterator now. |
814 | if (node->kind() != prim::BroadcastingChunk) { |
815 | continue; |
816 | } |
817 | auto* bchunk = node; |
818 | insertExplicitBroadcast(bchunk); |
819 | |
820 | auto* graph = block_->owningGraph(); |
821 | size_t nchunks = bchunk->i(attr::chunks); |
822 | WithInsertPoint guard(bchunk->next()); |
823 | |
824 | // Split the bchunk into bchunks.inputs().size() number of chunk nodes. |
825 | for (const auto input_offset : c10::irange(bchunk->inputs().size())) { |
826 | auto* input = bchunk->inputs().at(input_offset); |
827 | |
828 | Node* new_chunk = |
829 | graph->insertNode(graph->create(prim::ConstantChunk, input, 0)); |
830 | new_chunk->copyAttributes(*bchunk); |
831 | for (const auto output_offset : c10::irange(nchunks)) { |
832 | auto new_output = new_chunk->addOutput(); |
833 | auto old_output = |
834 | bchunk->outputs().at(input_offset * nchunks + output_offset); |
835 | new_output->copyMetadata(old_output); |
836 | old_output->replaceAllUsesWith(new_output); |
837 | } |
838 | } |
839 | bchunk->destroy(); |
840 | } |
841 | } |
842 | |
843 | bool usedInDtype(Value* v) { |
844 | const auto& uses = v->uses(); |
845 | return std::any_of(uses.begin(), uses.end(), [](const Use& u) { |
846 | return u.user->matches("prim::dtype(Tensor a) -> int" ); |
847 | }); |
848 | } |
849 | |
850 | bool usedOnlyInDtypeAndSize(Value* v) { |
851 | const auto& uses = v->uses(); |
852 | return std::all_of(uses.begin(), uses.end(), [](const Use& u) { |
853 | return u.user->matches("prim::dtype(Tensor a) -> int" ) || |
854 | u.user->matches("aten::size(Tensor self) -> int[]" ); |
855 | }); |
856 | } |
857 | |
858 | // Builds up expressions that compute shapes of all intermediates (and |
859 | // outputs) of the fusion group, based on the sizes of inputs. You should run |
860 | // DCE to remove those that you end up not using. |
861 | // TODO: Add shape support for view, reshape, unsqueeze, and squeeze |
862 | std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) { |
863 | WithInsertPoint insert_guard{fusion_group->next()}; |
864 | std::unordered_map<Value*, Value*> shape_of; |
865 | |
866 | Graph* graph = fusion_group->owningGraph(); |
867 | auto subgraph = fusion_group->g(attr::Subgraph); |
868 | |
869 | auto inputs = fusion_group->inputs(); |
870 | auto sinputs = subgraph->inputs(); |
871 | AT_ASSERT(inputs.size() == sinputs.size()); |
872 | for (const auto i : c10::irange(inputs.size())) { |
873 | if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { |
874 | auto sinput_value = graph->insert(aten::size, {inputs[i]}); |
875 | shape_of[sinputs[i]] = sinput_value; |
876 | sinput_value->node()->moveBefore(fusion_group); |
877 | } |
878 | } |
879 | |
880 | // When we have a guarantee that an output won't be removed, because it's |
881 | // used in expressions that don't involve size checks, we can use its size |
882 | // instead of computing a long chain of broadcasts, starting from the |
883 | // beginning of the kernel. |
884 | auto outputs = fusion_group->outputs(); |
885 | auto soutputs = subgraph->outputs(); |
886 | AT_ASSERT(outputs.size() == soutputs.size()); |
887 | for (const auto i : c10::irange(outputs.size())) { |
888 | if (usedOnlyInDtypeAndSize(outputs[i])) |
889 | continue; |
890 | if (soutputs[i]->type()->isSubtypeOf(TensorType::get())) { |
891 | shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]}); |
892 | } |
893 | } |
894 | |
895 | // Place all the shape expressions for intermediates in fusion |
896 | // before the CudaFusionGroup |
897 | graph->setInsertPoint(fusion_group); |
898 | |
899 | // hmmm, do I need to setInsertPoint... |
900 | const auto map_inputs = [&](Value* v) -> Value* { |
901 | // if constant ever has an input, it has to come from |
902 | // profile_ivalue dependency |
903 | if (v->node()->kind() == prim::Param && |
904 | fusion_group->input(v->offset())->node()->kind() == |
905 | prim::profile_ivalue) { |
906 | // we need to map it along profile_ivalue dependency |
907 | return fusion_group->input(v->offset()); |
908 | } else { |
909 | throw std::runtime_error( |
910 | std::string("unexpected input from node" ) + |
911 | v->node()->kind().toDisplayString()); |
912 | } |
913 | }; |
914 | |
915 | for (Node* n : subgraph->nodes()) { |
916 | // XXX: Use of shape_of.emplace is crucial to the output shape |
917 | // optimization! |
918 | if (n->kind() == prim::FusedConcat) { |
919 | // This is a bit more involved, because we have to account for the case |
920 | // when inputs have different shapes, but fortunately those tensors are |
921 | // always outputs, and so we can simply avoid replacing their queries, |
922 | // because it won't help us. |
923 | continue; |
924 | } |
925 | if (n->kind() == prim::Constant) { |
926 | continue; |
927 | } |
928 | if (n->kind() == prim::ConstantChunk) { |
929 | TORCH_INTERNAL_ASSERT( |
930 | shape_of.count(n->input()) > 0, |
931 | "buildShapeExpressions failed at accessing input shapes" ); |
932 | Node* sizes_node = graph->insertNode( |
933 | graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); |
934 | sizes_node->i_(attr::dim, n->i(attr::dim)); |
935 | sizes_node->i_(attr::chunks, n->i(attr::chunks)); |
936 | Value* regular_size = sizes_node->outputs().at(0); |
937 | Value* last_size = sizes_node->outputs().at(1); |
938 | regular_size->setType(ListType::ofInts()); |
939 | last_size->setType(ListType::ofInts()); |
940 | auto outputs = n->outputs(); |
941 | for (Value* o : outputs.slice(0, outputs.size() - 1)) { |
942 | shape_of.emplace(o, regular_size); |
943 | } |
944 | shape_of.emplace(outputs.at(outputs.size() - 1), last_size); |
945 | continue; |
946 | } |
947 | // extended shape expression support to reduction operations |
948 | // TODO: `aten::sum` is too flexible, we should restrict for a better |
949 | // match |
950 | // TODO: Add python tests where we check for existing ops and their |
951 | // shape expression logic. |
952 | static std::unordered_set<Symbol> reduction_ops( |
953 | {aten::sum, aten::mean, aten::var, aten::std}); |
954 | if (reduction_ops.find(n->kind()) != reduction_ops.end()) { |
955 | // TODO: expand support to wire non-constant inputs, this is currently |
956 | // blocked by profiling executor not capable of profiling scalar inputs. |
957 | TORCH_INTERNAL_ASSERT( |
958 | n->input(1)->node()->kind() == prim::Constant && |
959 | n->input(2)->node()->kind() == prim::Constant, |
960 | "only supports reduction axes and keepdim being constant" ); |
961 | |
962 | Node* in1_const = graph->createClone(n->input(1)->node(), map_inputs); |
963 | graph->insertNode(in1_const); |
964 | Node* in2_const = graph->createClone(n->input(2)->node(), map_inputs); |
965 | graph->insertNode(in2_const); |
966 | |
967 | TORCH_INTERNAL_ASSERT( |
968 | shape_of.count(n->input(0)) > 0, |
969 | "buildShapeExpressions failed at accessing input shapes" ); |
970 | std::vector<Value*> inputs = { |
971 | shape_of.at(n->input(0)), in1_const->output(), in2_const->output()}; |
972 | Node* size_node = |
973 | graph->insertNode(graph->create(prim::ReductionSizes, inputs, 1)); |
974 | Value* size = size_node->output(0); |
975 | size->setType(ListType::ofInts()); |
976 | shape_of.emplace(n->output(), size); |
977 | continue; |
978 | } |
979 | // TODO: output(1) & output(2) should also be marked |
980 | if (n->kind() == aten::native_layer_norm) { |
981 | TORCH_INTERNAL_ASSERT( |
982 | shape_of.count(n->input(0)) > 0, |
983 | "buildShapeExpressions failed at accessing input shapes" ); |
984 | shape_of.emplace(n->output(0), shape_of.at(n->input(0))); |
985 | continue; |
986 | } |
987 | // TODO: output(1) & output(2) should also be marked |
988 | if (n->kind() == aten::native_layer_norm_backward) { |
989 | TORCH_INTERNAL_ASSERT( |
990 | shape_of.count(n->input(0)) > 0, |
991 | "buildShapeExpressions failed at accessing input shapes" ); |
992 | shape_of.emplace(n->output(0), shape_of.at(n->input(0))); |
993 | if (shape_of.count(n->input(5)) > 0) { |
994 | shape_of.emplace(n->output(1), shape_of.at(n->input(5))); |
995 | } |
996 | if (shape_of.count(n->input(6)) > 0) { |
997 | shape_of.emplace(n->output(2), shape_of.at(n->input(6))); |
998 | } |
999 | continue; |
1000 | } |
1001 | // TODO: output(1) & output(2) should also be marked |
1002 | if (n->kind() == aten::native_batch_norm || |
1003 | n->kind() == aten::_batch_norm_impl_index) { |
1004 | TORCH_INTERNAL_ASSERT( |
1005 | shape_of.count(n->input(0)) > 0, |
1006 | "buildShapeExpressions failed at accessing input shapes" ); |
1007 | shape_of.emplace(n->output(0), shape_of.at(n->input(0))); |
1008 | continue; |
1009 | } |
1010 | // TODO: output(1) & output(2) should also be marked |
1011 | if (n->kind() == aten::native_batch_norm_backward) { |
1012 | TORCH_INTERNAL_ASSERT( |
1013 | shape_of.count(n->input(0)) > 0, |
1014 | "buildShapeExpressions failed at accessing input shapes" ); |
1015 | shape_of.emplace(n->output(0), shape_of.at(n->input(0))); |
1016 | if (shape_of.count(n->input(2)) > 0) { |
1017 | shape_of.emplace(n->output(1), shape_of.at(n->input(2))); |
1018 | // use shape of weight here for grad_bias |
1019 | shape_of.emplace(n->output(2), shape_of.at(n->input(2))); |
1020 | } |
1021 | continue; |
1022 | } |
1023 | if (n->kind() == aten::_batch_norm_impl_index_backward) { |
1024 | TORCH_INTERNAL_ASSERT( |
1025 | shape_of.count(n->input(1)) > 0, |
1026 | "buildShapeExpressions failed at accessing input shapes" ); |
1027 | shape_of.emplace(n->output(0), shape_of.at(n->input(1))); |
1028 | if (shape_of.count(n->input(3)) > 0) { |
1029 | shape_of.emplace(n->output(1), shape_of.at(n->input(3))); |
1030 | // use shape of weight here for grad_bias |
1031 | shape_of.emplace(n->output(2), shape_of.at(n->input(3))); |
1032 | } |
1033 | continue; |
1034 | } |
1035 | if (n->kind() == aten::native_dropout) { |
1036 | TORCH_INTERNAL_ASSERT( |
1037 | shape_of.count(n->input(0)) > 0, |
1038 | "buildShapeExpressions failed at accessing input shapes" ); |
1039 | shape_of.emplace(n->output(0), shape_of.at(n->input(0))); |
1040 | shape_of.emplace(n->output(1), shape_of.at(n->input(0))); |
1041 | continue; |
1042 | } |
1043 | if (n->kind() == prim::unsqueeze_copy) { |
1044 | TORCH_INTERNAL_ASSERT( |
1045 | shape_of.count(n->input(0)) > 0, |
1046 | "buildShapeExpressions failed at accessing input shapes" ); |
1047 | TORCH_INTERNAL_ASSERT( |
1048 | n->input(1)->node()->kind() == prim::Constant, |
1049 | "only supports unsqueeze axes being constant" ); |
1050 | Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); |
1051 | graph->insertNode(dim_const); |
1052 | std::vector<Value*> inputs = { |
1053 | shape_of.at(n->input(0)), dim_const->output()}; |
1054 | Node* size_node = graph->insertNode(graph->create( |
1055 | Symbol::fromQualString("prim::infer_unsqueeze_size" ), inputs, 1)); |
1056 | Value* size = size_node->output(0); |
1057 | size->setType(ListType::ofInts()); |
1058 | shape_of.emplace(n->output(), size); |
1059 | continue; |
1060 | } |
1061 | if (n->kind() == prim::squeeze_copy) { |
1062 | TORCH_INTERNAL_ASSERT( |
1063 | shape_of.count(n->input(0)) > 0, |
1064 | "buildShapeExpressions failed at accessing input shapes" ); |
1065 | TORCH_INTERNAL_ASSERT( |
1066 | n->inputs().size() == 2 || n->inputs().size() == 1, |
1067 | "prim::squeeze_copy expects one or two inputs" ); |
1068 | std::vector<Value*> inputs = {shape_of.at(n->input(0))}; |
1069 | |
1070 | if (n->inputs().size() == 2) { |
1071 | TORCH_INTERNAL_ASSERT( |
1072 | n->input(1)->node()->kind() == prim::Constant, |
1073 | "only supports squeeze axes being constant" ); |
1074 | Node* dim_const = graph->createClone(n->input(1)->node(), map_inputs); |
1075 | graph->insertNode(dim_const); |
1076 | inputs.push_back(dim_const->output()); |
1077 | } |
1078 | Node* size_node = graph->insertNode(graph->create( |
1079 | Symbol::fromQualString("prim::infer_squeeze_size" ), inputs, 1)); |
1080 | Value* size = size_node->output(0); |
1081 | size->setType(ListType::ofInts()); |
1082 | shape_of.emplace(n->output(), size); |
1083 | continue; |
1084 | } |
1085 | |
1086 | auto tensor_inputs = filter(n->inputs(), [](Value* v) { |
1087 | return v->type()->isSubtypeOf(*TensorType::get()); |
1088 | }); |
1089 | auto shapes = fmap(tensor_inputs, [&](Value* v) { |
1090 | TORCH_INTERNAL_ASSERT( |
1091 | shape_of.count(v) > 0, |
1092 | "buildShapeExpressions failed at accessing input shapes" ); |
1093 | return shape_of.at(v); |
1094 | }); |
1095 | AT_ASSERT(!shapes.empty()); |
1096 | shape_of.emplace( |
1097 | n->output(0), |
1098 | shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes)); |
1099 | } |
1100 | return shape_of; |
1101 | } |
1102 | |
1103 | void removeOutputsUsedOnlyInSize(Node* fusion_group) { |
1104 | if (fusion_group->kind() != prim::CudaFusionGroup) |
1105 | return; |
1106 | auto subgraph = fusion_group->g(attr::Subgraph); |
1107 | |
1108 | // TODO: failure in buildShapeExpressions should not break fusion execution, |
1109 | // we can add a try/catch here to bailout from removeOutputsUsedOnlyInSize. |
1110 | GRAPH_DEBUG("before build shape expression: " , *graph_); |
1111 | auto shape_map = buildShapeExpressions(fusion_group); |
1112 | fusion_value_to_runtime_shape_.insert(shape_map.begin(), shape_map.end()); |
1113 | GRAPH_DEBUG("after build shape expression: " , *graph_); |
1114 | |
1115 | auto outputs = fusion_group->outputs().vec(); |
1116 | auto soutputs = subgraph->outputs().vec(); |
1117 | // XXX: Iterating in this order is not only good for performance reasons! |
1118 | // It is also crucial for correctness (i has to reflect the current true |
1119 | // index of outputs[i])! |
1120 | for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) { |
1121 | auto output = outputs[i]; |
1122 | auto soutput = soutputs[i]; |
1123 | if (usedOnlyInDtypeAndSize(output) && shape_map.count(soutput) > 0) { |
1124 | bool has_dtype = usedInDtype(output); |
1125 | auto uses = output->uses(); |
1126 | for (Use u : uses) { |
1127 | if (u.user->matches("aten::size(Tensor self) -> int[]" )) { |
1128 | u.user->output()->replaceAllUsesWith(shape_map.at(soutput)); |
1129 | u.user->destroy(); |
1130 | } else if (u.user->matches("prim::dtype(Tensor a) -> int" )) { |
1131 | continue; |
1132 | } else { |
1133 | AT_ASSERT( |
1134 | false, |
1135 | "unrecognized consumer should not trigger removeOutputsUsedOnlyInSize" ); |
1136 | } |
1137 | } |
1138 | // We only wipe the output when there's no more dtype consumer. |
1139 | // This is to be removed by `removeOutputUsedOnlyInDtype` |
1140 | if (!has_dtype) { |
1141 | fusion_group->eraseOutput(i); |
1142 | subgraph->eraseOutput(i); |
1143 | } |
1144 | } |
1145 | } |
1146 | GRAPH_DEBUG("after build shape expression and re-wiring: " , *graph_); |
1147 | } |
1148 | |
1149 | void refreshAliasDb() { |
1150 | aliasDb_ = torch::make_unique<AliasDb>(graph_); |
1151 | } |
1152 | |
1153 | void removeNoopBinaryOps(Block* block) { |
1154 | for (Node* node : block->nodes()) { |
1155 | for (Block* b : node->blocks()) { |
1156 | removeNoopBinaryOps(b); |
1157 | } |
1158 | |
1159 | if (node->matches( |
1160 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
1161 | /*const_inputs=*/{attr::alpha, attr::other}) || |
1162 | node->matches( |
1163 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
1164 | /*const_inputs=*/{attr::alpha, attr::other})) { |
1165 | // x + 0 == x - 0 == x |
1166 | // if either scalar input is a float, than removing this operator could |
1167 | // remove type promotion and affect semantics |
1168 | auto scalar_type = |
1169 | node->input(0)->type()->expectRef<TensorType>().scalarType(); |
1170 | if (!scalar_type.has_value() || |
1171 | !at::isFloatingType(scalar_type.value())) { |
1172 | auto inps = node->inputs(); |
1173 | if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) || |
1174 | !inps.at(2)->type()->isSubtypeOf(IntType::get())) { |
1175 | continue; |
1176 | } |
1177 | } |
1178 | |
1179 | if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 && |
1180 | node->get<at::Scalar>(attr::other)->toDouble() == 0) { |
1181 | GRAPH_UPDATE( |
1182 | getHeader(node), |
1183 | " (x + 0 == x - 0 == x) is replaced with " , |
1184 | node->input(0)->debugName()); |
1185 | node->output()->replaceAllUsesWith(node->input(0)); |
1186 | } |
1187 | } else if ( |
1188 | node->matches( |
1189 | "aten::mul(Tensor self, Scalar other) -> Tensor" , |
1190 | /*const_inputs=*/attr::other) || |
1191 | node->matches( |
1192 | "aten::div(Tensor self, Scalar other) -> Tensor" , |
1193 | /*const_inputs=*/attr::other)) { |
1194 | // x * 1 == x / 1 == x |
1195 | // is the node is a division or other isn't an integer, than removing |
1196 | // this operator could remove type promotion and affect semantics |
1197 | auto scalar_type = |
1198 | node->input(0)->type()->expectRef<TensorType>().scalarType(); |
1199 | if (!scalar_type.has_value() || |
1200 | !at::isFloatingType(scalar_type.value())) { |
1201 | if (node->kind() == aten::div || |
1202 | !node->input(1)->type()->isSubtypeOf(IntType::get())) { |
1203 | continue; |
1204 | } |
1205 | } |
1206 | |
1207 | if (node->get<at::Scalar>(attr::other)->toDouble() == 1) { |
1208 | GRAPH_UPDATE( |
1209 | getHeader(node), |
1210 | " (x * 1 == x / 1 == x) is replaced with " , |
1211 | node->input(0)->debugName()); |
1212 | node->output()->replaceAllUsesWith(node->input(0)); |
1213 | } |
1214 | } |
1215 | } |
1216 | } |
1217 | |
1218 | void optimizeFusedGraphs() { |
1219 | for (Node* node : block_->nodes()) { |
1220 | if (node->kind() != kind_) { |
1221 | continue; |
1222 | } |
1223 | auto subgraph = node->g(attr::Subgraph); |
1224 | GRAPH_DEBUG("before optimizing: " , *subgraph); |
1225 | removeNoopBinaryOps(subgraph->block()); |
1226 | EliminateDeadCode(subgraph); |
1227 | EliminateCommonSubexpression(subgraph); |
1228 | ConstantPooling(subgraph); |
1229 | GRAPH_DEBUG("after optimizing: " , *subgraph); |
1230 | } |
1231 | } |
1232 | |
1233 | void run() { |
1234 | // Run the pass until no changes are made. |
1235 | // This is necessary, because the algorithm can miss out on certain fusion |
1236 | // opportunities if ran only once. Consider this graph: |
1237 | // |
1238 | // %1 = f(...) |
1239 | // %2 = g(%1) |
1240 | // %3 = h(%1) |
1241 | // %4 = l(%3) |
1242 | // return (%4, %2) |
1243 | // |
1244 | // where f, g, h, l are simple map ops. |
1245 | // The first iteration will fuse %4 and %3, and see that %1 is an input, but |
1246 | // can't be fused, because it has a different use before the fusion group |
1247 | // in our topological ordering. Then, %2 will be considered, and fused with |
1248 | // %1. If we do another iteration, the algorithm will consider the fusion of |
1249 | // these two groups and fix the situation. |
1250 | bool any_changed = true; |
1251 | while (any_changed) { |
1252 | any_changed = false; |
1253 | refreshAliasDb(); |
1254 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
1255 | bool changed = false; |
1256 | std::tie(it, changed) = scanNode(*it); |
1257 | any_changed |= changed; |
1258 | } |
1259 | } |
1260 | |
1261 | GRAPH_DEBUG("after scan and merge" , *graph_); |
1262 | refreshAliasDb(); |
1263 | |
1264 | optimizeFusedGraphs(); |
1265 | |
1266 | // The graph fuser can add intermediate prim::BroadcastingChunk nodes. |
1267 | // Replace them with broadcasts + chunks. |
1268 | replaceIntermediateBroadcastingChunks(); |
1269 | |
1270 | // Fuse starting chunks into the group. |
1271 | // for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { |
1272 | // it = scanNodeForChunks(*it); |
1273 | //} |
1274 | |
1275 | GRAPH_DEBUG("before removeOutputsUsedOnlyInSize" , *graph_); |
1276 | // Remove outputs that have been added only because we need their size |
1277 | for (Node* n : block_->nodes()) { |
1278 | removeOutputsUsedOnlyInSize(n); |
1279 | } |
1280 | GRAPH_DEBUG("after removeOutputsUsedOnlyInSize" , *graph_); |
1281 | |
1282 | for (Node* node : block_->nodes()) { |
1283 | for (Block* sub_block : node->blocks()) { |
1284 | CudaGraphFuser sub_block_cfg(sub_block, graph_); |
1285 | sub_block_cfg.run(); |
1286 | // Accumulate runtime shapes for all sub-blocks |
1287 | fusion_value_to_runtime_shape_.insert( |
1288 | sub_block_cfg.fusion_value_to_runtime_shape_.begin(), |
1289 | sub_block_cfg.fusion_value_to_runtime_shape_.end()); |
1290 | } |
1291 | } |
1292 | } |
1293 | }; |
1294 | |
1295 | void removeCudaFusionPathForGuardNode(Node* n) { |
1296 | auto uses = n->output()->uses(); |
1297 | TORCH_INTERNAL_ASSERT( |
1298 | uses.size() == 1, |
1299 | "CudaFusionGuard should only be used once by prim::If or prim::ListConstruct" ); |
1300 | Node* if_node = uses[0].user; |
1301 | if (if_node->kind() != prim::If) { |
1302 | TORCH_INTERNAL_ASSERT( |
1303 | if_node->kind() == prim::ListConstruct, |
1304 | "CudaFusionGuard is not used by neither prim::If or prim::ListConstruct" ); |
1305 | // break all inputs so producer prim::CudaFusionGuard can be removed later |
1306 | if_node->removeAllInputs(); |
1307 | auto list_use = if_node->output()->uses(); |
1308 | TORCH_INTERNAL_ASSERT( |
1309 | list_use.size() == 1 && list_use[0].user->kind() == aten::all, |
1310 | "prim::ListConstruct should only be used once by aten::all" ); |
1311 | auto all_use = list_use[0].user->output()->uses(); |
1312 | TORCH_INTERNAL_ASSERT( |
1313 | all_use.size() == 1 && all_use[0].user->kind() == prim::If, |
1314 | "aten::all should only be used once by prim::If" ); |
1315 | if_node = all_use[0].user; |
1316 | } |
1317 | |
1318 | auto fall_back_graph = if_node->blocks()[1]; |
1319 | Node* fallback_node = nullptr; |
1320 | for (auto fb_n : fall_back_graph->nodes()) { |
1321 | TORCH_INTERNAL_ASSERT( |
1322 | fb_n->kind() == prim::FallbackGraph, |
1323 | "CudaFusionGuard fallback path should only have single fallback node" ); |
1324 | TORCH_INTERNAL_ASSERT( |
1325 | fallback_node == nullptr, |
1326 | "CudaFusionGuard fallback path should only have single fallback node" ); |
1327 | fallback_node = fb_n; |
1328 | } |
1329 | |
1330 | TORCH_INTERNAL_ASSERT( |
1331 | fallback_node != nullptr, |
1332 | "CudaFusionGuard fallback path found no fallback node" ); |
1333 | fallback_node->moveBefore(n); |
1334 | |
1335 | TORCH_INTERNAL_ASSERT( |
1336 | fallback_node->outputs().size() == if_node->outputs().size(), |
1337 | "CudaFusionGuard fallback should have same number of outputs as with nesting if block" ); |
1338 | |
1339 | if_node->replaceAllUsesWith(fallback_node); |
1340 | if_node->destroy(); |
1341 | n->destroy(); |
1342 | } |
1343 | |
1344 | bool missingCompleteTypes(const std::vector<TypePtr>& types) { |
1345 | for (const auto& type : types) { |
1346 | if (auto tensor_type = type->cast<TensorType>()) { |
1347 | // if we found one missing value, we know that we are not going to able to |
1348 | // generate a kernel, so we bail out; |
1349 | if (!tensor_type->device().has_value() || |
1350 | !tensor_type->dim().has_value() || |
1351 | !tensor_type->scalarType().has_value()) { |
1352 | return true; |
1353 | } |
1354 | } |
1355 | } |
1356 | return false; |
1357 | } |
1358 | |
1359 | void removeFusionWithMissingProfilingInformation(Block* block) { |
1360 | FUSER_PERF_SCOPE("compileFusionRecursive" ); |
1361 | std::vector<Node*> removeCudaFusionNodes; |
1362 | |
1363 | for (auto node : block->nodes()) { |
1364 | if (node->kind() == prim::CudaFusionGuard && |
1365 | missingCompleteTypes(node->tys(attr::types))) { |
1366 | removeCudaFusionNodes.push_back(node); |
1367 | } |
1368 | for (auto sub_block : node->blocks()) { |
1369 | removeFusionWithMissingProfilingInformation(sub_block); |
1370 | } |
1371 | } |
1372 | |
1373 | for (auto node : removeCudaFusionNodes) { |
1374 | removeCudaFusionPathForGuardNode(node); |
1375 | } |
1376 | } |
1377 | |
1378 | void compileFusionRecursive(Block* block) { |
1379 | FUSER_PERF_SCOPE("compileFusionRecursive" ); |
1380 | |
1381 | for (auto node : block->nodes()) { |
1382 | if (node->kind() == prim::CudaFusionGroup) { |
1383 | fuser::cuda::compileFusionGroup(node); |
1384 | } |
1385 | for (auto sub_block : node->blocks()) { |
1386 | compileFusionRecursive(sub_block); |
1387 | } |
1388 | } |
1389 | } |
1390 | |
1391 | void PeepholeOptimizeShapeExpressions(Block* block) { |
1392 | FUSER_PERF_SCOPE("PeepholeOptimizeShapeExpressions" ); |
1393 | |
1394 | auto nodes = block->nodes(); |
1395 | for (auto it = nodes.begin(); it != nodes.end(); ++it) { |
1396 | Node* node = *it; |
1397 | for (Block* subblock : node->blocks()) { |
1398 | PeepholeOptimizeShapeExpressions(subblock); |
1399 | } |
1400 | if (node->kind() == prim::BroadcastSizes) { |
1401 | // Remove no-op broadcasts. |
1402 | if (node->inputs().size() == 1) { |
1403 | node->output()->replaceAllUsesWith(node->input()); |
1404 | it.destroyCurrent(); |
1405 | continue; |
1406 | } |
1407 | // Deduplicate inputs, but use their unique() values to ensure |
1408 | // this process only depends on the graph. |
1409 | std::map<size_t, Value*> unique_to_value; |
1410 | for (Value* input : node->inputs()) { |
1411 | unique_to_value.emplace(input->unique(), input); |
1412 | } |
1413 | if (unique_to_value.size() != node->inputs().size()) { |
1414 | std::vector<Value*> inputs; |
1415 | inputs.reserve(unique_to_value.size()); |
1416 | for (auto& entry : unique_to_value) { |
1417 | inputs.push_back(entry.second); |
1418 | } |
1419 | if (inputs.size() == 1) { |
1420 | node->output()->replaceAllUsesWith(inputs[0]); |
1421 | } else { |
1422 | WithInsertPoint insert_guard{node}; |
1423 | node->output()->replaceAllUsesWith(broadcastSizes(inputs)); |
1424 | } |
1425 | it.destroyCurrent(); |
1426 | --it; // Revisit the node with deduplicated inputs |
1427 | continue; |
1428 | } |
1429 | // Remove compose simple chains of broadcasts into a single node. |
1430 | const auto& uses = node->output()->uses(); |
1431 | if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) { |
1432 | Node* user = uses[0].user; |
1433 | user->removeInput(uses[0].offset); |
1434 | // NB: we don't care about deduplication in here, as we will visit user |
1435 | // later. |
1436 | for (Value* i : node->inputs()) { |
1437 | user->addInput(i); |
1438 | } |
1439 | it.destroyCurrent(); |
1440 | } |
1441 | } |
1442 | } |
1443 | } |
1444 | |
1445 | // view_sizes_runtime is the profiled-ivalue argument for view-size. |
1446 | // view_sizes_constant_list is the constant list recorded during profiling runs. |
1447 | Value* guardView( |
1448 | Node* fusion, |
1449 | std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size, |
1450 | Node* versioning_if, |
1451 | Node* view, |
1452 | Value* view_sizes_runtime) { |
1453 | // 1. Get self tensor sizes and view_sizes |
1454 | auto self_value = view->inputs().front(); |
1455 | auto self_type = self_value->type()->cast<TensorType>(); |
1456 | auto self_sizes_constant_list = getTensorSizes(self_type); |
1457 | |
1458 | auto view_sizes_constant_list = |
1459 | constant_as<c10::List<int64_t>>(view->inputs().back()); |
1460 | TORCH_INTERNAL_ASSERT(view_sizes_constant_list.has_value()); |
1461 | std::vector<int64_t> view_sizes = view_sizes_constant_list->vec(); |
1462 | // 2. Get constraints for self tensor and view_sizes |
1463 | auto constraints = |
1464 | analyzeViewConstraint(self_sizes_constant_list, view_sizes); |
1465 | |
1466 | // 3. Add constraints as constant to graph |
1467 | auto full_constraints = fusion->owningGraph()->insertConstant( |
1468 | IValue(constraints.conglomerateString())); |
1469 | full_constraints->node()->moveBefore(versioning_if); |
1470 | |
1471 | // 4. Create CudaFusionViewGuard using input tensor, profile_ivalue |
1472 | // for view_sizes list, and constraints |
1473 | TORCH_INTERNAL_ASSERT( |
1474 | fusion_value_to_runtime_size.find(self_value) != |
1475 | fusion_value_to_runtime_size.end(), |
1476 | "Failed to find runtime size for fusion value:\t" , |
1477 | self_value->node()->kind().toDisplayString()); |
1478 | Node* viewcheck_node = |
1479 | fusion->owningGraph() |
1480 | ->create( |
1481 | c10::Symbol::fromQualString("prim::CudaFusionViewGuard" ), |
1482 | {fusion_value_to_runtime_size.at(self_value), |
1483 | view_sizes_runtime, |
1484 | full_constraints}, |
1485 | 1) |
1486 | ->insertBefore(versioning_if); |
1487 | return viewcheck_node->output(); |
1488 | } |
1489 | |
1490 | //! [ Note -- CudaFusionGuard implementation ] |
1491 | //! |
1492 | //! shamelessly copying code from NNC (tensorexpr_fuser) with very little |
1493 | //! modification, original code at: |
1494 | //! `../../passes/tensorexpr_fuser.cpp:guardFusionGroup` |
1495 | //! |
1496 | //! Add prim::CudaFusionGuard node to ensure that accepted profiling information |
1497 | //! is not violated at runtime. |
1498 | //! |
1499 | //! We replace a single |
1500 | //! |
1501 | //! outputs = prim::CudaFusionGroup[cache_id](inputs) |
1502 | //! |
1503 | //! with the following pattern: |
1504 | //! |
1505 | //! %1 : bool = prim::CudaFusionGuard[types=[...]](inputs) |
1506 | //! outputs = prim::If(%1) |
1507 | //! block0(): |
1508 | //! outputs = prim::CudaFusionGroup[cache_id](inputs) |
1509 | //! -> (outputs) |
1510 | //! block1(): |
1511 | //! %2 : Function = prim::Constant[name="fallback_function", fallback=1]() |
1512 | //! otuputs = prim::CallFunction(%2, inputs) |
1513 | //! -> (outputs) |
1514 | //! |
1515 | //! `prim::CudaFusionGuard` stores all profiled data type in attribute |
1516 | //! `attr::types`. |
1517 | //! At runtime, we check input tensors against our profiled data type and return |
1518 | //! an output holds the result of the check (bool). |
1519 | //! See [ Note -- type guard logic in CudaFusionGuard ] |
1520 | //! |
1521 | //! This ensures that `prim::CudaFusionGroup` only execute compatible inputs. |
1522 | //! In case of check failure, execution goes through false block, which |
1523 | //! recursively goes along another profiling / optimization iteration. (could be |
1524 | //! tuned by `bailout_depth`) |
1525 | //! |
1526 | //! TODO: we also need to assert/check reduction axes and replace it with |
1527 | //! constants in `CudaFusionGroup` |
1528 | void guardFusionGroup( |
1529 | Node* fusion, |
1530 | std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size) { |
1531 | // Fixup types of the subgraph inputs |
1532 | std::vector<TypePtr> guard_types; |
1533 | std::vector<Value*> tensor_inputs_to_check; |
1534 | std::set<size_t> profiled_ivalue_indices; |
1535 | |
1536 | for (const auto index : c10::irange(fusion->inputs().size())) { |
1537 | Value* input = fusion->inputs()[index]; |
1538 | if (input->type()->cast<TensorType>()) { |
1539 | // We only check inputs of the fusion group and expect NNC to infer |
1540 | // intermediates and outputs shapes |
1541 | |
1542 | // note: modified from original implementation, we are guarding fusion |
1543 | // outputs |
1544 | if (input->node()->kind() == prim::Constant) { |
1545 | continue; |
1546 | } |
1547 | tensor_inputs_to_check.push_back(input); |
1548 | guard_types.push_back(input->type()); |
1549 | } else if (input->node()->kind() == prim::profile_ivalue) { |
1550 | // Conditional constant from profiled_ivalue, should be guarded |
1551 | profiled_ivalue_indices.insert(index); |
1552 | } |
1553 | } |
1554 | |
1555 | // insert the if block first; |
1556 | auto versioning_if = |
1557 | fusion->owningGraph()->create(prim::If, fusion->outputs().size()); |
1558 | for (const auto idx : c10::irange(fusion->outputs().size())) { |
1559 | versioning_if->output(idx)->setType(fusion->output(idx)->type()); |
1560 | fusion->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); |
1561 | } |
1562 | auto true_block = versioning_if->addBlock(); |
1563 | auto false_block = versioning_if->addBlock(); |
1564 | |
1565 | // insert typecheck_node; |
1566 | Node* typecheck_node = |
1567 | fusion->owningGraph() |
1568 | ->create(prim::CudaFusionGuard, tensor_inputs_to_check, 1) |
1569 | ->insertBefore(fusion); |
1570 | // fix output to BoolType |
1571 | typecheck_node->output()->setType(BoolType::get()); |
1572 | Value* typecheck_result = typecheck_node->output(); |
1573 | typecheck_node->tys_(attr::types, guard_types); |
1574 | |
1575 | versioning_if->insertAfter(typecheck_node); |
1576 | |
1577 | auto fusion_graph = fusion->g(attr::Subgraph); |
1578 | std::vector<Value*> check_flags = {}; |
1579 | |
1580 | // Fill in the false block. It should contain the unoptimized |
1581 | // copy of the fused subgraph, unless we have conditional constants from |
1582 | // profiled_ivalue; |
1583 | std::shared_ptr<Graph> fb_graph; // resource holder; |
1584 | // Restore the dependency for constant introduced by profiled_ivalue within |
1585 | // the graph. |
1586 | if (!profiled_ivalue_indices.empty()) { |
1587 | // This is necessary as it cleans up the fallback graph, which was copied |
1588 | // from subgraph, since the two graph would differ as we cannot use |
1589 | // conditional constant in fallback |
1590 | |
1591 | // 1. RESTORE conditional constant dependency in fallback group; |
1592 | fb_graph = fusion_graph->copy(); |
1593 | GRAPH_DEBUG("re-wiring fallback graph" , *fb_graph); |
1594 | |
1595 | for (const auto& offset : profiled_ivalue_indices) { |
1596 | auto val = fb_graph->inputs()[offset]; |
1597 | auto uses = val->uses(); |
1598 | // since we are updating use of val in the loop, we have to copy |
1599 | // val->uses() before hand. |
1600 | for (const auto& use : uses) { |
1601 | // re-wire inputs and remove conditional constant nodes; |
1602 | TORCH_INTERNAL_ASSERT( |
1603 | use.user->kind() == prim::Constant, |
1604 | "profile_ivalue at index: " , |
1605 | offset, |
1606 | " can only be used by conditional constant, instead got: " , |
1607 | use.user->kind().toDisplayString()); |
1608 | use.user->output()->replaceAllUsesWith(val); |
1609 | use.user->destroy(); |
1610 | } |
1611 | } |
1612 | |
1613 | WithInsertPoint guard(false_block->return_node()); |
1614 | const auto subgraph_outputs = |
1615 | insertGraph(*fusion->owningGraph(), *fb_graph, fusion->inputs()); |
1616 | for (Value* output : subgraph_outputs) { |
1617 | false_block->registerOutput(output); |
1618 | } |
1619 | // types get copied to the fallback graph, so remove specializations before |
1620 | // replacing |
1621 | // TODO: this is not exposed here, I need to remove that before inserting |
1622 | // the graph |
1623 | // removeTensorTypeSpecializations(false_block); |
1624 | replaceBlockWithFallbackGraph(false_block, fusion->inputs()); |
1625 | |
1626 | // 2. REMOVE conditional constant dependency in fusion group |
1627 | size_t compensation = 0; |
1628 | |
1629 | // get a constant true, which is used by `and` pattern later |
1630 | auto const_true = fusion->owningGraph()->insertConstant(IValue(true)); |
1631 | const_true->node()->moveBefore(versioning_if); |
1632 | |
1633 | for (const auto& original_offset : profiled_ivalue_indices) { |
1634 | size_t offset = original_offset - compensation; |
1635 | |
1636 | // step a. handle fusion |
1637 | // remove inputs to fusion, and update check logic for fallback |
1638 | auto profiled_ival = fusion->input(offset)->node()->input(); |
1639 | auto const_o = createConditionalConstant(fusion->input(offset)->node()); |
1640 | TORCH_INTERNAL_ASSERT( |
1641 | const_o, |
1642 | "profile_ivalue node are expected to have profile information, at node: " , |
1643 | *fusion->input(offset)->node()); |
1644 | const_o->node()->moveBefore(versioning_if); |
1645 | Value* ivalue_check = nullptr; |
1646 | |
1647 | if (fusion->input(offset)->node()->hasAttribute( |
1648 | Symbol::attr("profiled_bool" ))) { |
1649 | // aten::eq doesn't support comparison between two boolean |
1650 | auto xor_n = fusion->owningGraph() |
1651 | ->create(aten::__xor__, {profiled_ival, const_o}, 1) |
1652 | ->insertBefore(versioning_if); |
1653 | xor_n->output()->setType(BoolType::get()); |
1654 | ivalue_check = |
1655 | fusion->owningGraph() |
1656 | ->create(aten::__xor__, {xor_n->output(), const_true}, 1) |
1657 | ->insertBefore(versioning_if) |
1658 | ->output(); |
1659 | } else if (fusion->input(offset)->node()->hasAttribute( |
1660 | Symbol::attr("profiled_reduction_size" ))) { |
1661 | // TODO(profile_size): check sizes here with special size comparison op |
1662 | // TORCH_INTERNAL_ASSERT(false, "not implemented yet"); |
1663 | ivalue_check = |
1664 | fusion->owningGraph() |
1665 | ->create( |
1666 | c10::Symbol::fromQualString("prim::CudaFusionSizeEq" ), |
1667 | {profiled_ival, const_o}, |
1668 | 1) |
1669 | ->insertBefore(versioning_if) |
1670 | ->output(); |
1671 | } else if (fusion->input(offset)->node()->hasAttribute( |
1672 | Symbol::attr("profiled_view_size" ))) { |
1673 | // TODO: Add support for dynamic split to view guard |
1674 | |
1675 | // Path from profile-ivalue to prim::view_copy operation |
1676 | // profile-ivalue -> Constant -> CudaFusionGroup |
1677 | // Get argument position in CudaFusionGroup |
1678 | // Get argument in subgraph for CudaFusionGroup |
1679 | // CudaFusionGroup argument -> Constant List -> prim::view_copy |
1680 | auto subgraph_arg = fusion_graph->inputs()[offset]; |
1681 | auto constant = subgraph_arg->uses().front().user->output(); |
1682 | |
1683 | TORCH_INTERNAL_ASSERT(!constant->uses().empty()); |
1684 | auto view = constant->uses().front().user; |
1685 | TORCH_INTERNAL_ASSERT( |
1686 | view->kind() == prim::view_copy || |
1687 | view->kind() == prim::reshape_copy); |
1688 | |
1689 | ivalue_check = guardView( |
1690 | fusion, |
1691 | fusion_value_to_runtime_size, |
1692 | versioning_if, |
1693 | view, |
1694 | profiled_ival); |
1695 | } else if (fusion->input(offset)->node()->hasAttribute( |
1696 | Symbol::attr("profiled_ival" ))) { |
1697 | ivalue_check = |
1698 | fusion->owningGraph() |
1699 | ->create( |
1700 | c10::Symbol::fromQualString("prim::CudaFusionIvalGuard" ), |
1701 | {profiled_ival, const_o}, |
1702 | 1) |
1703 | ->insertBefore(versioning_if) |
1704 | ->output(); |
1705 | } else { |
1706 | ivalue_check = fusion->owningGraph() |
1707 | ->create(aten::eq, {profiled_ival, const_o}, 1) |
1708 | ->insertBefore(versioning_if) |
1709 | ->output(); |
1710 | } |
1711 | ivalue_check->setType(BoolType::get()); |
1712 | |
1713 | // aggregate flags; |
1714 | check_flags.emplace_back(ivalue_check); |
1715 | |
1716 | // remove inputs to fusion; |
1717 | fusion->removeInput(offset); |
1718 | |
1719 | // step b. remove the extra dependency inside fusion; |
1720 | for (const auto& use : fusion_graph->inputs()[offset]->uses()) { |
1721 | TORCH_INTERNAL_ASSERT( |
1722 | use.user->kind() == prim::Constant, |
1723 | "profile_ivalue at index: " , |
1724 | offset, |
1725 | " can only be used by conditional constant, instead got: " , |
1726 | use.user->kind().toDisplayString()); |
1727 | use.user->removeAllInputs(); |
1728 | } |
1729 | fusion_graph->eraseInput(offset); |
1730 | compensation++; |
1731 | } |
1732 | // update graph in fusion node |
1733 | fusion->g_(attr::Subgraph, fusion_graph); |
1734 | } |
1735 | |
1736 | if (!check_flags.empty()) { |
1737 | // attaching output from CudaFusionGuard to profile ivalue checks |
1738 | check_flags.emplace_back(typecheck_result); |
1739 | auto graph = fusion->owningGraph(); |
1740 | auto bool_list_node = |
1741 | graph->insertNode(graph->createList(BoolType::get(), check_flags)); |
1742 | bool_list_node->moveBefore(versioning_if); |
1743 | Value* bool_list = bool_list_node->output(); |
1744 | // new typecheck_result |
1745 | typecheck_result = graph->insert(aten::all, {bool_list}); |
1746 | typecheck_result->node()->moveBefore(versioning_if); |
1747 | } |
1748 | |
1749 | if (profiled_ivalue_indices.empty()) { |
1750 | WithInsertPoint guard(false_block->return_node()); |
1751 | const auto subgraph_outputs = |
1752 | insertGraph(*fusion->owningGraph(), *fusion_graph, fusion->inputs()); |
1753 | for (Value* output : subgraph_outputs) { |
1754 | false_block->registerOutput(output); |
1755 | } |
1756 | // types get copied to the fallback graph, so remove specializations before |
1757 | // replacing |
1758 | // TODO: this is not exposed here, I need to remove that before inserting |
1759 | // the graph |
1760 | // removeTensorTypeSpecializations(false_block); |
1761 | replaceBlockWithFallbackGraph(false_block, fusion->inputs()); |
1762 | } |
1763 | |
1764 | // wiring up if block |
1765 | versioning_if->addInput(typecheck_result); |
1766 | |
1767 | // Fill in the true block. It has all inputs type-checked and its |
1768 | // body should be the fusion group node. |
1769 | fusion->moveBefore(true_block->return_node()); |
1770 | for (Value* output : fusion->outputs()) { |
1771 | true_block->registerOutput(output); |
1772 | } |
1773 | } |
1774 | |
1775 | void guardFusionGroups( |
1776 | Block* block, |
1777 | std::unordered_map<Value*, Value*>& fusion_value_to_runtime_size) { |
1778 | std::vector<Node*> fusions; |
1779 | for (Node* n : block->nodes()) { |
1780 | for (Block* b : n->blocks()) { |
1781 | guardFusionGroups(b, fusion_value_to_runtime_size); |
1782 | } |
1783 | if (n->kind() == prim::CudaFusionGroup) { |
1784 | fusions.push_back(n); |
1785 | } |
1786 | } |
1787 | for (Node* fusion : fusions) { |
1788 | // step 1: a. add prim::CudaFusionGuard and fallback logic |
1789 | // b. insert guard logic of profile_ivalue with if block |
1790 | // c. restore conditional constant to non-constant for fallback |
1791 | guardFusionGroup(fusion, fusion_value_to_runtime_size); |
1792 | } |
1793 | } |
1794 | |
1795 | void dumpFusionGroups(std::shared_ptr<Graph>& g) { |
1796 | DepthFirstGraphNodeIterator it(g); |
1797 | Node* n = nullptr; |
1798 | GRAPH_DEBUG("Exporting all NVFuser fusions:" ); |
1799 | while ((n = it.next()) != nullptr) { |
1800 | if (n->kind() == prim::FallbackGraph) { |
1801 | GRAPH_EXPORT("" , n->g(attr::Subgraph)); |
1802 | } |
1803 | } |
1804 | } |
1805 | |
1806 | // rewire const integer index & empty byte-typed reserve space tensor outputs, |
1807 | // so `CudaFusionGroup` doesn't have to handle those |
1808 | void alterBatchNormImplIndex(Node* node) { |
1809 | std::set<size_t> bn_index_out_indices; |
1810 | std::set<size_t> bn_buffer_out_indices; |
1811 | |
1812 | auto subgraph = node->g(attr::Subgraph); |
1813 | for (const auto i : c10::irange(subgraph->outputs().size())) { |
1814 | auto val = subgraph->outputs()[i]; |
1815 | if (val->node()->kind() == aten::_batch_norm_impl_index && |
1816 | val->offset() == 4) { |
1817 | bn_index_out_indices.emplace(i); |
1818 | } else if ( |
1819 | val->node()->kind() == aten::_batch_norm_impl_index && |
1820 | val->offset() == 3) { |
1821 | bn_buffer_out_indices.emplace(i); |
1822 | } |
1823 | } |
1824 | |
1825 | if (!bn_index_out_indices.empty()) { |
1826 | // we output index to 0 so backwards go through native_batch_norm, which is |
1827 | // what we support; |
1828 | auto const_1 = node->owningGraph()->insertConstant(IValue(0)); |
1829 | const_1->node()->moveBefore(node); |
1830 | for (auto i : bn_index_out_indices) { |
1831 | node->outputs()[i]->replaceAllUsesWith(const_1); |
1832 | } |
1833 | } |
1834 | |
1835 | if (!bn_buffer_out_indices.empty()) { |
1836 | auto graph = node->owningGraph(); |
1837 | std::vector<int64_t> sizes{0}; // empty tensor with no size; |
1838 | // std::vector<int64_t> sizes; // empty tensor with no size; |
1839 | auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes)); |
1840 | const_size_0->node()->moveBefore(node); |
1841 | auto const_0 = node->owningGraph()->insertConstant(IValue(0)); |
1842 | const_0->node()->moveBefore(node); |
1843 | auto none_val = node->owningGraph()->insertConstant(IValue()); |
1844 | none_val->node()->moveBefore(node); |
1845 | auto device = |
1846 | graph->insertNode(graph->create(prim::device, {node->inputs()[0]}, 1)); |
1847 | device->moveBefore(node); |
1848 | device->output()->setType(DeviceObjType::get()); |
1849 | auto empty_tensor = graph->insertNode(graph->create( |
1850 | aten::empty, |
1851 | {const_size_0, const_0, none_val, device->output(), none_val, none_val}, |
1852 | 1)); |
1853 | empty_tensor->moveBefore(node); |
1854 | for (auto i : bn_buffer_out_indices) { |
1855 | node->outputs()[i]->replaceAllUsesWith(empty_tensor->output()); |
1856 | } |
1857 | } |
1858 | |
1859 | bn_index_out_indices.insert( |
1860 | bn_buffer_out_indices.begin(), bn_buffer_out_indices.end()); |
1861 | for (auto iter = bn_index_out_indices.crbegin(); |
1862 | iter != bn_index_out_indices.crend(); |
1863 | ++iter) { |
1864 | subgraph->eraseOutput(*iter); |
1865 | node->eraseOutput(*iter); |
1866 | } |
1867 | } |
1868 | |
1869 | // rewire empty byte-typed reserve space tensor input to an empty float-typed |
1870 | // tensor, because `CudaFusionGroup` doesn't support byte-typed tensor, nor does |
1871 | // it use reserve space. |
1872 | void alterBatchNormImplIndexBackward(Node* node) { |
1873 | std::set<size_t> bn_buffer_in_indices; |
1874 | |
1875 | auto subgraph = node->g(attr::Subgraph); |
1876 | for (auto n : subgraph->nodes()) { |
1877 | if (n->kind() == aten::_batch_norm_impl_index_backward) { |
1878 | // 11th inputs are `reserve`, which is not used by codegen kernel and its |
1879 | // type is not supported `Byte`. So we disconnect it here to avoid codegen |
1880 | // error |
1881 | auto byte_input = n->inputs()[11]; |
1882 | // TODO: let's check the data type for buffer and skip if it's good |
1883 | // TODO: we can actually support it by adding an extra inputs to the |
1884 | // subgraph |
1885 | // TODO: assert on empty buffer |
1886 | TORCH_INTERNAL_ASSERT( |
1887 | byte_input->node() == subgraph->param_node(), |
1888 | "Assumption that reserve input to aten::_batch_norm_impl_index_backward comes from forward graph is broken" ); |
1889 | bn_buffer_in_indices.emplace(byte_input->offset()); |
1890 | } |
1891 | } |
1892 | |
1893 | if (!bn_buffer_in_indices.empty()) { |
1894 | auto graph = node->owningGraph(); |
1895 | std::vector<int64_t> sizes{0}; // empty tensor with no size; |
1896 | // std::vector<int64_t> sizes{}; // empty tensor with no size; |
1897 | auto const_size_0 = node->owningGraph()->insertConstant(IValue(sizes)); |
1898 | const_size_0->node()->moveBefore(node); |
1899 | auto const_0 = node->owningGraph()->insertConstant(IValue(6)); |
1900 | const_0->node()->moveBefore(node); |
1901 | auto none_val = node->owningGraph()->insertConstant(IValue()); |
1902 | none_val->node()->moveBefore(node); |
1903 | auto device = |
1904 | graph->insertNode(graph->create(prim::device, {node->inputs()[1]}, 1)); |
1905 | device->moveBefore(node); |
1906 | device->output()->setType(DeviceObjType::get()); |
1907 | auto empty_tensor = graph->insertNode(graph->create( |
1908 | aten::empty, |
1909 | {const_size_0, const_0, none_val, device->output(), none_val, none_val}, |
1910 | 1)); |
1911 | empty_tensor->moveBefore(node); |
1912 | |
1913 | for (const auto& item : bn_buffer_in_indices) { |
1914 | subgraph->inputs()[item]->setType( |
1915 | node->inputs()[item]->type()->cast<TensorType>()->withScalarType( |
1916 | at::ScalarType::Float)); |
1917 | node->replaceInput(item, empty_tensor->output()); |
1918 | } |
1919 | } |
1920 | } |
1921 | |
1922 | void alterBatchNormImpls(Block* block) { |
1923 | std::vector<Node*> fusions; |
1924 | for (Node* n : block->nodes()) { |
1925 | for (Block* b : n->blocks()) { |
1926 | alterBatchNormImpls(b); |
1927 | } |
1928 | if (n->kind() == prim::CudaFusionGroup) { |
1929 | fusions.push_back(n); |
1930 | } |
1931 | } |
1932 | for (Node* fusion : fusions) { |
1933 | // remove index & reserve from outputs; |
1934 | alterBatchNormImplIndex(fusion); |
1935 | // remove reserve from inputs; |
1936 | alterBatchNormImplIndexBackward(fusion); |
1937 | } |
1938 | } |
1939 | |
1940 | // We absorb `prim::dtype` node into CudaFusion structure. The structure below |
1941 | // |
1942 | // %1 = prim::CudaFusionGuard(...) |
1943 | // %2, %3 = prim::If(...) |
1944 | // block0(): |
1945 | // %4, %5 = prim::CudaFusionGroup(...) |
1946 | // -> (%4, %5) |
1947 | // block1(): |
1948 | // %6, %7 = prim::FallbackGraph(...) |
1949 | // -> (%6, %7) |
1950 | // %4 = prim::dtype(%3) |
1951 | // ... (uses %2, %4, but never reference to %3 any more) |
1952 | // |
1953 | // is updated to: |
1954 | // |
1955 | // %1 = prim::CudaFusionGuard(...) |
1956 | // %2, %3 = prim::If(...) |
1957 | // block0(): |
1958 | // %4 = prim::CudaFusionGroup(...) # %5 is also removed from subgraph |
1959 | // %8 = prim::Constant[value=...]() |
1960 | // -> (%4, %8) |
1961 | // block1(): |
1962 | // %6, %7 = prim::FallbackGraph(...) |
1963 | // %9 = prim::dtype(%7) |
1964 | // -> (%6, %9) |
1965 | // # %4 = prim::dtype(%3) is removed. All reference to %4 is replaced with %3 |
1966 | // ... (uses %2, %4, but never reference to %3 any more) |
1967 | void removeOutputUsedOnlyInDtype(Node* fusion_node) { |
1968 | auto fusion_block = fusion_node->owningBlock(); |
1969 | TORCH_INTERNAL_ASSERT( |
1970 | fusion_block->owningNode() && |
1971 | fusion_block->owningNode()->kind() == prim::If, |
1972 | "CudaFusionGroup should be inside `prim::CudaFusionGuard` / `prim::If`" ); |
1973 | |
1974 | auto if_node = fusion_block->owningNode(); |
1975 | auto fusion_node_graph = fusion_node->g(attr::Subgraph); |
1976 | auto fallback_block = if_node->blocks()[1]; |
1977 | |
1978 | bool updated = false; |
1979 | // Iterating in this order is crucial for correctness (i has to reflect the |
1980 | // current true index of outputs[i])! |
1981 | for (int64_t i = static_cast<int64_t>(if_node->outputs().size()) - 1; i >= 0; |
1982 | --i) { |
1983 | auto output = if_node->outputs()[i]; |
1984 | // output only used in dtype, we eliminate the output and rely on |
1985 | // profiled/static scalar type inference to save on memory IO. |
1986 | if (usedOnlyInDtype(output)) { |
1987 | updated = true; |
1988 | { |
1989 | // update fusion_block to output profiled scalar type |
1990 | auto fusion_output = fusion_block->outputs()[i]; |
1991 | auto tensor_type = fusion_output->type()->cast<TensorType>(); |
1992 | TORCH_INTERNAL_ASSERT( |
1993 | tensor_type, "non tensor fed to dtype is not supported" ); |
1994 | auto scalar_type = tensor_type->scalarType(); |
1995 | TORCH_INTERNAL_ASSERT( |
1996 | scalar_type.has_value(), |
1997 | "ScalarType should be static for Tensors in fusion for amp optimization" ); |
1998 | auto type_const = |
1999 | fusion_block->owningGraph()->insertConstant(IValue(scalar_type)); |
2000 | type_const->setType(IntType::get()); |
2001 | type_const->node()->moveBefore(fusion_block->return_node()); |
2002 | fusion_block->replaceOutput(i, type_const); |
2003 | |
2004 | // removing the dangling output tensor from CudaFusionGroup would |
2005 | // require tracing output i from block to output j in CudaFusionGroup. |
2006 | // We choose to instead do that later by simply checking uses |
2007 | } |
2008 | |
2009 | { |
2010 | // update fallback_block to output dtype instead of tensor |
2011 | auto tensor_output = fallback_block->outputs()[i]; |
2012 | auto dtype_node = fallback_block->owningGraph()->create( |
2013 | prim::dtype, tensor_output, 1); |
2014 | dtype_node->output()->setType(IntType::get()); |
2015 | fallback_block->appendNode(dtype_node); |
2016 | fallback_block->replaceOutput(i, dtype_node->output()); |
2017 | } |
2018 | |
2019 | // we just shot-cut the `dtype` node since we are already outputing dtype |
2020 | auto uses = output->uses(); |
2021 | for (Use u : uses) { |
2022 | AT_ASSERT(u.user->matches("prim::dtype(Tensor a) -> int" )); |
2023 | u.user->output()->replaceAllUsesWith(output); |
2024 | u.user->destroy(); |
2025 | } |
2026 | output->setType(IntType::get()); |
2027 | } |
2028 | } |
2029 | |
2030 | if (updated) { |
2031 | // Remove fusion node output with no uses; |
2032 | for (int64_t i = static_cast<int64_t>(fusion_node->outputs().size()) - 1; |
2033 | i >= 0; |
2034 | --i) { |
2035 | if (fusion_node->output(i)->uses().empty()) { |
2036 | GRAPH_UPDATE( |
2037 | "removing output: " , i, " from fusion node: " , *fusion_node); |
2038 | fusion_node->eraseOutput(i); |
2039 | fusion_node_graph->eraseOutput(i); |
2040 | } |
2041 | } |
2042 | |
2043 | fusion_node->g_(attr::Subgraph, fusion_node_graph); |
2044 | } |
2045 | } |
2046 | |
2047 | // For output tensors in fusion group that is only used by dtype node, with |
2048 | // CudaFusionGuard, we can short-cut it with constant dtype directly instead to |
2049 | // save IO memory bandwidth. |
2050 | // The reason that we do it after we insert the guard, instead of doing it along |
2051 | // during graph fusion/partitioning, is that we needed to handle the fallback |
2052 | // differently, since fallback is not inside CudaFusionGuard, and hence doesn't |
2053 | // have the dtype as a constant. |
2054 | void removeOutputUsedOnlyInDtype(Block* block) { |
2055 | std::vector<Node*> fusions; |
2056 | for (Node* n : block->nodes()) { |
2057 | for (Block* b : n->blocks()) { |
2058 | removeOutputUsedOnlyInDtype(b); |
2059 | } |
2060 | if (n->kind() == prim::CudaFusionGroup) { |
2061 | fusions.push_back(n); |
2062 | } |
2063 | } |
2064 | for (Node* fusion : fusions) { |
2065 | // remove index & reserve from outputs; |
2066 | removeOutputUsedOnlyInDtype(fusion); |
2067 | } |
2068 | } |
2069 | |
2070 | void RemoveProfileIValue(Node* profile_ivalue) { |
2071 | for (const auto& use : profile_ivalue->output()->uses()) { |
2072 | if (use.user->kind() == prim::Constant) { |
2073 | use.user->output()->replaceAllUsesWith(profile_ivalue->input()); |
2074 | use.user->destroy(); |
2075 | } |
2076 | } |
2077 | profile_ivalue->output()->replaceAllUsesWith(profile_ivalue->input()); |
2078 | profile_ivalue->destroy(); |
2079 | } |
2080 | |
2081 | void (Node* profile_ivalue) { |
2082 | auto const_o = createConditionalConstant(profile_ivalue); |
2083 | if (const_o) { |
2084 | auto const_n = const_o->node(); |
2085 | const_n->moveAfter(profile_ivalue); |
2086 | profile_ivalue->output()->replaceAllUsesAfterNodeWith(const_n, const_o); |
2087 | // special wiring, we add this input to constant simply in order to create |
2088 | // dependency, which we can trace and remove later; |
2089 | const_n->addInput(profile_ivalue->output()); |
2090 | } else { |
2091 | // no profile value available, remove profile_ivalue node; |
2092 | RemoveProfileIValue(profile_ivalue); |
2093 | } |
2094 | } |
2095 | |
2096 | // break `linear` layer into `matmul` and `add_optional`. This allows us to fuse |
2097 | // the binary operation without supporting gemm. |
2098 | // Note that we are not breaking `linear` layer without bias. |
2099 | void decomposeLinearOps(Block* block) { |
2100 | std::vector<Node*> linear_nodes; |
2101 | for (Node* n : block->nodes()) { |
2102 | for (Block* b : n->blocks()) { |
2103 | decomposeLinearOps(b); |
2104 | } |
2105 | // only decompose `linear` layer with bias |
2106 | if (n->kind() == aten::linear && |
2107 | !n->input(2)->type()->isSubtypeOf( |
2108 | static_cast<c10::TypePtr>(NoneType::get()))) { |
2109 | linear_nodes.push_back(n); |
2110 | } |
2111 | } |
2112 | |
2113 | auto graph = block->owningGraph(); |
2114 | for (Node* n : linear_nodes) { |
2115 | WithInsertPoint guard(n); |
2116 | auto weight_t = graph->insertNode(graph->create(aten::t, {n->input(1)}, 1)); |
2117 | auto matmul = graph->insertNode( |
2118 | graph->create(aten::matmul, {n->input(0), weight_t->output()}, 1)); |
2119 | auto input_tensor_type = n->input(0)->type()->cast<c10::TensorType>(); |
2120 | if (!input_tensor_type) { |
2121 | TORCH_WARN_ONCE( |
2122 | "linear input 0 is required to be tensor for linear decompose" ); |
2123 | continue; |
2124 | } |
2125 | auto mat0_size = input_tensor_type->sizes().concrete_sizes(); |
2126 | auto mat1_size = |
2127 | n->input(1)->type()->cast<c10::TensorType>()->sizes().concrete_sizes(); |
2128 | |
2129 | // TODO: Continuing here is not necessary when we can handle matmul, right |
2130 | // now we are splitting the linear between matmul & bias_add. Our fuser can |
2131 | // only take the second half and we would need the size information. |
2132 | if (!mat0_size.has_value() || !mat1_size.has_value()) { |
2133 | TORCH_WARN_ONCE( |
2134 | "concrete shape for linear input & weight are required to decompose into matmul + bias" ); |
2135 | continue; |
2136 | } |
2137 | |
2138 | // only decompose for input with nDims >= 4. since lower rank linear eager |
2139 | // is already fused |
2140 | if (mat0_size->size() < 4) { |
2141 | continue; |
2142 | } |
2143 | |
2144 | auto out_size = mat0_size.value(); |
2145 | TORCH_INTERNAL_ASSERT( |
2146 | mat1_size->size() == 2 || mat1_size->size() == 1, |
2147 | "weight dimension for linear is expected to be 1 or 2, but got: " , |
2148 | mat1_size->size()); |
2149 | if (mat1_size->size() == 2) { |
2150 | out_size[out_size.size() - 1] = mat1_size.value()[0]; |
2151 | } else if (mat1_size->size() == 1) { |
2152 | out_size.pop_back(); |
2153 | } |
2154 | matmul->output()->setType(input_tensor_type->withSizes(out_size)); |
2155 | |
2156 | // TODO: memory stride should be considered here, our inference above is not |
2157 | // safe. |
2158 | auto bias = graph->insertNode( |
2159 | graph->create(prim::add_optional, {matmul->output(0), n->input(2)}, 1)); |
2160 | bias->output()->setType(matmul->output(0)->type()); |
2161 | |
2162 | n->output()->replaceAllUsesWith(bias->output()); |
2163 | n->destroy(); |
2164 | } |
2165 | } |
2166 | |
2167 | // Replace 'operation' with 'operation_copy' to guard alias operations. |
2168 | // Supports View, Reshape, Squeeze, and Unsqueeze |
2169 | void replaceAliasOpsWithCopy(std::shared_ptr<Graph>& graph, Block* block) { |
2170 | static std::unordered_map<Symbol, Symbol> alias_to_copy_mapping( |
2171 | {{aten::expand, prim::expand_copy}, |
2172 | {aten::expand_as, prim::expand_as_copy}, |
2173 | {aten::permute, prim::permute_copy}, |
2174 | {aten::transpose, prim::transpose_copy}, |
2175 | {aten::t, prim::t_copy}}); |
2176 | // TODO: revert disabled aten::view |
2177 | // ({{aten::view, prim::view_copy}, |
2178 | // {aten::reshape, prim::reshape_copy}, |
2179 | // {aten::squeeze, prim::squeeze_copy}, |
2180 | // {aten::unsqueeze, prim::unsqueeze_copy}, |
2181 | // {aten::flatten, prim::flatten_copy}}); |
2182 | |
2183 | std::vector<Node*> maybe_safe_alias_nodes; |
2184 | for (Node* n : block->nodes()) { |
2185 | for (Block* b : n->blocks()) { |
2186 | replaceAliasOpsWithCopy(graph, b); |
2187 | } |
2188 | if (alias_to_copy_mapping.find(n->kind()) != alias_to_copy_mapping.end()) { |
2189 | maybe_safe_alias_nodes.push_back(n); |
2190 | } |
2191 | } |
2192 | |
2193 | auto alias_db = std::make_unique<AliasDb>(graph); |
2194 | |
2195 | auto safeToChangeAliasToCopy = [&alias_db](Node* n) { |
2196 | return !alias_db->hasWriters(n->input(0)) && |
2197 | !alias_db->hasWriters(n->output(0)); |
2198 | }; |
2199 | |
2200 | auto replaceAliasWithCopy = [&graph, &alias_db](Node* n) { |
2201 | WithInsertPoint guard(n); |
2202 | auto copy_op = graph->insertNode( |
2203 | graph->create(alias_to_copy_mapping[n->kind()], n->inputs(), 1)); |
2204 | copy_op->output()->setType(n->output(0)->type()); |
2205 | |
2206 | // adding newly created value into alias_db; |
2207 | alias_db->createValue(copy_op->output()); |
2208 | |
2209 | n->output()->replaceAllUsesWith(copy_op->output()); |
2210 | n->destroy(); |
2211 | }; |
2212 | |
2213 | for (Node* n : maybe_safe_alias_nodes) { |
2214 | if (!safeToChangeAliasToCopy(n)) { |
2215 | continue; |
2216 | } |
2217 | replaceAliasWithCopy(n); |
2218 | } |
2219 | } |
2220 | |
2221 | // Revert all 'operation_copy' with 'operation' except in CudaFusionGroup |
2222 | // e.g., Any non-fused alias operation including within the prim::FallbackGraph |
2223 | // Supports View, Reshape, Squeeze, and Unsqueeze |
2224 | void revertAliasCopyOps(std::shared_ptr<Graph>& graph, Block* block) { |
2225 | static std::unordered_map<Symbol, Symbol> copy_to_alias_mapping( |
2226 | {{prim::expand_copy, aten::expand}, |
2227 | {prim::expand_as_copy, aten::expand_as}, |
2228 | {prim::permute_copy, aten::permute}, |
2229 | {prim::transpose_copy, aten::transpose}, |
2230 | {prim::t_copy, aten::t}}); |
2231 | // TODO: revert disabled aten::view |
2232 | // ({{prim::view_copy, aten::view}, |
2233 | // {prim::flatten_copy, aten::flatten}, |
2234 | // {prim::reshape_copy, aten::reshape}, |
2235 | // {prim::squeeze_copy, aten::squeeze}, |
2236 | // {prim::unsqueeze_copy, aten::unsqueeze}}); |
2237 | |
2238 | std::vector<Node*> alias_copy_ops; |
2239 | for (Node* n : block->nodes()) { |
2240 | // Allow alias copy ops in CudaFusionGroup |
2241 | if (n->kind() == prim::CudaFusionGroup) { |
2242 | continue; |
2243 | } |
2244 | // Revert alias copy ops within FallbackGraph |
2245 | if (n->kind() == prim::FallbackGraph) { |
2246 | auto subgraph = n->g(attr::Subgraph); |
2247 | revertAliasCopyOps(subgraph, subgraph->block()); |
2248 | } |
2249 | for (Block* b : n->blocks()) { |
2250 | revertAliasCopyOps(graph, b); |
2251 | } |
2252 | // Revert any non-fused alias copy ops |
2253 | if (copy_to_alias_mapping.find(n->kind()) != copy_to_alias_mapping.end()) { |
2254 | alias_copy_ops.push_back(n); |
2255 | } |
2256 | } |
2257 | |
2258 | auto replaceCopyWithAlias = [&graph](Node* n) { |
2259 | WithInsertPoint guard(n); |
2260 | auto alias_op = graph->insertNode( |
2261 | graph->create(copy_to_alias_mapping[n->kind()], n->inputs(), 1)); |
2262 | alias_op->output()->setType(n->output(0)->type()); |
2263 | n->output()->replaceAllUsesWith(alias_op->output()); |
2264 | n->destroy(); |
2265 | }; |
2266 | |
2267 | for (Node* n : alias_copy_ops) { |
2268 | replaceCopyWithAlias(n); |
2269 | } |
2270 | } |
2271 | |
2272 | // break `conv2d` layer into `conv2d` and `add_optional`. This allows us to fuse |
2273 | // the binary operation without supporting gemm. |
2274 | // Note that we are not breaking `conv2d` layer without bias. |
2275 | void decomposeConvOps(Block* block) { |
2276 | std::vector<Node*> conv_nodes; |
2277 | for (Node* n : block->nodes()) { |
2278 | for (Block* b : n->blocks()) { |
2279 | decomposeConvOps(b); |
2280 | } |
2281 | // TODO: expand this to convXd |
2282 | // only decompose `conv2d` layer with bias. |
2283 | if (n->kind() == aten::conv2d && |
2284 | n->input(2)->type()->isSubtypeOf(TensorType::get())) { |
2285 | conv_nodes.push_back(n); |
2286 | } |
2287 | } |
2288 | |
2289 | auto graph = block->owningGraph(); |
2290 | for (Node* n : conv_nodes) { |
2291 | // TODO: only handling conv2d at this moment, expand this to convXd |
2292 | WithInsertPoint guard(n); |
2293 | |
2294 | auto const_neg_1 = n->owningGraph()->insertConstant(IValue(-1)); |
2295 | auto const_none = n->owningGraph()->insertConstant(IValue()); |
2296 | |
2297 | auto bias_tensor_type = n->input(2)->type()->cast<c10::TensorType>(); |
2298 | auto bias_size_opt = bias_tensor_type->sizes().concrete_sizes(); |
2299 | if (!bias_size_opt.has_value()) { |
2300 | TORCH_WARN_ONCE( |
2301 | "concrete shape for bias input is required to decompose into conv + bias" ); |
2302 | continue; |
2303 | } |
2304 | // bias shape (C) |
2305 | auto bias_size = bias_size_opt.value(); |
2306 | |
2307 | auto tmp = graph->insertNode( |
2308 | graph->create(aten::unsqueeze, {n->input(2), const_neg_1}, 1)); |
2309 | // new shape (C, 1) |
2310 | bias_size.emplace_back(1); |
2311 | tmp->output()->setType(bias_tensor_type->withSizes(bias_size)); |
2312 | |
2313 | auto unsqueezed_bias = graph->insertNode( |
2314 | graph->create(aten::unsqueeze, {tmp->output(), const_neg_1}, 1)); |
2315 | // new shape (C, 1, 1) |
2316 | bias_size.emplace_back(1); |
2317 | unsqueezed_bias->output()->setType(bias_tensor_type->withSizes(bias_size)); |
2318 | |
2319 | // replace bias input to none |
2320 | n->replaceInput(2, const_none); |
2321 | |
2322 | // add bias as a new node |
2323 | auto bias_n = graph->insertNode(graph->create( |
2324 | prim::add_optional, {n->output(0), unsqueezed_bias->output()}, 1)); |
2325 | bias_n->output()->setType(n->output(0)->type()); |
2326 | // moving add_optional after conv2d since it uses its output. |
2327 | bias_n->moveAfter(n); |
2328 | |
2329 | // replace later uses |
2330 | n->output(0)->replaceAllUsesAfterNodeWith(bias_n, bias_n->output()); |
2331 | } |
2332 | } |
2333 | |
2334 | bool removeInplaceOperations(const std::shared_ptr<Graph>& graph) { |
2335 | // TODO: we should probably get a list that's close to what our fuser handles |
2336 | static std::unordered_set<Symbol> inplace_ops = []() { |
2337 | std::unordered_set<Symbol> target_ops; |
2338 | for (const auto& iter : activation_type_promotion_mapping) { |
2339 | std::string name = std::string(iter.first.toQualString()) + "_" ; |
2340 | target_ops.insert(Symbol::fromQualString(name)); |
2341 | } |
2342 | |
2343 | target_ops.insert(Symbol::fromQualString("aten::add_" )); |
2344 | target_ops.insert(Symbol::fromQualString("aten::mul_" )); |
2345 | target_ops.insert(Symbol::fromQualString("aten::div_" )); |
2346 | target_ops.insert(Symbol::fromQualString("aten::sub_" )); |
2347 | return target_ops; |
2348 | }(); |
2349 | |
2350 | return RemoveTensorMutation( |
2351 | graph, [&](Node* node) { return inplace_ops.count(node->kind()) != 0; }); |
2352 | } |
2353 | |
2354 | // Recursively traverse blocks, gather all nodes with given symbol, |
2355 | // and then apply mutator function. |
2356 | void mutateNode( |
2357 | Block* block, |
2358 | Symbol symbol, |
2359 | const std::function<void(Node*)>& func) { |
2360 | // Recursively call mutateNode on blocks |
2361 | // Gather all nodes with given symbol |
2362 | std::vector<Node*> nodes; |
2363 | for (Node* n : block->nodes()) { |
2364 | for (Block* b : n->blocks()) { |
2365 | mutateNode(b, symbol, func); |
2366 | } |
2367 | if (n->kind() == symbol) { |
2368 | nodes.push_back(n); |
2369 | } |
2370 | } |
2371 | |
2372 | // Apply mutator funcion to every node |
2373 | for (Node* n : nodes) { |
2374 | func(n); |
2375 | } |
2376 | } |
2377 | |
2378 | // For the given CudaFusionGroup, separate nested views and remove any unused, |
2379 | // intermediate views |
2380 | void separateNestedViews(Node* cuda_fusion_group) { |
2381 | TORCH_INTERNAL_ASSERT(cuda_fusion_group->kind() == prim::CudaFusionGroup); |
2382 | |
2383 | auto isView = [](Node* node) { |
2384 | static std::unordered_set<Symbol> alias_op_set( |
2385 | {prim::view_copy, prim::reshape_copy}); |
2386 | return alias_op_set.find(node->kind()) != alias_op_set.end(); |
2387 | }; |
2388 | |
2389 | // node -> input / output values |
2390 | auto isNestedView = [&isView](Node* node) { |
2391 | return isView(node) && isView(node->input(0)->node()); |
2392 | }; |
2393 | |
2394 | auto subgraph = cuda_fusion_group->g(attr::Subgraph); |
2395 | for (auto node : subgraph->block()->nodes()) { |
2396 | if (isNestedView(node)) { |
2397 | // grandparent -> (view / reshape) parent -> (view / reshape) node |
2398 | auto parent_value = node->input(0); |
2399 | auto parent = parent_value->node(); |
2400 | |
2401 | auto grandparent_value = parent->input(0); |
2402 | C10_UNUSED auto grandparent = grandparent_value->node(); |
2403 | |
2404 | // Before: gp -> x -> n |
2405 | // After: gp -> x / gp -> n |
2406 | // Delete x if no more uses |
2407 | node->replaceInputWith(parent_value, grandparent_value); |
2408 | if (!parent->hasUses()) { |
2409 | parent->destroy(); |
2410 | } |
2411 | } |
2412 | } |
2413 | } |
2414 | |
2415 | } // anonymous namespace |
2416 | |
2417 | void CudaFuseGraph(std::shared_ptr<Graph>& graph) { |
2418 | FUSER_PERF_SCOPE("nvFuser::Manager::CudaFuseGraph" ); |
2419 | GRAPH_DUMP("Before Fusion: " , graph); |
2420 | |
2421 | // TODO: extract & guard profile_ivalue; but how do we restore it??? |
2422 | // I don't know how to store edge/node in attribute. so let's abuse data flow |
2423 | // dependency and add inputs to conditional constant generated by |
2424 | // aten::profile_ivalue |
2425 | mutateNode(graph->block(), prim::profile_ivalue, ExtractProfileIValue); |
2426 | GRAPH_DEBUG("insert conditional constant from profile_ivalue: " , *graph); |
2427 | |
2428 | // TODO: we need to properly restore shape information after fusion. |
2429 | // shamelessly use tool from NNC. |
2430 | RemoveProfileNodesAndSpecializeTypes(graph); |
2431 | GRAPH_DEBUG("After Profiling Nodes Removed: " , *graph); |
2432 | |
2433 | // replace inplace operation to functional version to expose fusion |
2434 | // opportunities |
2435 | removeInplaceOperations(graph); |
2436 | GRAPH_DEBUG("Remove inplace operations: " , *graph); |
2437 | |
2438 | // TODO: separate passes into different file; |
2439 | if (isOptionEnabled(EnableOption::LinearDecomposition)) { |
2440 | // TODO: restore decomposition after fusion, in case we are decomposing |
2441 | // operation that can't be fused; |
2442 | decomposeLinearOps(graph->block()); |
2443 | } |
2444 | GRAPH_DEBUG("After decompose Linear Ops by nvfuser: " , *graph); |
2445 | |
2446 | if (isOptionEnabled(EnableOption::ConvDecomposition)) { |
2447 | decomposeConvOps(graph->block()); |
2448 | } |
2449 | GRAPH_DEBUG("After decompose decompose Conv Ops by nvfuser: " , *graph); |
2450 | |
2451 | replaceAliasOpsWithCopy(graph, graph->block()); |
2452 | GRAPH_DEBUG("replace alias_op with alias_copy by nvfuser: " , *graph); |
2453 | |
2454 | CudaGraphFuser cgf(graph->block(), graph); |
2455 | cgf.run(); |
2456 | GRAPH_DEBUG("After Fusion: " , *graph); |
2457 | |
2458 | // guard input types as well as conditional constants from |
2459 | // aten::profile_ivalue |
2460 | guardFusionGroups(graph->block(), cgf.fusion_value_to_runtime_shape_); |
2461 | GRAPH_DEBUG("After Guard Fusion: " , *graph); |
2462 | |
2463 | // mutate `aten::_batch_norm_impl_index` and |
2464 | // `aten::_batch_norm_impl_index_backward` node in the fusion group to WAR |
2465 | // the lack of fusion support on integer output as well as byte-typed tensor. |
2466 | alterBatchNormImpls(graph->block()); |
2467 | GRAPH_DEBUG("After _batch_norm_impl_index: " , *graph); |
2468 | |
2469 | mutateNode(graph->block(), prim::profile_ivalue, RemoveProfileIValue); |
2470 | |
2471 | GRAPH_DEBUG("Before remove missing profiling: " , *graph); |
2472 | removeFusionWithMissingProfilingInformation(graph->block()); |
2473 | GRAPH_DEBUG("After remove missing profiling: " , *graph); |
2474 | |
2475 | // optimization targeting AMP |
2476 | removeOutputUsedOnlyInDtype(graph->block()); |
2477 | GRAPH_DEBUG("After removeOutputUsedOnlyInDtype: " , *graph); |
2478 | |
2479 | mutateNode(graph->block(), prim::CudaFusionGroup, separateNestedViews); |
2480 | GRAPH_DEBUG( |
2481 | "separate nested and delete redundant views in CudaFusionGroup:" , *graph); |
2482 | |
2483 | revertAliasCopyOps(graph, graph->block()); |
2484 | GRAPH_DEBUG("revert alias_copy ops by nvfuser: " , *graph); |
2485 | |
2486 | dumpFusionGroups(graph); |
2487 | |
2488 | // After FuseGraph some common subexpressions may come back |
2489 | EliminateCommonSubexpression(graph); |
2490 | // We might have emitted a fair amount of useless shape propagating code, so |
2491 | // remove it |
2492 | EliminateDeadCode(graph); |
2493 | |
2494 | GRAPH_DEBUG("After ECS & Dead code removal: " , *graph); |
2495 | // Improve the quality of shape propagation code that was left |
2496 | PeepholeOptimizeShapeExpressions(graph->block()); |
2497 | GRAPH_DEBUG("After PeepholeOptimizeShapeExpressions: " , *graph); |
2498 | |
2499 | // TODO: we need to properly restore shape information after fusion. |
2500 | // shamelessly use tool from NNC. |
2501 | RemoveTensorTypeSpecializations(graph); |
2502 | |
2503 | GRAPH_DUMP("Before Compilation: " , graph); |
2504 | // Compile CudaFusionGroup |
2505 | compileFusionRecursive(graph->block()); |
2506 | } |
2507 | |
2508 | } // namespace cuda |
2509 | } // namespace fuser |
2510 | } // namespace jit |
2511 | } // namespace torch |
2512 | |