1#include <type_inference.h>
2
3#include <ATen/AccumulateType.h>
4#include <c10/core/ScalarType.h>
5#include <instrumentation.h>
6#include <torch/csrc/jit/ir/constants.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/runtime/operator.h>
9
10#include <ATen/ExpandUtils.h>
11#include <ATen/core/jit_type.h>
12#include <ATen/native/TypeProperties.h>
13#include <type_promotion.h>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20namespace {
21
22at::ScalarType toAccumulateType(const TensorTypePtr& op) {
23 TORCH_INTERNAL_ASSERT(
24 op->scalarType().has_value(), "Missing Type Information.");
25 return at::toAccumulateType(op->scalarType().value(), true /* is_cuda */);
26}
27
28bool hasTypeAndDevice(const TensorTypePtr& op) {
29 return op != nullptr && op->device().has_value() &&
30 op->scalarType().has_value();
31}
32
33void copyScalarTypeAndDeviceToOutput(
34 c10::optional<c10::ScalarType> dtype,
35 c10::optional<c10::Device> device,
36 Node* node,
37 size_t index = 0) {
38 auto out = node->output(index)->type()->cast<TensorType>();
39 TORCH_INTERNAL_ASSERT(
40 out != nullptr,
41 "Expect target node's type pointer to be non-nullptr, but get nullptr");
42 if (!hasTypeAndDevice(out)) {
43 node->output(index)->setType(
44 TensorType::create(dtype, device, c10::nullopt, c10::nullopt));
45 }
46}
47
48void copyScalarTypeAndDeviceToOutput(
49 TensorTypePtr from,
50 Node* node,
51 size_t index = 0) {
52 copyScalarTypeAndDeviceToOutput(
53 from->scalarType(), from->device(), node, index);
54}
55
56TensorTypePtr getInputTensorType(
57 Node* node,
58 size_t index,
59 bool optional = false) {
60 auto tensor_type = node->input(index)->type()->cast<TensorType>();
61 if (optional && tensor_type == nullptr) {
62 return tensor_type;
63 }
64
65 // (not optional) implies (tensor_type not equal nullptr)
66 TORCH_CHECK(
67 optional || tensor_type != nullptr,
68 "Input ",
69 index,
70 " for operation ",
71 node->kind().toDisplayString(),
72 " needs to be a tensor.");
73
74 TORCH_CHECK(
75 hasTypeAndDevice(tensor_type),
76 "Input ",
77 index,
78 " for operation ",
79 node->kind().toDisplayString(),
80 " is missing Type or Device Information.");
81 return tensor_type;
82}
83
84/* NaiveTypePropagator
85 * Populate type/device tag on tensor, this is a transition module to
86 * cover the absence of type inference in codegen cuda fuser.
87 *
88 * We only cover operations supported in codegen. We focus on propagate concrete
89 * types.
90 * It does NOT handle aliases (not supported in codegen anyway); Type promotion
91 * is not guaranteed to be consistent with PyTorch (we need to serve the need of
92 * codegen instead).
93 */
94class NaiveTypePropagator {
95 public:
96 NaiveTypePropagator(std::shared_ptr<Graph> graph)
97 : graph_(std::move(graph)) {}
98
99 void PropagateOnBlock(Block* block) {
100 for (Node* node : block->nodes()) {
101 PropagateOnNode(node);
102 }
103 }
104
105 void PropagateOnNode(Node* node) {
106 switch (node->kind()) {
107 // Constant:
108 case prim::Constant: {
109 if (node->output()->type()->isSubtypeOf(TensorType::get())) {
110 node->output()->inferTypeFrom(node->t(attr::value));
111 }
112 break;
113 }
114 // unary operations
115 case aten::threshold:
116 case aten::clamp:
117 case aten::abs:
118 case aten::neg:
119 case aten::ceil:
120 case aten::floor:
121 case aten::round:
122 case aten::trunc:
123 case aten::frac:
124 case aten::leaky_relu:
125 case aten::relu:
126 case aten::silu:
127 case aten::gelu:
128 case aten::softplus:
129 case aten::bitwise_not:
130 // TODO: rand_like should support cast.
131 case aten::rand_like: {
132 unary_type(node);
133 break;
134 }
135 // unary float operations
136 case aten::log:
137 case aten::log10:
138 case aten::log1p:
139 case aten::log2:
140 case aten::lgamma:
141 case aten::exp:
142 case aten::expm1:
143 case aten::erf:
144 case aten::erfc:
145 case aten::cos:
146 case aten::acos:
147 case aten::cosh:
148 case aten::sin:
149 case aten::asin:
150 case aten::sinh:
151 case aten::tan:
152 case aten::atan:
153 case aten::atanh:
154 case aten::sqrt:
155 case aten::rsqrt:
156 case aten::reciprocal:
157 case aten::sigmoid:
158 case aten::tanh: {
159 unary_float_type(node);
160 break;
161 }
162 // unary is
163 case aten::isfinite:
164 case aten::isinf:
165 case aten::isnan:
166 case aten::isneginf:
167 case aten::isposinf:
168 case aten::isreal: {
169 copyScalarTypeAndDeviceToOutput(
170 c10::ScalarType::Bool, c10::nullopt, node);
171 break;
172 }
173 // binary float
174 case aten::atan2: {
175 binary_type(node, TypePromotion::float_op_config);
176 break;
177 }
178 // binary operations that forward meta info and broadcast shape:
179 case aten::gelu_backward:
180 case aten::tanh_backward:
181 case aten::mul:
182 case aten::div:
183 case aten::min:
184 case aten::max:
185 // TODO: first operand for pow can be Tensor / Scalar
186 case aten::pow:
187 case aten::remainder:
188 case aten::threshold_backward:
189 case aten::fmod:
190 case aten::lerp:
191 // add/sub could be ternary op and the third argument does not contribute
192 // to neither type promotion nor shape.
193 // TODO: Include alpha check for add/sub
194 case aten::add:
195 case aten::sub:
196 case aten::rsub:
197 case aten::bitwise_and:
198 case aten::__and__:
199 case aten::bitwise_or:
200 case aten::__or__:
201 case aten::bitwise_xor:
202 case aten::__xor__:
203 case aten::bitwise_left_shift:
204 case aten::__lshift__:
205 case aten::bitwise_right_shift:
206 case aten::__rshift__: {
207 binary_type(node);
208 break;
209 }
210 // binary comparison
211 case aten::lt:
212 case aten::le:
213 case aten::gt:
214 case aten::ge:
215 case aten::ne:
216 case aten::eq: {
217 binary_broadcast_type(
218 node,
219 getInputTensorType(node, 0, false),
220 getInputTensorType(node, 1, true),
221 at::ScalarType::Bool);
222 break;
223 }
224 case aten::where: {
225 binary_broadcast_type(
226 node,
227 getInputTensorType(node, 1, true),
228 getInputTensorType(node, 2, true));
229 break;
230 }
231 case aten::addcmul: {
232 auto promoted_type = binary_broadcast_type(
233 nullptr,
234 getInputTensorType(node, 1, true),
235 getInputTensorType(node, 2, true));
236 binary_broadcast_type(
237 node, promoted_type, getInputTensorType(node, 0, true));
238 break;
239 }
240 case aten::native_dropout: {
241 auto out_type = getInputTensorType(node, 0);
242 copyScalarTypeAndDeviceToOutput(out_type, node, 0);
243 copyScalarTypeAndDeviceToOutput(
244 out_type->withScalarType(at::ScalarType::Bool), node, 1);
245 break;
246 }
247 case aten::native_dropout_backward:
248 case aten::dropout:
249 case aten::instance_norm:
250 case aten::batch_norm:
251 case aten::layer_norm: {
252 copyScalarTypeAndDeviceToOutput(getInputTensorType(node, 0), node);
253 break;
254 }
255 case aten::_batch_norm_impl_index_backward:
256 case aten::native_batch_norm_backward: {
257 int mask_index = -1;
258 if (node->kind() ==
259 c10::Symbol::fromQualString(
260 "aten::_batch_norm_impl_index_backward")) {
261 mask_index = 10;
262 } else if (
263 node->kind() ==
264 c10::Symbol::fromQualString("aten::native_batch_norm_backward")) {
265 mask_index = 9;
266 } else {
267 TORCH_INTERNAL_ASSERT(
268 false, "unidentified node kind", node->kind().toDisplayString());
269 }
270 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
271 auto out_mask_list =
272 constant_as<c10::List<bool>>(node->input(mask_index));
273 TORCH_INTERNAL_ASSERT(
274 out_mask_list.has_value(),
275 "Missing output mask for batch_norm_backward");
276 std::vector<int> output_mask;
277 for (const auto value : out_mask_list->vec()) {
278 output_mask.emplace_back(static_cast<int>(value));
279 }
280
281 auto grad_input_type = getInputTensorType(node, 1);
282 if (output_mask[0]) {
283 copyScalarTypeAndDeviceToOutput(grad_input_type, node, 0);
284 }
285
286 if (output_mask[1]) {
287 if (auto weight_type = getInputTensorType(node, 3, true)) {
288 auto acc_weight_type =
289 weight_type->withScalarType(toAccumulateType(weight_type));
290 copyScalarTypeAndDeviceToOutput(acc_weight_type, node, 1);
291 }
292 }
293
294 // TODO: Use shape information from weight tensor
295 // OR get dtype information for bias tensor
296 if (output_mask[2]) {
297 auto bias_type = TensorType::create(
298 toAccumulateType(grad_input_type),
299 *grad_input_type->device(),
300 c10::nullopt,
301 c10::nullopt);
302 copyScalarTypeAndDeviceToOutput(bias_type, node, 2);
303 }
304 break;
305 }
306 case aten::_batch_norm_impl_index: {
307 auto out_type = getInputTensorType(node, 0);
308 copyScalarTypeAndDeviceToOutput(out_type, node, 0);
309
310 auto mean_invstd_type = TensorType::create(
311 toAccumulateType(out_type),
312 *out_type->device(),
313 c10::nullopt,
314 c10::nullopt);
315 copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 1);
316 copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 2);
317
318 // TODO: not that it matters, but mark the right type here;
319 auto reserve_type = TensorType::create(
320 *out_type->scalarType(),
321 *out_type->device(),
322 c10::nullopt,
323 c10::nullopt);
324 copyScalarTypeAndDeviceToOutput(reserve_type, node, 3);
325 node->output(4)->setType(IntType::get());
326 break;
327 }
328 case aten::native_batch_norm:
329 case aten::native_layer_norm: {
330 auto out_type = getInputTensorType(node, 0);
331 copyScalarTypeAndDeviceToOutput(out_type, node, 0);
332
333 auto mean_invstd_type = TensorType::create(
334 toAccumulateType(out_type),
335 *out_type->device(),
336 c10::nullopt,
337 c10::nullopt);
338 copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 1);
339 copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 2);
340 break;
341 }
342 case aten::native_layer_norm_backward: {
343 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
344 auto out_mask_list = constant_as<c10::List<bool>>(node->input(7));
345 TORCH_INTERNAL_ASSERT(
346 out_mask_list.has_value(), "output mask for layer_norm_backward");
347 std::vector<int> output_mask;
348 for (const auto value : out_mask_list->vec()) {
349 output_mask.emplace_back(static_cast<int>(value));
350 }
351
352 if (output_mask[0]) {
353 copyScalarTypeAndDeviceToOutput(getInputTensorType(node, 0), node, 0);
354 }
355
356 if (output_mask[1]) {
357 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
358 if (auto weight_type = getInputTensorType(node, 5, true)) {
359 copyScalarTypeAndDeviceToOutput(weight_type, node, 1);
360 }
361 }
362
363 if (output_mask[2]) {
364 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
365 if (auto bias_type = getInputTensorType(node, 6, true)) {
366 copyScalarTypeAndDeviceToOutput(bias_type, node, 2);
367 }
368 }
369 break;
370 }
371 case aten::log_softmax:
372 case aten::softmax: {
373 auto out_type = getInputTensorType(node, 0);
374
375 // accept dtype input to `aten::softmax` node
376 if (!node->input(2)->type()->isSubtypeOf(
377 static_cast<c10::TypePtr>(NoneType::get()))) {
378 if (auto opt_ivalue = toIValue(node->input(2))) {
379 out_type = out_type->withScalarType(opt_ivalue->toScalarType());
380 }
381 }
382 copyScalarTypeAndDeviceToOutput(out_type, node);
383 break;
384 }
385 case aten::_softmax: {
386 auto out_type = getInputTensorType(node, 0);
387
388 const auto half_to_float = constant_as<bool>(node->input(2));
389 TORCH_CHECK(
390 half_to_float.has_value(),
391 "half_to_float bool doesn't have a value.");
392 if (half_to_float.value()) {
393 out_type = out_type->withScalarType(at::ScalarType::Float);
394 }
395
396 copyScalarTypeAndDeviceToOutput(out_type, node);
397 break;
398 }
399 case aten::_log_softmax_backward_data:
400 case aten::_softmax_backward_data: {
401 auto out_type = getInputTensorType(node, 0);
402 if (auto opt_ivalue = toIValue(node->input(3))) {
403 out_type = out_type->withScalarType(opt_ivalue->toScalarType());
404 }
405 copyScalarTypeAndDeviceToOutput(out_type, node);
406 break;
407 }
408 case aten::amax:
409 case aten::amin:
410 case aten::mean:
411 case aten::sum: {
412 auto out_type = getInputTensorType(node, 0);
413
414 // accept dtype input to `aten::sum` && `aten::mean` node
415 if (node->kind() == aten::mean || node->kind() == aten::sum) {
416 if (!node->input(3)->type()->isSubtypeOf(
417 static_cast<c10::TypePtr>(NoneType::get()))) {
418 if (auto opt_ivalue = toIValue(node->input(3))) {
419 out_type = out_type->withScalarType(opt_ivalue->toScalarType());
420 }
421 }
422 }
423 const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
424 const auto keepdim = constant_as<bool>(node->input(2));
425 TORCH_CHECK(
426 dims.has_value() && keepdim.has_value(),
427 "Shape inference cannot handle options.");
428 unary_reduce_type(node, out_type, dims->vec(), keepdim.value());
429 break;
430 }
431 case aten::std:
432 case aten::var: {
433 auto out_type = getInputTensorType(node, 0);
434 const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
435 const auto keepdim = constant_as<bool>(node->input(3));
436 TORCH_CHECK(
437 dims.has_value() && keepdim.has_value(),
438 "Shape inference cannot handle options.");
439 unary_reduce_type(node, out_type, dims->vec(), keepdim.value());
440 break;
441 }
442 case aten::sum_to_size:
443 case aten::_grad_sum_to_size: {
444 auto out_type = node->input(0)->type()->cast<TensorType>();
445 copyScalarTypeAndDeviceToOutput(out_type->withDim(c10::nullopt), node);
446 break;
447 }
448 case prim::expand_copy:
449 case prim::expand_as_copy:
450 case prim::flatten_copy:
451 case prim::permute_copy:
452 case prim::reshape_copy:
453 case prim::squeeze_copy:
454 case prim::t_copy:
455 case prim::transpose_copy:
456 case prim::unsqueeze_copy:
457 case prim::view_copy: {
458 auto out_type = node->input(0)->type()->cast<TensorType>();
459 copyScalarTypeAndDeviceToOutput(out_type, node);
460 break;
461 }
462 case aten::type_as: {
463 const auto type0 = getInputTensorType(node, 0);
464 const auto type1 = getInputTensorType(node, 1);
465 copyScalarTypeAndDeviceToOutput(
466 type0->withScalarType(type1->scalarType()), node);
467 break;
468 }
469 case aten::to:
470 case aten::_to_copy: {
471 const auto type0 = getInputTensorType(node, 0);
472 const auto out_dtype = toIValue(node->input(1));
473 if (out_dtype.has_value() && out_dtype->isInt()) {
474 copyScalarTypeAndDeviceToOutput(
475 type0->withScalarType(out_dtype->toScalarType()), node);
476 } else {
477 TORCH_CHECK(
478 !out_dtype.has_value() || out_dtype->isNone(),
479 "dtype for cast unrecognized ",
480 out_dtype->tagKind());
481 copyScalarTypeAndDeviceToOutput(type0, node);
482 }
483 break;
484 }
485 case prim::add_optional: {
486 const auto type0 = getInputTensorType(node, 0);
487 // const auto type1 = getInputTensorType(node, 1, true);
488 // note: add_optional is supposed to replace an inplace add on input0,
489 // so we just directly forward dtype
490 TORCH_CHECK(type0 != nullptr);
491 copyScalarTypeAndDeviceToOutput(type0, node);
492 break;
493 }
494 case aten::_autocast_to_reduced_precision: {
495 const auto in_type = node->input(0)->type()->cast<TensorType>();
496 TORCH_CHECK(
497 hasTypeAndDevice(in_type),
498 "Type and device propagation has failed, or was not provided enough information.");
499 const auto in_device = in_type->device();
500 const auto cuda_enabled = constant_as<bool>(node->input(1));
501 const auto cpu_enabled = constant_as<bool>(node->input(2));
502 const auto cuda_dtype = constant_as<c10::ScalarType>(node->input(3));
503 const auto cpu_dtype = constant_as<c10::ScalarType>(node->input(4));
504 TORCH_CHECK(
505 cuda_enabled.has_value() && cpu_enabled.has_value() &&
506 cuda_dtype.has_value() && cpu_dtype.has_value(),
507 "_autocast_to_reduced_precision requires all scalar inputs to be constant.");
508 if (in_type->scalarType() == at::ScalarType::Float) {
509 if (in_device->is_cuda() && cuda_enabled.value()) {
510 copyScalarTypeAndDeviceToOutput(
511 in_type->withScalarType(cuda_dtype.value()), node);
512 break;
513 } else if (in_device->is_cpu() && cpu_enabled.value()) {
514 copyScalarTypeAndDeviceToOutput(
515 in_type->withScalarType(cpu_dtype.value()), node);
516 break;
517 }
518 }
519 copyScalarTypeAndDeviceToOutput(in_type, node);
520 break;
521 }
522 case aten::_autocast_to_full_precision: {
523 const auto in_type = node->input(0)->type()->cast<TensorType>();
524 TORCH_CHECK(
525 hasTypeAndDevice(in_type),
526 "Type and device propagation has failed, or was not provided enough information.");
527 const auto in_scalar_type = in_type->scalarType();
528 const auto in_device = in_type->device();
529 const auto cuda_enabled = constant_as<bool>(node->input(1));
530 const auto cpu_enabled = constant_as<bool>(node->input(2));
531 TORCH_CHECK(
532 cuda_enabled.has_value() && cpu_enabled.has_value(),
533 "_autocast_to_full_precision requires enable flag to be constant.");
534
535 if ((in_scalar_type == at::ScalarType::Half ||
536 in_scalar_type == at::ScalarType::BFloat16) &&
537 ((in_device->is_cuda() && cuda_enabled.value()) ||
538 (in_device->is_cpu() && cpu_enabled.value()))) {
539 copyScalarTypeAndDeviceToOutput(
540 in_type->withScalarType(at::ScalarType::Float), node);
541 } else {
542 copyScalarTypeAndDeviceToOutput(in_type, node);
543 }
544 break;
545 }
546 default:
547 TORCH_CHECK(
548 false,
549 "type inference failed, unrecognized operation encountered:",
550 node->kind().toDisplayString());
551 // TODO: generate a proper error log, as this probably means something
552 // went unexpected.
553 break;
554 }
555 }
556
557 void run() {
558 PropagateOnBlock(graph_->block());
559 }
560
561 protected:
562 void unary_type(Node* node) {
563 auto op = getInputTensorType(node, 0, false);
564 copyScalarTypeAndDeviceToOutput(op, node);
565 }
566
567 void unary_float_type(Node* node) {
568 auto op = getInputTensorType(node, 0, false);
569 copyScalarTypeAndDeviceToOutput(
570 computeTypes(TypePromotion::float_op_config, {op}),
571 *op->device(),
572 node);
573 }
574
575 void unary_reduce_type(
576 Node* node,
577 const TensorTypePtr& op,
578 const std::vector<int64_t>& dims,
579 bool keepdim) {
580 TORCH_CHECK(
581 hasTypeAndDevice(op),
582 "Type and device propagation has failed, or was not provided enough information.");
583 copyScalarTypeAndDeviceToOutput(op, node);
584 }
585
586 void binary_type(
587 Node* node,
588 TypePromotionConfig config = TypePromotion::default_op_config) {
589 auto op0 = node->input(0)->type();
590 auto op1 = node->input(1)->type();
591 auto op0_tensor_type = op0->cast<TensorType>();
592 auto op1_tensor_type = op1->cast<TensorType>();
593 TORCH_CHECK(
594 hasTypeAndDevice(op0_tensor_type) || hasTypeAndDevice(op1_tensor_type),
595 "At least one operand must be a tensor.");
596 auto ptr = (op0_tensor_type != nullptr) ? op0_tensor_type : op1_tensor_type;
597 copyScalarTypeAndDeviceToOutput(
598 computeTypes(config, {op0, op1}), *ptr->device(), node);
599 }
600
601 // TODO: we should comply to codegen type promotion.
602 TensorTypePtr binary_broadcast_type(
603 Node* node,
604 TensorTypePtr const& op0,
605 TensorTypePtr const& op1,
606 c10::optional<at::ScalarType> scalar_type = c10::nullopt) {
607 TensorTypePtr out;
608 TORCH_CHECK(
609 op0 != nullptr || op1 != nullptr,
610 "Scalar operations on binary broadcast type, not supported yet.");
611
612 c10::ScalarType promoted_scalar_type;
613 c10::optional<c10::Device> device;
614 if (op0 != nullptr && op1 != nullptr) {
615 TORCH_CHECK(
616 hasTypeAndDevice(op0) && hasTypeAndDevice(op1),
617 "Type and device propagation has failed, or was not provided enough information.");
618 promoted_scalar_type = scalar_type.has_value()
619 ? *scalar_type
620 : c10::promoteTypes(*op0->scalarType(), *op1->scalarType());
621 device = *op0->device();
622 } else {
623 auto ptr = (op0 != nullptr) ? op0 : op1;
624 TORCH_CHECK(
625 hasTypeAndDevice(ptr),
626 "Type and device propagation has failed, or was not provided enough information.");
627 promoted_scalar_type =
628 scalar_type.has_value() ? *scalar_type : *ptr->scalarType();
629 device = *ptr->device();
630 }
631 if (node != nullptr) {
632 copyScalarTypeAndDeviceToOutput(promoted_scalar_type, device, node);
633 }
634
635 return TensorType::create(
636 promoted_scalar_type, device, c10::nullopt, c10::nullopt);
637 }
638
639 private:
640 std::shared_ptr<Graph> graph_;
641};
642
643} // namespace
644
645void TypePropagate(std::shared_ptr<Graph>& graph) {
646 FUSER_PERF_SCOPE("TypePropagate");
647 GRAPH_DUMP("Before TypePropagate: ", graph);
648 NaiveTypePropagator(graph).run();
649 GRAPH_DUMP("After TypePropagate: ", graph);
650}
651
652} // namespace cuda
653} // namespace fuser
654} // namespace jit
655} // namespace torch
656