1 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
2 | |
3 | #include <ATen/core/interned_strings.h> |
4 | #include <ATen/core/symbol.h> |
5 | #include <ATen/record_function.h> |
6 | #include <c10/util/FunctionRef.h> |
7 | #include <c10/util/irange.h> |
8 | #include <torch/csrc/jit/codegen/cuda/interface.h> |
9 | #include <torch/csrc/jit/codegen/fuser/interface.h> |
10 | #include <torch/csrc/jit/ir/alias_analysis.h> |
11 | #include <torch/csrc/jit/jit_log.h> |
12 | #include <torch/csrc/jit/jit_opt_limit.h> |
13 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
14 | #include <torch/csrc/jit/passes/constant_pooling.h> |
15 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
16 | #include <torch/csrc/jit/passes/pass_manager.h> |
17 | #include <torch/csrc/jit/passes/remove_redundant_profiles.h> |
18 | #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h> |
19 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
20 | #include <torch/csrc/jit/runtime/custom_operator.h> |
21 | #include <torch/csrc/jit/runtime/graph_executor.h> |
22 | #include <torch/csrc/jit/runtime/operator_options.h> |
23 | #include <torch/csrc/jit/runtime/symbolic_shape_registry.h> |
24 | #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h> |
25 | #include <torch/csrc/jit/tensorexpr/kernel.h> |
26 | #include <torch/csrc/utils/memory.h> |
27 | |
28 | #include <utility> |
29 | |
30 | // NOLINTNEXTLINE |
31 | C10_DEFINE_bool( |
32 | torch_jit_disable_cat, |
33 | false, |
34 | "disable aten::cat in TE fusion groups" ); |
35 | |
36 | C10_DEFINE_bool( |
37 | torch_jit_enable_dynamic_shape_fusion, |
38 | false, |
39 | "enable TE fusion using dynamic shapes" ); |
40 | |
41 | namespace torch { |
42 | namespace jit { |
43 | |
44 | static bool texpr_reductions_enabled = false; |
45 | |
46 | bool isSupportedForBlock(Node* node) { |
47 | switch (node->kind()) { |
48 | case aten::add: |
49 | case aten::mul: |
50 | return true; |
51 | default: |
52 | return false; |
53 | } |
54 | } |
55 | |
56 | bool usedOnlyInSize(Value* v) { |
57 | const auto& uses = v->uses(); |
58 | return std::all_of(uses.begin(), uses.end(), [](const Use& u) { |
59 | return u.user->matches("aten::size(Tensor self) -> int[]" ); |
60 | }); |
61 | } |
62 | |
63 | Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) { |
64 | AT_ASSERT(!sizes.empty()); |
65 | Graph* graph = sizes[0]->owningGraph(); |
66 | Node* broadcast_n = |
67 | graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); |
68 | broadcast_n->output()->setType(ListType::ofInts()); |
69 | db->createValue(broadcast_n->output()); |
70 | return broadcast_n->output(); |
71 | } |
72 | |
73 | namespace tensorexpr { |
74 | |
75 | OperatorSet& getCustomOperatorSet() { |
76 | static OperatorSet _g_custom_operator_set{}; |
77 | return _g_custom_operator_set; |
78 | } |
79 | |
80 | static const OperatorSet& supported_non_eltwise_set() { |
81 | // clang-format off |
82 | static const OperatorSet supported_non_eltwise_set{ |
83 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" , |
84 | "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor" , |
85 | "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor" , |
86 | "aten::matmul(Tensor self, Tensor other) -> Tensor" , |
87 | }; |
88 | // clang-format on |
89 | return supported_non_eltwise_set; |
90 | }; |
91 | |
92 | bool isSupported(Node* node) { |
93 | // For Block codegen we allow limited ops. |
94 | if (tensorexpr::getTEGenerateBlockCode()) { |
95 | return isSupportedForBlock(node); |
96 | } |
97 | |
98 | static const OperatorSet supported_reduction_set{ |
99 | "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor" , |
100 | "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" , |
101 | "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor" , |
102 | "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor" , |
103 | }; |
104 | static const OperatorSet supported_misc_set{ |
105 | "aten::cat(Tensor[] tensors, int dim=0) -> Tensor" , |
106 | "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)" , |
107 | }; |
108 | // clang-format on |
109 | |
110 | if (get_tensorexpr_elementwise_set().contains(node) || |
111 | node->isMemberOf(supported_non_eltwise_set()) || |
112 | node->isMemberOf(supported_misc_set) || |
113 | node->isMemberOf(getCustomOperatorSet()) || |
114 | (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) { |
115 | // We only insert guards on Tensor types, so we rely on the output |
116 | // of a node being uniquely determined by its input types. |
117 | // bail if any non-Tensor input affects the output type |
118 | // and cannot be reasoned about statically |
119 | |
120 | // Value is either an int or a float (can occur from .item()) |
121 | for (Value* v : node->inputs()) { |
122 | if (v->type()->cast<NumberType>()) { |
123 | return false; |
124 | } |
125 | } |
126 | |
127 | // non-const dtype / device |
128 | for (auto arg_name : {"dtype" , "device" }) { |
129 | if (auto index = node->schema().argumentIndexWithName(arg_name)) { |
130 | if (!toIValue(node->input(*index))) { |
131 | return false; |
132 | } |
133 | } |
134 | } |
135 | |
136 | if (FLAGS_torch_jit_disable_cat && node->kind() == aten::cat) { |
137 | return false; |
138 | } |
139 | |
140 | return true; |
141 | } |
142 | |
143 | // unschematized ops |
144 | switch (node->kind()) { |
145 | case prim::ConstantChunk: |
146 | case prim::ListConstruct: |
147 | case prim::TensorExprGroup: |
148 | return true; |
149 | } |
150 | |
151 | return false; |
152 | } |
153 | } // namespace tensorexpr |
154 | |
155 | static bool texpr_fuser_enabled_ = true; |
156 | |
157 | void setTensorExprFuserEnabled(bool val) { |
158 | texpr_fuser_enabled_ = val; |
159 | } |
160 | |
161 | bool tensorExprFuserEnabled() { |
162 | static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR" ); |
163 | if (!enable_c_str) { |
164 | return texpr_fuser_enabled_; |
165 | } |
166 | if (std::string(enable_c_str) == "0" ) { |
167 | return false; |
168 | } |
169 | return true; |
170 | } |
171 | |
172 | bool tensorExprDynamicShapeFusionEnabled() { |
173 | return FLAGS_torch_jit_enable_dynamic_shape_fusion; |
174 | } |
175 | |
176 | void setTensorExprDynamicShapeFusionEnabled(bool val) { |
177 | FLAGS_torch_jit_enable_dynamic_shape_fusion = val; |
178 | } |
179 | |
180 | bool setTexprReductionsEnabled(bool value) { |
181 | bool old_value = texpr_reductions_enabled; |
182 | texpr_reductions_enabled = value; |
183 | return old_value; |
184 | } |
185 | |
186 | bool texprReductionsEnabled() { |
187 | return texpr_reductions_enabled; |
188 | } |
189 | |
190 | void removeProfileNodesAndSpecializeTypes(Block* b) { |
191 | for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { |
192 | if (it->kind() == prim::profile) { |
193 | GRAPH_DEBUG("Removing prim::profile: %" , it->output()->debugName()); |
194 | it->output()->replaceAllUsesWith(it->input()); |
195 | auto profiled_type = it->ty(attr::profiled_type)->expect<TensorType>(); |
196 | |
197 | TensorTypePtr input_tensor_type = nullptr; |
198 | bool input_is_optional = false; |
199 | if (it->input()->type()->kind() == c10::TypeKind::TensorType) { |
200 | input_tensor_type = it->input()->type()->expect<TensorType>(); |
201 | } else { |
202 | input_tensor_type = it->input() |
203 | ->type() |
204 | ->expectRef<OptionalType>() |
205 | .getElementType() |
206 | ->expect<TensorType>(); |
207 | input_is_optional = true; |
208 | } |
209 | |
210 | if (input_is_optional) { |
211 | it.destroyCurrent(); |
212 | continue; |
213 | } |
214 | |
215 | // A value can be profiled with differently typed uses. |
216 | // This can occur from: |
217 | // - having a use which is not executed, so the type will be |
218 | // TensorType::get() |
219 | // - control-flow that depends on tensor type: |
220 | // if x.size() == 2 op(x) else op(x) |
221 | // - mutation of the value on a field represented in the tensor type |
222 | // op(x); x.resize_([...]); op(x) |
223 | |
224 | // The most common case today with num_profiles = 1 is from the first |
225 | // case. Here we can just ignore non-profiled uses, and choose any of the |
226 | // profiled uses. Because we guard all tensor types in the runtime, even |
227 | // if we set a Value to have a profiled type from one use and then execute |
228 | // a use with a different profiled type, we will still be correct. |
229 | // In the future we could consider unifying the types of uses, or adding a |
230 | // type refinement node so uses can have the correct corresponding type. |
231 | if (profiled_type == TensorType::get()) { |
232 | continue; |
233 | } |
234 | |
235 | // If we encounter non-identical profiled types for the same value, merge |
236 | // them. This situation can happen if, e.g., loop unrolling duplicates |
237 | // profiled types in a loop body in a manner that isn't logically |
238 | // consistent (see TestTEFuser.test_unrolled_cat). |
239 | if (input_tensor_type == TensorType::get()) { |
240 | it->input()->setType(profiled_type); |
241 | } else { |
242 | it->input()->setType(input_tensor_type->merge(*profiled_type)); |
243 | } |
244 | |
245 | it.destroyCurrent(); |
246 | } else { |
247 | for (Block* ib : it->blocks()) { |
248 | removeProfileNodesAndSpecializeTypes(ib); |
249 | } |
250 | } |
251 | } |
252 | } |
253 | |
254 | void RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph>& graph) { |
255 | GRAPH_DEBUG("Before removeProfileNodesAndSpecializeTypes:\n" , *graph); |
256 | removeProfileNodesAndSpecializeTypes(graph->block()); |
257 | GRAPH_DEBUG("After removeProfileNodesAndSpecializeTypes:\n" , *graph); |
258 | } |
259 | |
260 | bool hasTensorTypeSpecialization(Value* v) { |
261 | if (!v->type()->cast<TensorType>()) { |
262 | return false; |
263 | } |
264 | // Constants & TensorExprGroup will always produce specialized tensor type, |
265 | // TypeCheck are inserted by this pass and only used by fusion groups that |
266 | // insert proper guards |
267 | if (v->node()->kind() == prim::Constant || |
268 | v->node()->kind() == prim::TypeCheck || |
269 | v->node()->kind() == prim::TensorExprGroup) { |
270 | return false; |
271 | } |
272 | if (v->type() == TensorType::get()) { |
273 | return false; |
274 | } |
275 | return true; |
276 | } |
277 | |
278 | void removeTensorTypeSpecialization(Value* v) { |
279 | if (hasTensorTypeSpecialization(v)) { |
280 | v->setType(TensorType::get()); |
281 | } |
282 | } |
283 | |
284 | void removeTensorTypeSpecializations(Block* block) { |
285 | for (Value* v : block->inputs()) { |
286 | removeTensorTypeSpecialization(v); |
287 | } |
288 | for (Node* n : block->nodes()) { |
289 | for (Block* b : n->blocks()) { |
290 | removeTensorTypeSpecializations(b); |
291 | } |
292 | for (Value* v : n->outputs()) { |
293 | removeTensorTypeSpecialization(v); |
294 | } |
295 | } |
296 | } |
297 | |
298 | void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph) { |
299 | removeTensorTypeSpecializations(graph->block()); |
300 | } |
301 | |
302 | void insertTypeGuard( |
303 | Node* guarded_node, |
304 | tensor_type_converter_t type_converter, |
305 | Symbol kind) { |
306 | GRAPH_DEBUG("Inserting a typecheck guard for a node" , *guarded_node); |
307 | auto subgraph = SubgraphUtils::getSubgraph(guarded_node); |
308 | |
309 | // Fixup types of the subgraph inputs |
310 | std::vector<Value*> inputs_to_check; |
311 | std::vector<TypePtr> guard_types; |
312 | for (Value* input : guarded_node->inputs()) { |
313 | // We only check inputs of the guarded nodes and expect user to infer |
314 | // intermediates and outputs shapes |
315 | if (!input->type()->cast<TensorType>()) { |
316 | continue; |
317 | } |
318 | |
319 | // fusion outputs are already guarded |
320 | if (input->node()->kind() == prim::Constant || |
321 | input->node()->kind() == prim::FusionGroup) { |
322 | continue; |
323 | } |
324 | inputs_to_check.push_back(input); |
325 | guard_types.emplace_back( |
326 | type_converter(input->type()->expect<TensorType>())); |
327 | } |
328 | if (inputs_to_check.empty()) { |
329 | return; |
330 | } |
331 | |
332 | // Add prim::TypeCheck node |
333 | // |
334 | // TypeCheck nodes look like the following: |
335 | // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool = |
336 | // prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor) |
337 | // |
338 | // They have N inputs whose types we are going to check and N+1 outputs. The |
339 | // first N outputs specify expected types and N+1-th output holds the result |
340 | // of the check (bool). |
341 | Node* typecheck_node = |
342 | guarded_node->owningGraph() |
343 | ->create(kind, inputs_to_check, inputs_to_check.size() + 1) |
344 | ->insertBefore(guarded_node); |
345 | typecheck_node->tys_(attr::types, std::move(guard_types)); |
346 | Value* typecheck_result = typecheck_node->output(inputs_to_check.size()); |
347 | |
348 | std::unordered_map<Value*, Value*> typechecked_inputs; |
349 | for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { |
350 | typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i); |
351 | } |
352 | |
353 | // Fixup types of the typecheck node outputs, which are used by the op in |
354 | // execution |
355 | typecheck_node->output(inputs_to_check.size())->setType(BoolType::get()); |
356 | for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) { |
357 | typecheck_node->output(i)->setType(typecheck_node->input(i)->type()); |
358 | } |
359 | |
360 | // Insert if |
361 | auto versioning_if = |
362 | guarded_node->owningGraph() |
363 | ->create(prim::If, {typecheck_result}, guarded_node->outputs().size()) |
364 | ->insertAfter(typecheck_node); |
365 | for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) { |
366 | versioning_if->output(idx)->setType(guarded_node->output(idx)->type()); |
367 | guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx)); |
368 | } |
369 | auto true_block = versioning_if->addBlock(); |
370 | auto false_block = versioning_if->addBlock(); |
371 | |
372 | // Fill in the false block. It should contain the unoptimized |
373 | // copy of the fused subgraph. |
374 | WithInsertPoint guard(false_block->return_node()); |
375 | const auto subgraph_outputs = insertGraph( |
376 | *guarded_node->owningGraph(), *subgraph, guarded_node->inputs()); |
377 | for (Value* output : subgraph_outputs) { |
378 | false_block->registerOutput(output); |
379 | } |
380 | |
381 | // types get copied to the fallback graph, so remove specializations before |
382 | // replacing |
383 | removeTensorTypeSpecializations(false_block); |
384 | replaceBlockWithFallbackGraph(false_block, guarded_node->inputs()); |
385 | |
386 | // Fill in the true block. It has all inputs type-checked and its |
387 | // body should be the fusion group node. |
388 | guarded_node->moveBefore(true_block->return_node()); |
389 | for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) { |
390 | if (typechecked_inputs.count(guarded_node->input(idx))) { |
391 | guarded_node->replaceInput( |
392 | idx, typechecked_inputs.at(guarded_node->input(idx))); |
393 | } |
394 | } |
395 | for (Value* output : guarded_node->outputs()) { |
396 | true_block->registerOutput(output); |
397 | } |
398 | } |
399 | |
400 | namespace { |
401 | bool has_unsupported_pin_memory(const Node* node) { |
402 | // cant support non-constant pin_memory or pin_memory = True |
403 | if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory" )) { |
404 | int index = *maybe_index; |
405 | auto inp = node->input(index); |
406 | if (inp->type() != NoneType::get() && |
407 | constant_as<bool>(inp).value_or(true)) { |
408 | return true; |
409 | } |
410 | } |
411 | return false; |
412 | } |
413 | } // namespace |
414 | |
415 | class TensorExprFuser { |
416 | public: |
417 | TensorExprFuser( |
418 | std::shared_ptr<Graph> graph, |
419 | size_t min_group_size, |
420 | bool add_composed_op, |
421 | bool fuse_to_dynamic_shapes) |
422 | : graph_(std::move(graph)), |
423 | min_group_size_(min_group_size), |
424 | add_composed_op_(add_composed_op), |
425 | fuse_to_dynamic_shapes_(fuse_to_dynamic_shapes) { |
426 | parseTENotFuseOption(); |
427 | } |
428 | |
429 | // Builds up expressions that compute shapes of all intermediates (and |
430 | // outputs) of the fusion group, based on the sizes of inputs. You should run |
431 | // DCE to remove those that you end up not using. |
432 | std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) { |
433 | GRAPH_DUMP("buildShapeExpressions for " , fusion_group->g(attr::Subgraph)); |
434 | WithInsertPoint insert_guard{fusion_group->next()}; |
435 | std::unordered_map<Value*, Value*> shape_of; |
436 | |
437 | Graph* graph = fusion_group->owningGraph(); |
438 | auto subgraph = fusion_group->g(attr::Subgraph); |
439 | |
440 | auto inputs = fusion_group->inputs(); |
441 | auto sinputs = subgraph->inputs(); |
442 | AT_ASSERT(inputs.size() == sinputs.size()); |
443 | for (const auto i : c10::irange(inputs.size())) { |
444 | if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) { |
445 | Value* soutput = graph->insert(aten::size, {inputs[i]}); |
446 | aliasDb_->createValue(soutput); |
447 | GRAPH_DEBUG( |
448 | "Adding a mapping for %" , |
449 | sinputs[i]->debugName(), |
450 | " " , |
451 | getHeader(soutput->node())); |
452 | shape_of[sinputs[i]] = soutput; |
453 | } |
454 | } |
455 | |
456 | // When we have a guarantee that an output won't be removed, because it's |
457 | // used in expressions that don't involve size checks, we can use its size |
458 | // instead of computing a long chain of broadcasts, starting from the |
459 | // beginning of the kernel. |
460 | auto outputs = fusion_group->outputs(); |
461 | auto soutputs = subgraph->outputs(); |
462 | AT_ASSERT(outputs.size() == soutputs.size()); |
463 | for (const auto i : c10::irange(outputs.size())) { |
464 | if (usedOnlyInSize(outputs[i])) |
465 | continue; |
466 | Value* soutput = graph->insert(aten::size, {outputs[i]}); |
467 | aliasDb_->createValue(soutput); |
468 | shape_of[soutputs[i]] = soutput; |
469 | } |
470 | |
471 | for (Node* n : subgraph->nodes()) { |
472 | auto tensor_inputs = filter(n->inputs(), [](Value* v) { |
473 | return v->type()->isSubtypeOf(*TensorType::get()); |
474 | }); |
475 | GRAPH_DEBUG("Building sizes for " , getHeader(n)); |
476 | bool all_inputs_have_sizes = true; |
477 | auto shapes = fmap(tensor_inputs, [&](Value* v) { |
478 | GRAPH_DEBUG("Getting aten::size for %" , v->debugName()); |
479 | all_inputs_have_sizes &= shape_of.count(v); |
480 | return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr; |
481 | }); |
482 | if (!all_inputs_have_sizes) { |
483 | GRAPH_DEBUG( |
484 | "Not all tensor arguments have sizes available to compute the broadcasted size" , |
485 | getHeader(n)); |
486 | continue; |
487 | } |
488 | |
489 | if (n->kind() == prim::ConstantChunk) { |
490 | Node* sizes_node = graph->insertNode( |
491 | graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); |
492 | sizes_node->i_(attr::dim, n->i(attr::dim)); |
493 | sizes_node->i_(attr::chunks, n->i(attr::chunks)); |
494 | for (Value* output : sizes_node->outputs()) { |
495 | aliasDb_->createValue(output); |
496 | } |
497 | Value* regular_size = sizes_node->outputs().at(0); |
498 | Value* last_size = sizes_node->outputs().at(1); |
499 | regular_size->setType(ListType::ofInts()); |
500 | last_size->setType(ListType::ofInts()); |
501 | auto outputs = n->outputs(); |
502 | for (Value* o : outputs.slice(0, outputs.size() - 1)) { |
503 | shape_of.emplace(o, regular_size); |
504 | } |
505 | shape_of.emplace(outputs.at(outputs.size() - 1), last_size); |
506 | continue; |
507 | } |
508 | |
509 | // we only support shape calculations for elementwise, some |
510 | // non-elementwise like batch_norm, conv, matmul, and |
511 | // a few exceptions (e.g. prim::ConstantChunk, etc) listed above |
512 | if (!(get_tensorexpr_elementwise_set().contains(n)) && |
513 | !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) { |
514 | continue; |
515 | } |
516 | |
517 | shape_of.emplace( |
518 | n->output(), |
519 | shapes.size() == 1 ? shapes[0] |
520 | : broadcastSizes(shapes, aliasDb_.get())); |
521 | } |
522 | return shape_of; |
523 | } |
524 | |
525 | void removeOutputsUsedOnlyInSize(Node* fusion_group) { |
526 | if (fusion_group->kind() != prim::TensorExprGroup) |
527 | return; |
528 | auto subgraph = fusion_group->g(attr::Subgraph); |
529 | |
530 | auto shape_of = buildShapeExpressions(fusion_group); |
531 | auto outputs = fusion_group->outputs().vec(); |
532 | auto soutputs = subgraph->outputs().vec(); |
533 | // XXX: Iterating in this order is not only good for performance reasons! |
534 | // It is also crucial for correctness (i has to reflect the current true |
535 | // index of outputs[i])! |
536 | for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) { |
537 | auto output = outputs[i]; |
538 | auto soutput = soutputs[i]; |
539 | if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) { |
540 | auto uses = output->uses(); |
541 | for (Use u : uses) { |
542 | AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]" )); |
543 | u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); |
544 | u.user->destroy(); |
545 | } |
546 | fusion_group->eraseOutput(i); |
547 | subgraph->eraseOutput(i); |
548 | } |
549 | } |
550 | } |
551 | |
552 | void run() { |
553 | aliasDb_ = torch::make_unique<AliasDb>(graph_); |
554 | RemoveRedundantProfiles(graph_); |
555 | GRAPH_DUMP("After removing redundant profile nodes: " , graph_); |
556 | createFusionGroups(graph_->block()); |
557 | GRAPH_DUMP("After creating fusion groups: " , graph_); |
558 | // we maintain alias db correctness during initial fusion, but it is |
559 | // difficult to maintain correctness after inlining so inline only after |
560 | // fusion is done. |
561 | inlineSmallFusionGroups(graph_->block()); |
562 | GRAPH_DUMP("After inlining small fusion groups: " , graph_); |
563 | if (fuse_to_dynamic_shapes_) { |
564 | VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled" |
565 | << std::endl; |
566 | generalizeFusionGroups(graph_->block()); |
567 | GRAPH_DUMP("After generalizing fusion groups: " , graph_); |
568 | } else { |
569 | prepareFusionGroupAndGuardOutputs(graph_->block()); |
570 | GRAPH_DUMP("After guarding fusion groups: " , graph_); |
571 | } |
572 | } |
573 | |
574 | private: |
575 | Node* getOrCreateTensorExprSubgraph(Node* n) { |
576 | if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::TensorExprGroup) { |
577 | return n; |
578 | } |
579 | GRAPH_UPDATE("Creating a tensorexpr::Group node from: " , *n); |
580 | return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( |
581 | n, prim::TensorExprGroup, *aliasDb_); |
582 | } |
583 | |
584 | value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) { |
585 | value_list result; |
586 | for (auto i : inputs) { |
587 | if (i->node()->owningBlock() == b) { |
588 | result.push_back(i); |
589 | } |
590 | } |
591 | // Sort in reverse topological order |
592 | std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { |
593 | return a->node()->isAfter(b->node()); |
594 | }); |
595 | return result; |
596 | } |
597 | |
598 | // Create a fusion group starting from the node N. |
599 | // We then try to pull inputs into the fusion group and repeat that process |
600 | // until there is nothing we can pull in. |
601 | std::pair<graph_node_list::iterator, bool> createFusionGroup( |
602 | Node* fusion_node) { |
603 | // Allow single-node groups containing conv2d, since we'll only select |
604 | // those in cases where the tensorexpr implementation is faster than the |
605 | // aten implementation. |
606 | if (min_group_size_ == 1 || fusion_node->kind() == aten::conv2d) { |
607 | fusion_node = getOrCreateTensorExprSubgraph(fusion_node); |
608 | } |
609 | |
610 | GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n" ); |
611 | auto inputs = sortReverseTopological( |
612 | fusion_node->inputs(), fusion_node->owningBlock()); |
613 | for (auto input : inputs) { |
614 | debugDumpFusionGroup("Current fusion group: " , fusion_node); |
615 | GRAPH_DEBUG("Trying to merge: " , *input->node()); |
616 | if (auto maybe_fusion_group = tryMerge(fusion_node, input->node())) { |
617 | // we successfully merged, so the new group's `inputs` may have |
618 | // changed. So rescan the new group for more merging opportunities. |
619 | return std::make_pair( |
620 | maybe_fusion_group.value()->reverseIterator(), true); |
621 | } |
622 | } |
623 | |
624 | return std::make_pair(++fusion_node->reverseIterator(), false); |
625 | } |
626 | |
627 | static void debugDumpFusionGroup(const std::string& msg, Node* n) { |
628 | // NOLINTNEXTLINE(clang-analyzer-core.NonNullParamChecker) |
629 | GRAPH_DEBUG(msg, *n); |
630 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
631 | if (n->kind() == prim::TensorExprGroup) { |
632 | GRAPH_DEBUG(*n->g(attr::Subgraph)); |
633 | } |
634 | } |
635 | |
636 | // No Ops in eager shouldn't be outputs of Fusion Groups because it |
637 | // will degrade perf and change aliasing relationships |
638 | static bool unexecutedEagerOp(Node* n) { |
639 | if (n->kind() != aten::to && |
640 | n->kind() != aten::_autocast_to_reduced_precision && |
641 | n->kind() != aten::_autocast_to_full_precision) { |
642 | return false; |
643 | } |
644 | |
645 | return *n->input(0)->type()->expect<TensorType>() == |
646 | *n->output()->type()->expect<TensorType>(); |
647 | } |
648 | |
649 | std::pair<graph_node_list::iterator, bool> scanNode(Node* n) { |
650 | GRAPH_DEBUG("Considering node:" , *n) |
651 | |
652 | if (!canHandle(n)) { |
653 | return std::make_pair(++n->reverseIterator(), false); |
654 | } |
655 | // There are some nodes that we can support, but we don't want to start a |
656 | // fusion group from - skip them. |
657 | if (n->kind() == prim::ListConstruct || n->kind() == aten::slice || |
658 | n->kind() == aten::unsqueeze || n->kind() == prim::ConstantChunk || |
659 | n->kind() == prim::Constant || unexecutedEagerOp(n)) { |
660 | return std::make_pair(++n->reverseIterator(), false); |
661 | } |
662 | return createFusionGroup(n); |
663 | } |
664 | |
665 | // Merge fusible nodes into subgraphs in prim::TensorExprGroup nodes. |
666 | void createFusionGroups(Block* block) { |
667 | bool any_changed = true; |
668 | while (any_changed) { |
669 | any_changed = false; |
670 | for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { |
671 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
672 | bool changed; |
673 | std::tie(it, changed) = scanNode(*it); |
674 | any_changed |= changed; |
675 | } |
676 | } |
677 | |
678 | for (Node* n : block->nodes()) { |
679 | for (Block* b : n->blocks()) { |
680 | createFusionGroups(b); |
681 | } |
682 | } |
683 | |
684 | // Try to merge adjacent fusion groups together. Because we have only merged |
685 | // by looking at graph inputs, without this we would not attempt to merge |
686 | // adjacent fusion groups that don't have a depdency on each other |
687 | |
688 | std::vector<Node*> initial_fusion_groups; |
689 | for (Node* n : block->nodes()) { |
690 | if (n->kind() == prim::TensorExprGroup) { |
691 | initial_fusion_groups.push_back(n); |
692 | } |
693 | } |
694 | |
695 | Node* prev_fusion_group = |
696 | !initial_fusion_groups.empty() ? initial_fusion_groups[0] : nullptr; |
697 | |
698 | for (const auto i : c10::irange(1, initial_fusion_groups.size())) { |
699 | // Try merging the just created fusion group into the previous one. |
700 | // If it did not work, then put the previous fusion group into |
701 | // fusion_groups vector - we will not touch it anymore in this loop. |
702 | // If merging suceeded, save the merged group as the "previous" fusion |
703 | // group so that we can try to merge the next one into it. |
704 | |
705 | Node* fusion_group = initial_fusion_groups[i]; |
706 | debugDumpFusionGroup( |
707 | "Trying to merge into the previous fusion group: " , |
708 | prev_fusion_group); |
709 | if (auto merged_fusion_group = |
710 | tryMerge(prev_fusion_group, fusion_group)) { |
711 | prev_fusion_group = *merged_fusion_group; |
712 | debugDumpFusionGroup( |
713 | "Successfully merged into the previous fusion group: " , |
714 | prev_fusion_group); |
715 | } else { |
716 | GRAPH_DEBUG("Cannot merge into the previous fusion group" ); |
717 | prev_fusion_group = fusion_group; |
718 | } |
719 | } |
720 | } |
721 | |
722 | size_t blockSize(Block* block) { |
723 | size_t num = 0; |
724 | for (Node* n : block->nodes()) { |
725 | // Don't count prim::Constants and prim::ListConstructs as these are nodes |
726 | // we only pull in along with another, "main", node. E.g. the |
727 | // ListConstruct nodes would also be pulled into a fusion group if they |
728 | // are inputs of an aten::cat node. |
729 | if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) { |
730 | continue; |
731 | } |
732 | for (Block* b : n->blocks()) { |
733 | num += blockSize(b); |
734 | } |
735 | num++; |
736 | } |
737 | return num; |
738 | } |
739 | |
740 | bool hasConv(Block* block) { |
741 | for (Node* n : block->nodes()) { |
742 | if (n->kind() == aten::conv2d) { |
743 | return true; |
744 | } |
745 | } |
746 | return false; |
747 | } |
748 | |
749 | bool inlineIfTooSmall(Node* n) { |
750 | if (n->kind() != prim::TensorExprGroup) { |
751 | return false; |
752 | } |
753 | auto subgraph = SubgraphUtils::getSubgraph(n); |
754 | size_t num_nodes = blockSize(subgraph->block()); |
755 | // Allow small subgraphs containing conv2d, since we'll only select those |
756 | // in cases where the tensorexpr implementation is faster than the aten |
757 | // implementation. |
758 | if (num_nodes < min_group_size_ && !hasConv(subgraph->block())) { |
759 | GRAPH_UPDATE("Fusion group is too small, unmerging: " , *n); |
760 | SubgraphUtils::unmergeSubgraph(n); |
761 | return true; |
762 | } |
763 | // Cleanup the subgraph from duplicated constants while we're at it. |
764 | ConstantPooling(subgraph); |
765 | |
766 | if (GRAPH_DEBUG_ENABLED) { |
767 | GRAPH_EXPORT("" , subgraph); |
768 | } |
769 | return false; |
770 | } |
771 | |
772 | void inlineSmallFusionGroups(Block* block) { |
773 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
774 | Node* n = *it; |
775 | it++; |
776 | |
777 | for (Block* b : n->blocks()) { |
778 | inlineSmallFusionGroups(b); |
779 | } |
780 | inlineIfTooSmall(n); |
781 | } |
782 | } |
783 | |
784 | c10::optional<Node*> tryMerge(Node* fusion_group, Node* to_merge) { |
785 | if (!canMerge(fusion_group, to_merge)) { |
786 | return c10::nullopt; |
787 | } |
788 | |
789 | std::vector<Node*> nodes_to_merge = {to_merge}; |
790 | |
791 | if (to_merge->kind() == aten::cat) { |
792 | Node* listconstruct = to_merge->input(0)->node(); |
793 | nodes_to_merge.push_back(listconstruct); |
794 | } |
795 | |
796 | // First, try to move all the nodes we want to fuse next to the fusion |
797 | // group. |
798 | Node* move_point = fusion_group; |
799 | for (auto n : nodes_to_merge) { |
800 | GRAPH_UPDATE("Trying to move node next to fusion group: " , getHeader(n)); |
801 | if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) { |
802 | GRAPH_UPDATE("Failed to move because of AliasDB checks!" ); |
803 | return c10::nullopt; |
804 | } |
805 | move_point = n; |
806 | } |
807 | |
808 | // Now all the nodes that we're going to fuse are moved next to the fusion |
809 | // group, so we can safely merge them into the fusion group subgraph. |
810 | fusion_group = getOrCreateTensorExprSubgraph(fusion_group); |
811 | |
812 | for (auto n : nodes_to_merge) { |
813 | GRAPH_UPDATE("Merging " , getHeader(n)); |
814 | SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( |
815 | n, fusion_group, *aliasDb_); |
816 | } |
817 | return fusion_group; |
818 | } |
819 | |
820 | bool shapeIsKnown(Value* v) { |
821 | if (v->type()->cast<TensorType>()) { |
822 | if (!v->isCompleteTensor()) { |
823 | return false; |
824 | } |
825 | } |
826 | return true; |
827 | } |
828 | |
829 | bool allShapesAreKnown(Node* node) { |
830 | // TODO: Relax the checks to support dynamic shapes |
831 | for (Value* input : node->inputs()) { |
832 | if (!shapeIsKnown(input)) { |
833 | return false; |
834 | } |
835 | if (input->node()->kind() == prim::ListConstruct) { |
836 | if (!allShapesAreKnown(input->node())) { |
837 | return false; |
838 | } |
839 | } |
840 | } |
841 | for (Value* output : node->outputs()) { |
842 | if (!shapeIsKnown(output)) { |
843 | return false; |
844 | } |
845 | } |
846 | return true; |
847 | } |
848 | |
849 | bool canFuseOnDevice(Value* v) { |
850 | auto type = v->type()->cast<TensorType>(); |
851 | if (!type) { |
852 | return true; |
853 | } |
854 | auto device = type->device(); |
855 | if (!device) { |
856 | return false; |
857 | } |
858 | if (device->is_cpu()) { |
859 | return canFuseOnCPU(); |
860 | } else if (device->is_cuda()) { |
861 | #ifndef C10_MOBILE |
862 | if (fuser::cuda::isEnabled()) { |
863 | return false; |
864 | } |
865 | #endif |
866 | return canFuseOnGPU(); |
867 | } else if (device->is_xpu()) { |
868 | return false; |
869 | } |
870 | return false; |
871 | } |
872 | |
873 | bool isFusableOnDevice(Node* node) { |
874 | for (const auto& input : node->inputs()) { |
875 | if (input->node()->kind() == prim::ListConstruct) { |
876 | if (!isFusableOnDevice(input->node())) { |
877 | return false; |
878 | } |
879 | } |
880 | if (!canFuseOnDevice(input)) { |
881 | return false; |
882 | } |
883 | } |
884 | return true; |
885 | } |
886 | |
887 | bool typesAreSupported(Node* node) { |
888 | // clang-format off |
889 | // breaks up the schema strings so they are no longer discoverable with ctrl-F |
890 | static const OperatorSet float_only_operator_set{ |
891 | "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor" , |
892 | "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor" , |
893 | "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor" , |
894 | "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor" , |
895 | }; |
896 | static const OperatorSet int_only_operator_set{ |
897 | "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor" , |
898 | "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor" , |
899 | "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor" , |
900 | "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor" , |
901 | }; |
902 | static const OperatorSet cpu_compute_heavy_set{ |
903 | "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor" , |
904 | "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor" , |
905 | "aten::matmul(Tensor self, Tensor other) -> Tensor" , |
906 | }; |
907 | static const OperatorSet gpu_only_operator_set{ |
908 | // On CPU, these are slower and less accurate than ATen kernels, because |
909 | // ATen is able to use MKL-VML, whereas the fuser currently can't. The |
910 | // fuser uses sleef instead because sleef provides functions that operate |
911 | // on vectors, instead of large buffers. |
912 | "aten::erf(Tensor self) -> Tensor" , |
913 | "aten::erfc(Tensor self) -> Tensor" , |
914 | }; |
915 | static const OperatorSet pow{ |
916 | "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor" , |
917 | }; |
918 | // clang-format on |
919 | |
920 | // Check types of input values. |
921 | for (const Value* v : node->inputs()) { |
922 | if (auto const& tt = v->type()->cast<TensorType>()) { |
923 | auto const& st = tt->scalarType(); |
924 | auto const& device = tt->device(); |
925 | |
926 | // All tensors must be typed. |
927 | if (!st || !device) { |
928 | return false; |
929 | } |
930 | |
931 | // Byte tensors introduce too many corner cases in type promotion. |
932 | // Better not to try to handle them. |
933 | if (*st == c10::ScalarType::Byte) { |
934 | return false; |
935 | } |
936 | |
937 | // Float16 support has some issues (see e.g. #61336 and #61382), so for |
938 | // now it's disabled. There seem to be some problems in HalfRewriter, |
939 | // but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we |
940 | // additionally disable it until we either move to a more stable version |
941 | // or find workarounds. |
942 | if (*st == c10::ScalarType::Half && *device == c10::kCPU) { |
943 | return false; |
944 | } |
945 | |
946 | if (*st == c10::ScalarType::BFloat16 && *device == c10::kCPU) { |
947 | #ifndef TORCH_ENABLE_LLVM |
948 | return false; |
949 | #endif |
950 | } |
951 | |
952 | // These operators only support floats, because integer divisors need to |
953 | // raise ZeroDivisionError. |
954 | if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) { |
955 | return false; |
956 | } |
957 | |
958 | // These operators have complicated casting rules for floats. |
959 | if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) { |
960 | return false; |
961 | } |
962 | } else if (node->isMemberOf(float_only_operator_set)) { |
963 | // Check scalar operands of float-only ops. |
964 | if (!v->type()->cast<FloatType>()) { |
965 | return false; |
966 | } |
967 | } else if (node->isMemberOf(int_only_operator_set)) { |
968 | if (!v->type()->cast<IntType>()) { |
969 | return false; |
970 | } |
971 | } |
972 | } |
973 | |
974 | // aten::pow has special rules to avoid complicated integer cases. We |
975 | // expect the first arg to be a floating point tensor, and if that's the |
976 | // case the type of the scalar exponent doesn't matter. |
977 | if (node->isMemberOf(pow)) { |
978 | auto const& tt = node->input(0)->type()->cast<TensorType>(); |
979 | if (!tt) { |
980 | return false; |
981 | } |
982 | auto const& st = tt->scalarType(); |
983 | if (!st || !isFloatingType(*st)) { |
984 | return false; |
985 | } |
986 | } |
987 | |
988 | // Operator is only supported on CPU. |
989 | if (node->isMemberOf(cpu_compute_heavy_set)) { |
990 | if (fuse_to_dynamic_shapes_) { |
991 | return false; |
992 | } |
993 | |
994 | auto device = tensorexpr::pickDeviceType(node->inputs()); |
995 | if (!device) { |
996 | device = tensorexpr::pickDeviceType(node->outputs()); |
997 | } |
998 | if (!device || !device->is_cpu()) { |
999 | return false; |
1000 | } |
1001 | } |
1002 | |
1003 | // Operator is only supported on GPU. |
1004 | if (node->isMemberOf(gpu_only_operator_set)) { |
1005 | auto device = tensorexpr::pickDeviceType(node->inputs()); |
1006 | if (!device) { |
1007 | device = tensorexpr::pickDeviceType(node->outputs()); |
1008 | } |
1009 | if (!device || !device->is_cuda()) { |
1010 | return false; |
1011 | } |
1012 | } |
1013 | |
1014 | if (node->kind() == aten::to) { |
1015 | // only support same-device conversion |
1016 | auto device = tensorexpr::pickDeviceType(node->inputs()); |
1017 | auto output_device = tensorexpr::pickDeviceType(node->outputs()); |
1018 | if (!device || !output_device || *device != *output_device) { |
1019 | return false; |
1020 | } |
1021 | // non_blocking only applies in cross-device conversion, which we bail on |
1022 | // copy arg only applies if op is a no-op, which we dont start fusion |
1023 | // group from memory format is separately handled in NNC output |
1024 | |
1025 | // all non-Tensor arguments must be constant |
1026 | for (size_t i = 1; i < node->inputs().size(); i++) { |
1027 | if (node->inputs().at(i)->node()->kind() != prim::Constant) { |
1028 | return false; |
1029 | } |
1030 | } |
1031 | |
1032 | if (has_unsupported_pin_memory(node)) { |
1033 | return false; |
1034 | } |
1035 | } |
1036 | |
1037 | if (node->kind() == aten::_autocast_to_reduced_precision || |
1038 | node->kind() == aten::_autocast_to_full_precision) { |
1039 | for (auto i : c10::irange(1, node->inputs().size())) { |
1040 | if (node->inputs().at(i)->node()->kind() != prim::Constant) { |
1041 | return false; |
1042 | } |
1043 | } |
1044 | |
1045 | bool is_reduced_precision = |
1046 | node->kind() == aten::_autocast_to_reduced_precision; |
1047 | bool is_full_precision = |
1048 | node->kind() == aten::_autocast_to_full_precision; |
1049 | auto self_tensor = node->inputs()[0]; // input tensor |
1050 | |
1051 | if (auto const& tt = self_tensor->type()->cast<TensorType>()) { |
1052 | auto st = tt->scalarType(); |
1053 | if (!st.has_value()) { |
1054 | return false; |
1055 | } |
1056 | |
1057 | auto device = tt->device(); |
1058 | if (!device.has_value()) { |
1059 | return false; |
1060 | } |
1061 | |
1062 | bool is_cpu = device->is_cpu(); |
1063 | |
1064 | if (*st != at::kFloat && is_reduced_precision && is_cpu) { |
1065 | // Regarding CPU, aten would do nothing if the data type is |
1066 | // float. Then the aten performance is better than NNC. So NNC |
1067 | // does not pull it into its fusion group. |
1068 | return false; |
1069 | } |
1070 | |
1071 | if (*st != at::kBFloat16 && is_full_precision && is_cpu) { |
1072 | // Regarding CPU, aten would do nothing if the data type is |
1073 | // BFloat16. Then the aten performance is better than NNC. So NNC |
1074 | // does not pull it into its fusion group. |
1075 | return false; |
1076 | } |
1077 | } |
1078 | |
1079 | if (has_unsupported_pin_memory(node)) { |
1080 | return false; |
1081 | } |
1082 | } |
1083 | |
1084 | if (node->kind() == aten::unsqueeze) { |
1085 | // `dim` argument must be a constant. |
1086 | if (node->input(1)->node()->kind() != prim::Constant) { |
1087 | return false; |
1088 | } |
1089 | } |
1090 | |
1091 | if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) { |
1092 | GRAPH_DEBUG("This aten::_convolution node is not a 2D conv" ); |
1093 | return false; |
1094 | } |
1095 | if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) { |
1096 | if (!tensorexpr::conv2dIsSupportedJit(node) && |
1097 | !tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) { |
1098 | GRAPH_DEBUG("Params of conv2d are not supported" ); |
1099 | return false; |
1100 | } |
1101 | } |
1102 | if (node->kind() == aten::matmul) { |
1103 | if (!tensorexpr::matmulIsSupported(node)) { |
1104 | GRAPH_DEBUG("Shapes of matmul inputs are not supported" ); |
1105 | return false; |
1106 | } |
1107 | } |
1108 | return true; |
1109 | } |
1110 | |
1111 | #define REQ(cond) \ |
1112 | if (!(cond)) { \ |
1113 | GRAPH_DEBUG("Failed cond " #cond "\n"); \ |
1114 | return false; \ |
1115 | } |
1116 | |
1117 | bool canHandle(Node* node) { |
1118 | REQ(allShapesAreKnown(node)); |
1119 | REQ(isFusableOnDevice(node)); |
1120 | REQ(operators_not_to_fuse.find(node->kind()) == |
1121 | operators_not_to_fuse.end()); |
1122 | |
1123 | for (Value* input : node->inputs()) { |
1124 | if (auto const& tt = input->type()->cast<TensorType>()) { |
1125 | auto st = tt->scalarType(); |
1126 | if (!st) { |
1127 | // All tensor types should be known. |
1128 | return false; |
1129 | } |
1130 | if (c10::isComplexType(*st) || c10::isQIntType(*st)) { |
1131 | return false; |
1132 | } |
1133 | } |
1134 | } |
1135 | if (node->kind() == aten::cat) { |
1136 | REQ(node->input(0)->node()->kind() == prim::ListConstruct); |
1137 | REQ(node->input(0)->uses().size() == 1); |
1138 | REQ(node->input(1)->node()->kind() == prim::Constant); |
1139 | auto const& listconstruct = node->input(0)->node(); |
1140 | REQ(tensorexpr::pickDeviceType(listconstruct->inputs())); |
1141 | } else { |
1142 | REQ(tensorexpr::pickDeviceType(node->inputs())); |
1143 | } |
1144 | |
1145 | // Only fuse aten::batch_norm when the parameter 'training' is false |
1146 | if (node->kind() == aten::batch_norm) { |
1147 | REQ(node->input(5)->node()->kind() == prim::Constant); |
1148 | REQ(!toIValue(node->input(5)).value().toBool()); |
1149 | } |
1150 | |
1151 | REQ(tensorexpr::isSupported(node)); |
1152 | REQ(typesAreSupported(node)); |
1153 | |
1154 | // A hook to optimizations limitter to allow bisecting the pass |
1155 | REQ(JIT_OPT_ALLOWED); |
1156 | |
1157 | if (fuse_to_dynamic_shapes_) { |
1158 | // Allow only if the node has a shape function defined. |
1159 | // ListConstruct node is an exception since that is needed to fuse |
1160 | // aten::cat, though it does not have a shape function. |
1161 | REQ(node->kind() == prim::ListConstruct || |
1162 | node->kind() == prim::TensorExprGroup || |
1163 | node->isMemberOf(tensorexpr::getCustomOperatorSet()) || |
1164 | (node->maybeSchema() && shapeComputeGraphForSchema(node->schema()))); |
1165 | } |
1166 | |
1167 | return true; |
1168 | } |
1169 | |
1170 | bool canMerge(Node* consumer, Node* producer) { |
1171 | // Only fuse within a block |
1172 | REQ(consumer->owningBlock() == producer->owningBlock()); |
1173 | |
1174 | // Symbolic checks |
1175 | REQ(canHandle(producer) || producer->kind() == prim::TensorExprGroup); |
1176 | TORCH_INTERNAL_ASSERT( |
1177 | consumer->kind() == prim::TensorExprGroup || canHandle(consumer)); |
1178 | |
1179 | // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. |
1180 | // The specific limit is a function of constant memory size, amount |
1181 | // available to pass arguments, and some implementation dependence. Select a |
1182 | // safe limit here. |
1183 | constexpr size_t subgraphArgLimit = 128; |
1184 | auto const nInputs = consumer->inputs().size() + |
1185 | consumer->outputs().size() + producer->inputs().size() + |
1186 | producer->outputs().size(); |
1187 | REQ(nInputs <= subgraphArgLimit); |
1188 | |
1189 | // Device checks |
1190 | if (consumer->kind() != aten::cat && producer->kind() != aten::cat) { |
1191 | // aten::cat needs a special handling because it takes a Tensor[] as its |
1192 | // input We deal with that in the code below. |
1193 | auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs()); |
1194 | REQ(consumer_device); |
1195 | auto producer_device = tensorexpr::pickDeviceType(producer->inputs()); |
1196 | REQ(producer_device); |
1197 | REQ(*consumer_device == *producer_device); |
1198 | } |
1199 | |
1200 | // Alias checks |
1201 | REQ(aliasDb_->couldMoveBeforeTopologically(producer, consumer)); |
1202 | |
1203 | // Ops that return aliases can only be folded if this is the only use. |
1204 | if (producer->kind() == aten::slice || |
1205 | producer->kind() == aten::unsqueeze || |
1206 | producer->kind() == prim::ConstantChunk) { |
1207 | for (auto& use : producer->output(0)->uses()) { |
1208 | REQ(use.user == consumer); |
1209 | } |
1210 | } |
1211 | |
1212 | if (!consumer->hasAttribute(attr::Subgraph) && |
1213 | consumer->kind() != prim::TensorExprGroup) { |
1214 | // Don't initiate a fusion group from prim::ListConstruct |
1215 | REQ(consumer->kind() != prim::ListConstruct); |
1216 | REQ(consumer->kind() != aten::slice); |
1217 | REQ(consumer->kind() != aten::unsqueeze); |
1218 | REQ(consumer->kind() != prim::ConstantChunk); |
1219 | |
1220 | // Don't initiate a fusion group just for a constant operand |
1221 | REQ(producer->kind() != prim::Constant); |
1222 | } |
1223 | |
1224 | if (producer->kind() == aten::cat) { |
1225 | REQ(producer->input(0)->node()->kind() == prim::ListConstruct); |
1226 | REQ(producer->input(0)->uses().size() == 1); |
1227 | REQ(producer->input(1)->node()->kind() == prim::Constant); |
1228 | auto const& listConstruct = producer->input(0)->node(); |
1229 | // We're merging listconstruct->cat->consumer. cat is the producer here |
1230 | // and we cannot determine its device type - we should use device of the |
1231 | // listconstruct instead |
1232 | auto listconstruct_device = |
1233 | tensorexpr::pickDeviceType(listConstruct->inputs()); |
1234 | auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs()); |
1235 | REQ(listconstruct_device); |
1236 | REQ(consumer_device); |
1237 | REQ(*listconstruct_device == *consumer_device); |
1238 | for (auto const& input : listConstruct->inputs()) { |
1239 | REQ(isFusableOnDevice(input->node())); |
1240 | } |
1241 | REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit); |
1242 | } else if (consumer->kind() == aten::cat) { |
1243 | REQ(consumer->input(0)->node()->kind() == prim::ListConstruct); |
1244 | REQ(consumer->input(0)->uses().size() == 1); |
1245 | REQ(consumer->input(1)->node()->kind() == prim::Constant); |
1246 | auto const& listConstruct = consumer->input(0)->node(); |
1247 | // We're merging listconstruct->cat. cat is the consumer and listconstruct |
1248 | // is the producer. cat doesn't have its device type and thus the only |
1249 | // thing we should check is that listconstruct's device is well defined |
1250 | // (e.g. all its inputs has the same device). |
1251 | auto listconstruct_device = |
1252 | tensorexpr::pickDeviceType(listConstruct->inputs()); |
1253 | REQ(listconstruct_device); |
1254 | REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit); |
1255 | } else { |
1256 | REQ(isFusableOnDevice(producer)); |
1257 | } |
1258 | |
1259 | return true; |
1260 | } |
1261 | #undef REQ |
1262 | |
1263 | void prepareFusionGroupAndGuardOutputs(Block* block) { |
1264 | std::vector<Node*> fusion_groups; |
1265 | for (Node* n : block->nodes()) { |
1266 | for (Block* b : n->blocks()) { |
1267 | prepareFusionGroupAndGuardOutputs(b); |
1268 | } |
1269 | if (n->kind() == prim::TensorExprGroup) { |
1270 | fusion_groups.push_back(n); |
1271 | } |
1272 | } |
1273 | for (Node* fusion_group : fusion_groups) { |
1274 | removeOutputsUsedOnlyInSize(fusion_group); |
1275 | insertTypeGuard( |
1276 | fusion_group, |
1277 | [](const TensorTypePtr& t) { return t; }, |
1278 | prim::TypeCheck); |
1279 | } |
1280 | } |
1281 | |
1282 | void generalizeFusionGroups(Block* block) { |
1283 | std::vector<Node*> fusion_groups; |
1284 | for (Node* n : block->nodes()) { |
1285 | for (Block* b : n->blocks()) { |
1286 | generalizeFusionGroups(b); |
1287 | } |
1288 | if (n->kind() == prim::TensorExprGroup) { |
1289 | fusion_groups.push_back(n); |
1290 | } |
1291 | } |
1292 | for (Node* fusion_group : fusion_groups) { |
1293 | removeOutputsUsedOnlyInSize(fusion_group); |
1294 | VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group; |
1295 | if (!GenerateGuard(fusion_group, add_composed_op_)) { |
1296 | VLOG(1) << " Unfusing the fusion group because GenerateGuard failed" |
1297 | << std::endl; |
1298 | SubgraphUtils::unmergeSubgraph(fusion_group); |
1299 | } |
1300 | } |
1301 | } |
1302 | |
1303 | // This function parses the option provided by the environment variable |
1304 | // "PYTORCH_TENSOREXPR_DONT_FUSE". |
1305 | // This variable allows users to disable fusion on a list of specified |
1306 | // operators that are separated by ':'. e.g., |
1307 | // 'PYTORCH_TENSOREXPR_DONT_FUSE="clamp:mul:add"' disables fusion on |
1308 | // aten::clamp, aten::mul and aten::add. |
1309 | void parseTENotFuseOption() { |
1310 | const char* option = std::getenv("PYTORCH_TENSOREXPR_DONT_FUSE" ); |
1311 | std::stringstream in_ss; |
1312 | if (option) { |
1313 | in_ss << option; |
1314 | } |
1315 | |
1316 | std::string line; |
1317 | while (std::getline(in_ss, line, ':')) { |
1318 | if (line.empty()) { |
1319 | continue; |
1320 | } |
1321 | operators_not_to_fuse.insert(c10::Symbol::aten(line)); |
1322 | } |
1323 | } |
1324 | |
1325 | std::shared_ptr<Graph> graph_; |
1326 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
1327 | |
1328 | std::set<NodeKind> operators_not_to_fuse; |
1329 | // Minimal size of a fusion group |
1330 | size_t min_group_size_; |
1331 | // compose Runtime Type Guard and Kernel in one op |
1332 | bool add_composed_op_; |
1333 | // generalize static shapes to dynamic shapes |
1334 | bool fuse_to_dynamic_shapes_; |
1335 | }; |
1336 | |
1337 | void FuseTensorExprs( |
1338 | std::shared_ptr<Graph>& graph, |
1339 | size_t min_group_size, |
1340 | bool add_composed_op, |
1341 | bool fuse_to_dynamic_shapes) { |
1342 | GRAPH_DUMP("Before TExprFuser: " , graph); |
1343 | |
1344 | // Temporary change for Block code generation. |
1345 | if (tensorexpr::getTEGenerateBlockCode()) { |
1346 | min_group_size = 1; |
1347 | } |
1348 | |
1349 | if (add_composed_op) { |
1350 | TORCH_INTERNAL_ASSERT( |
1351 | fuse_to_dynamic_shapes, "Fusing static shapes with composed op NYI" ); |
1352 | } |
1353 | |
1354 | // Get rid of dead code so that we don't waste effort fusing it. |
1355 | EliminateDeadCode(graph); |
1356 | |
1357 | TensorExprFuser fuser( |
1358 | graph, min_group_size, add_composed_op, fuse_to_dynamic_shapes); |
1359 | fuser.run(); |
1360 | |
1361 | EliminateCommonSubexpression(graph); |
1362 | EliminateDeadCode(graph); |
1363 | |
1364 | GRAPH_DUMP("After TExprFuser: " , graph); |
1365 | } |
1366 | |
1367 | Operation createTensorExprOp(const Node* node) { |
1368 | bool dynamic_shape_fusion_node = |
1369 | node->hasAttribute(attr::striding_inputs_desc); |
1370 | if (!dynamic_shape_fusion_node) { |
1371 | auto kernel = |
1372 | std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph)); |
1373 | return [kernel](Stack& stack) { |
1374 | RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>()); |
1375 | kernel->run(stack); |
1376 | return 0; |
1377 | }; |
1378 | } |
1379 | |
1380 | // Handle the case when dynamic shape fusion is enabled. |
1381 | VLOG(1) << "Compiling a new kernel for " << *node; |
1382 | std::vector<int64_t> sym_shapes; |
1383 | if (node->hasAttribute(attr::symbolic_shape_inputs)) { |
1384 | sym_shapes = node->is(attr::symbolic_shape_inputs); |
1385 | } |
1386 | bool allow_stack_outputs = false; |
1387 | if (node->hasAttribute(attr::allow_stack_outputs)) { |
1388 | allow_stack_outputs = node->i(attr::allow_stack_outputs) == 1; |
1389 | } |
1390 | |
1391 | std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction> |
1392 | custom_lowerings; |
1393 | auto subgraph = node->g(attr::Subgraph); |
1394 | IValue sym_strides = node->ival(attr::striding_inputs_desc); |
1395 | |
1396 | // Striding Descriptor is serialized on the node as a vector of vector of |
1397 | // strings, translate back to StrideInput enum |
1398 | std::vector<std::vector<std::string>> sym_strides_strs = |
1399 | sym_strides.to<std::vector<std::vector<std::string>>>(); |
1400 | std::vector<std::vector<StrideInput>> striding_inputs; |
1401 | for (const auto& vec : sym_strides_strs) { |
1402 | std::vector<StrideInput> input_desc; |
1403 | input_desc.reserve(vec.size()); |
1404 | for (const std::string& str : vec) { |
1405 | input_desc.push_back(strideInputFromString(str)); |
1406 | } |
1407 | striding_inputs.push_back(input_desc); |
1408 | } |
1409 | std::unordered_map<const Value*, std::vector<StrideInput>> stride_map; |
1410 | size_t index = 0; |
1411 | for (Value* v : subgraph->inputs()) { |
1412 | if (!v->type()->cast<TensorType>()) { |
1413 | continue; |
1414 | } |
1415 | stride_map[v] = striding_inputs[index]; |
1416 | index++; |
1417 | } |
1418 | std::vector<std::string> output_desc = |
1419 | node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>(); |
1420 | for (size_t i = 0; i < subgraph->outputs().size(); ++i) { |
1421 | stride_map[subgraph->outputs().at(i)] = { |
1422 | strideInputFromString(output_desc.at(i))}; |
1423 | } |
1424 | |
1425 | std::shared_ptr<tensorexpr::TensorExprKernel> kernel = |
1426 | std::make_shared<tensorexpr::TensorExprKernel>( |
1427 | subgraph, |
1428 | custom_lowerings, |
1429 | sym_shapes, |
1430 | /*pre_alloc*/ false, |
1431 | stride_map); |
1432 | |
1433 | auto num_subgraph_inputs = subgraph->inputs().size(); |
1434 | return [kernel, num_subgraph_inputs, allow_stack_outputs](Stack& stack) { |
1435 | RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>()); |
1436 | |
1437 | // Stack contents: |
1438 | // [<outputs>] <inputs> |
1439 | // |
1440 | // If the number of graph inputs is same as the stack size, then no |
1441 | // outputs are being passed in. Otherwise, output tensors are passed in |
1442 | // at the bottom of the stack. So, we call the appropriate run function |
1443 | // in TensorExprKernel. |
1444 | if (num_subgraph_inputs == stack.size() || !allow_stack_outputs) { |
1445 | kernel->run(stack); |
1446 | } else { |
1447 | kernel->runWithAllocatedOutputs(stack); |
1448 | } |
1449 | return 0; |
1450 | }; |
1451 | } |
1452 | |
1453 | RegisterOperators TensorExprOps({ |
1454 | torch::jit::Operator( |
1455 | prim::TensorExprGroup, |
1456 | createTensorExprOp, |
1457 | AliasAnalysisKind::INTERNAL_SPECIAL_CASE), |
1458 | }); |
1459 | |
1460 | } // namespace jit |
1461 | } // namespace torch |
1462 | |