1 | #include <type_inference.h> |
2 | |
3 | #include <ATen/AccumulateType.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <instrumentation.h> |
6 | #include <torch/csrc/jit/ir/constants.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/runtime/operator.h> |
9 | |
10 | #include <ATen/ExpandUtils.h> |
11 | #include <ATen/core/jit_type.h> |
12 | #include <ATen/native/TypeProperties.h> |
13 | #include <type_promotion.h> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | namespace { |
21 | |
22 | at::ScalarType toAccumulateType(const TensorTypePtr& op) { |
23 | TORCH_INTERNAL_ASSERT( |
24 | op->scalarType().has_value(), "Missing Type Information." ); |
25 | return at::toAccumulateType(op->scalarType().value(), true /* is_cuda */); |
26 | } |
27 | |
28 | bool hasTypeAndDevice(const TensorTypePtr& op) { |
29 | return op != nullptr && op->device().has_value() && |
30 | op->scalarType().has_value(); |
31 | } |
32 | |
33 | void copyScalarTypeAndDeviceToOutput( |
34 | c10::optional<c10::ScalarType> dtype, |
35 | c10::optional<c10::Device> device, |
36 | Node* node, |
37 | size_t index = 0) { |
38 | auto out = node->output(index)->type()->cast<TensorType>(); |
39 | TORCH_INTERNAL_ASSERT( |
40 | out != nullptr, |
41 | "Expect target node's type pointer to be non-nullptr, but get nullptr" ); |
42 | if (!hasTypeAndDevice(out)) { |
43 | node->output(index)->setType( |
44 | TensorType::create(dtype, device, c10::nullopt, c10::nullopt)); |
45 | } |
46 | } |
47 | |
48 | void copyScalarTypeAndDeviceToOutput( |
49 | TensorTypePtr from, |
50 | Node* node, |
51 | size_t index = 0) { |
52 | copyScalarTypeAndDeviceToOutput( |
53 | from->scalarType(), from->device(), node, index); |
54 | } |
55 | |
56 | TensorTypePtr getInputTensorType( |
57 | Node* node, |
58 | size_t index, |
59 | bool optional = false) { |
60 | auto tensor_type = node->input(index)->type()->cast<TensorType>(); |
61 | if (optional && tensor_type == nullptr) { |
62 | return tensor_type; |
63 | } |
64 | |
65 | // (not optional) implies (tensor_type not equal nullptr) |
66 | TORCH_CHECK( |
67 | optional || tensor_type != nullptr, |
68 | "Input " , |
69 | index, |
70 | " for operation " , |
71 | node->kind().toDisplayString(), |
72 | " needs to be a tensor." ); |
73 | |
74 | TORCH_CHECK( |
75 | hasTypeAndDevice(tensor_type), |
76 | "Input " , |
77 | index, |
78 | " for operation " , |
79 | node->kind().toDisplayString(), |
80 | " is missing Type or Device Information." ); |
81 | return tensor_type; |
82 | } |
83 | |
84 | /* NaiveTypePropagator |
85 | * Populate type/device tag on tensor, this is a transition module to |
86 | * cover the absence of type inference in codegen cuda fuser. |
87 | * |
88 | * We only cover operations supported in codegen. We focus on propagate concrete |
89 | * types. |
90 | * It does NOT handle aliases (not supported in codegen anyway); Type promotion |
91 | * is not guaranteed to be consistent with PyTorch (we need to serve the need of |
92 | * codegen instead). |
93 | */ |
94 | class NaiveTypePropagator { |
95 | public: |
96 | NaiveTypePropagator(std::shared_ptr<Graph> graph) |
97 | : graph_(std::move(graph)) {} |
98 | |
99 | void PropagateOnBlock(Block* block) { |
100 | for (Node* node : block->nodes()) { |
101 | PropagateOnNode(node); |
102 | } |
103 | } |
104 | |
105 | void PropagateOnNode(Node* node) { |
106 | switch (node->kind()) { |
107 | // Constant: |
108 | case prim::Constant: { |
109 | if (node->output()->type()->isSubtypeOf(TensorType::get())) { |
110 | node->output()->inferTypeFrom(node->t(attr::value)); |
111 | } |
112 | break; |
113 | } |
114 | // unary operations |
115 | case aten::threshold: |
116 | case aten::clamp: |
117 | case aten::abs: |
118 | case aten::neg: |
119 | case aten::ceil: |
120 | case aten::floor: |
121 | case aten::round: |
122 | case aten::trunc: |
123 | case aten::frac: |
124 | case aten::leaky_relu: |
125 | case aten::relu: |
126 | case aten::silu: |
127 | case aten::gelu: |
128 | case aten::softplus: |
129 | case aten::bitwise_not: |
130 | // TODO: rand_like should support cast. |
131 | case aten::rand_like: { |
132 | unary_type(node); |
133 | break; |
134 | } |
135 | // unary float operations |
136 | case aten::log: |
137 | case aten::log10: |
138 | case aten::log1p: |
139 | case aten::log2: |
140 | case aten::lgamma: |
141 | case aten::exp: |
142 | case aten::expm1: |
143 | case aten::erf: |
144 | case aten::erfc: |
145 | case aten::cos: |
146 | case aten::acos: |
147 | case aten::cosh: |
148 | case aten::sin: |
149 | case aten::asin: |
150 | case aten::sinh: |
151 | case aten::tan: |
152 | case aten::atan: |
153 | case aten::atanh: |
154 | case aten::sqrt: |
155 | case aten::rsqrt: |
156 | case aten::reciprocal: |
157 | case aten::sigmoid: |
158 | case aten::tanh: { |
159 | unary_float_type(node); |
160 | break; |
161 | } |
162 | // unary is |
163 | case aten::isfinite: |
164 | case aten::isinf: |
165 | case aten::isnan: |
166 | case aten::isneginf: |
167 | case aten::isposinf: |
168 | case aten::isreal: { |
169 | copyScalarTypeAndDeviceToOutput( |
170 | c10::ScalarType::Bool, c10::nullopt, node); |
171 | break; |
172 | } |
173 | // binary float |
174 | case aten::atan2: { |
175 | binary_type(node, TypePromotion::float_op_config); |
176 | break; |
177 | } |
178 | // binary operations that forward meta info and broadcast shape: |
179 | case aten::gelu_backward: |
180 | case aten::tanh_backward: |
181 | case aten::mul: |
182 | case aten::div: |
183 | case aten::min: |
184 | case aten::max: |
185 | // TODO: first operand for pow can be Tensor / Scalar |
186 | case aten::pow: |
187 | case aten::remainder: |
188 | case aten::threshold_backward: |
189 | case aten::fmod: |
190 | case aten::lerp: |
191 | // add/sub could be ternary op and the third argument does not contribute |
192 | // to neither type promotion nor shape. |
193 | // TODO: Include alpha check for add/sub |
194 | case aten::add: |
195 | case aten::sub: |
196 | case aten::rsub: |
197 | case aten::bitwise_and: |
198 | case aten::__and__: |
199 | case aten::bitwise_or: |
200 | case aten::__or__: |
201 | case aten::bitwise_xor: |
202 | case aten::__xor__: |
203 | case aten::bitwise_left_shift: |
204 | case aten::__lshift__: |
205 | case aten::bitwise_right_shift: |
206 | case aten::__rshift__: { |
207 | binary_type(node); |
208 | break; |
209 | } |
210 | // binary comparison |
211 | case aten::lt: |
212 | case aten::le: |
213 | case aten::gt: |
214 | case aten::ge: |
215 | case aten::ne: |
216 | case aten::eq: { |
217 | binary_broadcast_type( |
218 | node, |
219 | getInputTensorType(node, 0, false), |
220 | getInputTensorType(node, 1, true), |
221 | at::ScalarType::Bool); |
222 | break; |
223 | } |
224 | case aten::where: { |
225 | binary_broadcast_type( |
226 | node, |
227 | getInputTensorType(node, 1, true), |
228 | getInputTensorType(node, 2, true)); |
229 | break; |
230 | } |
231 | case aten::addcmul: { |
232 | auto promoted_type = binary_broadcast_type( |
233 | nullptr, |
234 | getInputTensorType(node, 1, true), |
235 | getInputTensorType(node, 2, true)); |
236 | binary_broadcast_type( |
237 | node, promoted_type, getInputTensorType(node, 0, true)); |
238 | break; |
239 | } |
240 | case aten::native_dropout: { |
241 | auto out_type = getInputTensorType(node, 0); |
242 | copyScalarTypeAndDeviceToOutput(out_type, node, 0); |
243 | copyScalarTypeAndDeviceToOutput( |
244 | out_type->withScalarType(at::ScalarType::Bool), node, 1); |
245 | break; |
246 | } |
247 | case aten::native_dropout_backward: |
248 | case aten::dropout: |
249 | case aten::instance_norm: |
250 | case aten::batch_norm: |
251 | case aten::layer_norm: { |
252 | copyScalarTypeAndDeviceToOutput(getInputTensorType(node, 0), node); |
253 | break; |
254 | } |
255 | case aten::_batch_norm_impl_index_backward: |
256 | case aten::native_batch_norm_backward: { |
257 | int mask_index = -1; |
258 | if (node->kind() == |
259 | c10::Symbol::fromQualString( |
260 | "aten::_batch_norm_impl_index_backward" )) { |
261 | mask_index = 10; |
262 | } else if ( |
263 | node->kind() == |
264 | c10::Symbol::fromQualString("aten::native_batch_norm_backward" )) { |
265 | mask_index = 9; |
266 | } else { |
267 | TORCH_INTERNAL_ASSERT( |
268 | false, "unidentified node kind" , node->kind().toDisplayString()); |
269 | } |
270 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
271 | auto out_mask_list = |
272 | constant_as<c10::List<bool>>(node->input(mask_index)); |
273 | TORCH_INTERNAL_ASSERT( |
274 | out_mask_list.has_value(), |
275 | "Missing output mask for batch_norm_backward" ); |
276 | std::vector<int> output_mask; |
277 | for (const auto value : out_mask_list->vec()) { |
278 | output_mask.emplace_back(static_cast<int>(value)); |
279 | } |
280 | |
281 | auto grad_input_type = getInputTensorType(node, 1); |
282 | if (output_mask[0]) { |
283 | copyScalarTypeAndDeviceToOutput(grad_input_type, node, 0); |
284 | } |
285 | |
286 | if (output_mask[1]) { |
287 | if (auto weight_type = getInputTensorType(node, 3, true)) { |
288 | auto acc_weight_type = |
289 | weight_type->withScalarType(toAccumulateType(weight_type)); |
290 | copyScalarTypeAndDeviceToOutput(acc_weight_type, node, 1); |
291 | } |
292 | } |
293 | |
294 | // TODO: Use shape information from weight tensor |
295 | // OR get dtype information for bias tensor |
296 | if (output_mask[2]) { |
297 | auto bias_type = TensorType::create( |
298 | toAccumulateType(grad_input_type), |
299 | *grad_input_type->device(), |
300 | c10::nullopt, |
301 | c10::nullopt); |
302 | copyScalarTypeAndDeviceToOutput(bias_type, node, 2); |
303 | } |
304 | break; |
305 | } |
306 | case aten::_batch_norm_impl_index: { |
307 | auto out_type = getInputTensorType(node, 0); |
308 | copyScalarTypeAndDeviceToOutput(out_type, node, 0); |
309 | |
310 | auto mean_invstd_type = TensorType::create( |
311 | toAccumulateType(out_type), |
312 | *out_type->device(), |
313 | c10::nullopt, |
314 | c10::nullopt); |
315 | copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 1); |
316 | copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 2); |
317 | |
318 | // TODO: not that it matters, but mark the right type here; |
319 | auto reserve_type = TensorType::create( |
320 | *out_type->scalarType(), |
321 | *out_type->device(), |
322 | c10::nullopt, |
323 | c10::nullopt); |
324 | copyScalarTypeAndDeviceToOutput(reserve_type, node, 3); |
325 | node->output(4)->setType(IntType::get()); |
326 | break; |
327 | } |
328 | case aten::native_batch_norm: |
329 | case aten::native_layer_norm: { |
330 | auto out_type = getInputTensorType(node, 0); |
331 | copyScalarTypeAndDeviceToOutput(out_type, node, 0); |
332 | |
333 | auto mean_invstd_type = TensorType::create( |
334 | toAccumulateType(out_type), |
335 | *out_type->device(), |
336 | c10::nullopt, |
337 | c10::nullopt); |
338 | copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 1); |
339 | copyScalarTypeAndDeviceToOutput(mean_invstd_type, node, 2); |
340 | break; |
341 | } |
342 | case aten::native_layer_norm_backward: { |
343 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
344 | auto out_mask_list = constant_as<c10::List<bool>>(node->input(7)); |
345 | TORCH_INTERNAL_ASSERT( |
346 | out_mask_list.has_value(), "output mask for layer_norm_backward" ); |
347 | std::vector<int> output_mask; |
348 | for (const auto value : out_mask_list->vec()) { |
349 | output_mask.emplace_back(static_cast<int>(value)); |
350 | } |
351 | |
352 | if (output_mask[0]) { |
353 | copyScalarTypeAndDeviceToOutput(getInputTensorType(node, 0), node, 0); |
354 | } |
355 | |
356 | if (output_mask[1]) { |
357 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
358 | if (auto weight_type = getInputTensorType(node, 5, true)) { |
359 | copyScalarTypeAndDeviceToOutput(weight_type, node, 1); |
360 | } |
361 | } |
362 | |
363 | if (output_mask[2]) { |
364 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
365 | if (auto bias_type = getInputTensorType(node, 6, true)) { |
366 | copyScalarTypeAndDeviceToOutput(bias_type, node, 2); |
367 | } |
368 | } |
369 | break; |
370 | } |
371 | case aten::log_softmax: |
372 | case aten::softmax: { |
373 | auto out_type = getInputTensorType(node, 0); |
374 | |
375 | // accept dtype input to `aten::softmax` node |
376 | if (!node->input(2)->type()->isSubtypeOf( |
377 | static_cast<c10::TypePtr>(NoneType::get()))) { |
378 | if (auto opt_ivalue = toIValue(node->input(2))) { |
379 | out_type = out_type->withScalarType(opt_ivalue->toScalarType()); |
380 | } |
381 | } |
382 | copyScalarTypeAndDeviceToOutput(out_type, node); |
383 | break; |
384 | } |
385 | case aten::_softmax: { |
386 | auto out_type = getInputTensorType(node, 0); |
387 | |
388 | const auto half_to_float = constant_as<bool>(node->input(2)); |
389 | TORCH_CHECK( |
390 | half_to_float.has_value(), |
391 | "half_to_float bool doesn't have a value." ); |
392 | if (half_to_float.value()) { |
393 | out_type = out_type->withScalarType(at::ScalarType::Float); |
394 | } |
395 | |
396 | copyScalarTypeAndDeviceToOutput(out_type, node); |
397 | break; |
398 | } |
399 | case aten::_log_softmax_backward_data: |
400 | case aten::_softmax_backward_data: { |
401 | auto out_type = getInputTensorType(node, 0); |
402 | if (auto opt_ivalue = toIValue(node->input(3))) { |
403 | out_type = out_type->withScalarType(opt_ivalue->toScalarType()); |
404 | } |
405 | copyScalarTypeAndDeviceToOutput(out_type, node); |
406 | break; |
407 | } |
408 | case aten::amax: |
409 | case aten::amin: |
410 | case aten::mean: |
411 | case aten::sum: { |
412 | auto out_type = getInputTensorType(node, 0); |
413 | |
414 | // accept dtype input to `aten::sum` && `aten::mean` node |
415 | if (node->kind() == aten::mean || node->kind() == aten::sum) { |
416 | if (!node->input(3)->type()->isSubtypeOf( |
417 | static_cast<c10::TypePtr>(NoneType::get()))) { |
418 | if (auto opt_ivalue = toIValue(node->input(3))) { |
419 | out_type = out_type->withScalarType(opt_ivalue->toScalarType()); |
420 | } |
421 | } |
422 | } |
423 | const auto dims = constant_as<c10::List<int64_t>>(node->input(1)); |
424 | const auto keepdim = constant_as<bool>(node->input(2)); |
425 | TORCH_CHECK( |
426 | dims.has_value() && keepdim.has_value(), |
427 | "Shape inference cannot handle options." ); |
428 | unary_reduce_type(node, out_type, dims->vec(), keepdim.value()); |
429 | break; |
430 | } |
431 | case aten::std: |
432 | case aten::var: { |
433 | auto out_type = getInputTensorType(node, 0); |
434 | const auto dims = constant_as<c10::List<int64_t>>(node->input(1)); |
435 | const auto keepdim = constant_as<bool>(node->input(3)); |
436 | TORCH_CHECK( |
437 | dims.has_value() && keepdim.has_value(), |
438 | "Shape inference cannot handle options." ); |
439 | unary_reduce_type(node, out_type, dims->vec(), keepdim.value()); |
440 | break; |
441 | } |
442 | case aten::sum_to_size: |
443 | case aten::_grad_sum_to_size: { |
444 | auto out_type = node->input(0)->type()->cast<TensorType>(); |
445 | copyScalarTypeAndDeviceToOutput(out_type->withDim(c10::nullopt), node); |
446 | break; |
447 | } |
448 | case prim::expand_copy: |
449 | case prim::expand_as_copy: |
450 | case prim::flatten_copy: |
451 | case prim::permute_copy: |
452 | case prim::reshape_copy: |
453 | case prim::squeeze_copy: |
454 | case prim::t_copy: |
455 | case prim::transpose_copy: |
456 | case prim::unsqueeze_copy: |
457 | case prim::view_copy: { |
458 | auto out_type = node->input(0)->type()->cast<TensorType>(); |
459 | copyScalarTypeAndDeviceToOutput(out_type, node); |
460 | break; |
461 | } |
462 | case aten::type_as: { |
463 | const auto type0 = getInputTensorType(node, 0); |
464 | const auto type1 = getInputTensorType(node, 1); |
465 | copyScalarTypeAndDeviceToOutput( |
466 | type0->withScalarType(type1->scalarType()), node); |
467 | break; |
468 | } |
469 | case aten::to: |
470 | case aten::_to_copy: { |
471 | const auto type0 = getInputTensorType(node, 0); |
472 | const auto out_dtype = toIValue(node->input(1)); |
473 | if (out_dtype.has_value() && out_dtype->isInt()) { |
474 | copyScalarTypeAndDeviceToOutput( |
475 | type0->withScalarType(out_dtype->toScalarType()), node); |
476 | } else { |
477 | TORCH_CHECK( |
478 | !out_dtype.has_value() || out_dtype->isNone(), |
479 | "dtype for cast unrecognized " , |
480 | out_dtype->tagKind()); |
481 | copyScalarTypeAndDeviceToOutput(type0, node); |
482 | } |
483 | break; |
484 | } |
485 | case prim::add_optional: { |
486 | const auto type0 = getInputTensorType(node, 0); |
487 | // const auto type1 = getInputTensorType(node, 1, true); |
488 | // note: add_optional is supposed to replace an inplace add on input0, |
489 | // so we just directly forward dtype |
490 | TORCH_CHECK(type0 != nullptr); |
491 | copyScalarTypeAndDeviceToOutput(type0, node); |
492 | break; |
493 | } |
494 | case aten::_autocast_to_reduced_precision: { |
495 | const auto in_type = node->input(0)->type()->cast<TensorType>(); |
496 | TORCH_CHECK( |
497 | hasTypeAndDevice(in_type), |
498 | "Type and device propagation has failed, or was not provided enough information." ); |
499 | const auto in_device = in_type->device(); |
500 | const auto cuda_enabled = constant_as<bool>(node->input(1)); |
501 | const auto cpu_enabled = constant_as<bool>(node->input(2)); |
502 | const auto cuda_dtype = constant_as<c10::ScalarType>(node->input(3)); |
503 | const auto cpu_dtype = constant_as<c10::ScalarType>(node->input(4)); |
504 | TORCH_CHECK( |
505 | cuda_enabled.has_value() && cpu_enabled.has_value() && |
506 | cuda_dtype.has_value() && cpu_dtype.has_value(), |
507 | "_autocast_to_reduced_precision requires all scalar inputs to be constant." ); |
508 | if (in_type->scalarType() == at::ScalarType::Float) { |
509 | if (in_device->is_cuda() && cuda_enabled.value()) { |
510 | copyScalarTypeAndDeviceToOutput( |
511 | in_type->withScalarType(cuda_dtype.value()), node); |
512 | break; |
513 | } else if (in_device->is_cpu() && cpu_enabled.value()) { |
514 | copyScalarTypeAndDeviceToOutput( |
515 | in_type->withScalarType(cpu_dtype.value()), node); |
516 | break; |
517 | } |
518 | } |
519 | copyScalarTypeAndDeviceToOutput(in_type, node); |
520 | break; |
521 | } |
522 | case aten::_autocast_to_full_precision: { |
523 | const auto in_type = node->input(0)->type()->cast<TensorType>(); |
524 | TORCH_CHECK( |
525 | hasTypeAndDevice(in_type), |
526 | "Type and device propagation has failed, or was not provided enough information." ); |
527 | const auto in_scalar_type = in_type->scalarType(); |
528 | const auto in_device = in_type->device(); |
529 | const auto cuda_enabled = constant_as<bool>(node->input(1)); |
530 | const auto cpu_enabled = constant_as<bool>(node->input(2)); |
531 | TORCH_CHECK( |
532 | cuda_enabled.has_value() && cpu_enabled.has_value(), |
533 | "_autocast_to_full_precision requires enable flag to be constant." ); |
534 | |
535 | if ((in_scalar_type == at::ScalarType::Half || |
536 | in_scalar_type == at::ScalarType::BFloat16) && |
537 | ((in_device->is_cuda() && cuda_enabled.value()) || |
538 | (in_device->is_cpu() && cpu_enabled.value()))) { |
539 | copyScalarTypeAndDeviceToOutput( |
540 | in_type->withScalarType(at::ScalarType::Float), node); |
541 | } else { |
542 | copyScalarTypeAndDeviceToOutput(in_type, node); |
543 | } |
544 | break; |
545 | } |
546 | default: |
547 | TORCH_CHECK( |
548 | false, |
549 | "type inference failed, unrecognized operation encountered:" , |
550 | node->kind().toDisplayString()); |
551 | // TODO: generate a proper error log, as this probably means something |
552 | // went unexpected. |
553 | break; |
554 | } |
555 | } |
556 | |
557 | void run() { |
558 | PropagateOnBlock(graph_->block()); |
559 | } |
560 | |
561 | protected: |
562 | void unary_type(Node* node) { |
563 | auto op = getInputTensorType(node, 0, false); |
564 | copyScalarTypeAndDeviceToOutput(op, node); |
565 | } |
566 | |
567 | void unary_float_type(Node* node) { |
568 | auto op = getInputTensorType(node, 0, false); |
569 | copyScalarTypeAndDeviceToOutput( |
570 | computeTypes(TypePromotion::float_op_config, {op}), |
571 | *op->device(), |
572 | node); |
573 | } |
574 | |
575 | void unary_reduce_type( |
576 | Node* node, |
577 | const TensorTypePtr& op, |
578 | const std::vector<int64_t>& dims, |
579 | bool keepdim) { |
580 | TORCH_CHECK( |
581 | hasTypeAndDevice(op), |
582 | "Type and device propagation has failed, or was not provided enough information." ); |
583 | copyScalarTypeAndDeviceToOutput(op, node); |
584 | } |
585 | |
586 | void binary_type( |
587 | Node* node, |
588 | TypePromotionConfig config = TypePromotion::default_op_config) { |
589 | auto op0 = node->input(0)->type(); |
590 | auto op1 = node->input(1)->type(); |
591 | auto op0_tensor_type = op0->cast<TensorType>(); |
592 | auto op1_tensor_type = op1->cast<TensorType>(); |
593 | TORCH_CHECK( |
594 | hasTypeAndDevice(op0_tensor_type) || hasTypeAndDevice(op1_tensor_type), |
595 | "At least one operand must be a tensor." ); |
596 | auto ptr = (op0_tensor_type != nullptr) ? op0_tensor_type : op1_tensor_type; |
597 | copyScalarTypeAndDeviceToOutput( |
598 | computeTypes(config, {op0, op1}), *ptr->device(), node); |
599 | } |
600 | |
601 | // TODO: we should comply to codegen type promotion. |
602 | TensorTypePtr binary_broadcast_type( |
603 | Node* node, |
604 | TensorTypePtr const& op0, |
605 | TensorTypePtr const& op1, |
606 | c10::optional<at::ScalarType> scalar_type = c10::nullopt) { |
607 | TensorTypePtr out; |
608 | TORCH_CHECK( |
609 | op0 != nullptr || op1 != nullptr, |
610 | "Scalar operations on binary broadcast type, not supported yet." ); |
611 | |
612 | c10::ScalarType promoted_scalar_type; |
613 | c10::optional<c10::Device> device; |
614 | if (op0 != nullptr && op1 != nullptr) { |
615 | TORCH_CHECK( |
616 | hasTypeAndDevice(op0) && hasTypeAndDevice(op1), |
617 | "Type and device propagation has failed, or was not provided enough information." ); |
618 | promoted_scalar_type = scalar_type.has_value() |
619 | ? *scalar_type |
620 | : c10::promoteTypes(*op0->scalarType(), *op1->scalarType()); |
621 | device = *op0->device(); |
622 | } else { |
623 | auto ptr = (op0 != nullptr) ? op0 : op1; |
624 | TORCH_CHECK( |
625 | hasTypeAndDevice(ptr), |
626 | "Type and device propagation has failed, or was not provided enough information." ); |
627 | promoted_scalar_type = |
628 | scalar_type.has_value() ? *scalar_type : *ptr->scalarType(); |
629 | device = *ptr->device(); |
630 | } |
631 | if (node != nullptr) { |
632 | copyScalarTypeAndDeviceToOutput(promoted_scalar_type, device, node); |
633 | } |
634 | |
635 | return TensorType::create( |
636 | promoted_scalar_type, device, c10::nullopt, c10::nullopt); |
637 | } |
638 | |
639 | private: |
640 | std::shared_ptr<Graph> graph_; |
641 | }; |
642 | |
643 | } // namespace |
644 | |
645 | void TypePropagate(std::shared_ptr<Graph>& graph) { |
646 | FUSER_PERF_SCOPE("TypePropagate" ); |
647 | GRAPH_DUMP("Before TypePropagate: " , graph); |
648 | NaiveTypePropagator(graph).run(); |
649 | GRAPH_DUMP("After TypePropagate: " , graph); |
650 | } |
651 | |
652 | } // namespace cuda |
653 | } // namespace fuser |
654 | } // namespace jit |
655 | } // namespace torch |
656 | |