1#include <torch/csrc/jit/passes/shape_analysis.h>
2
3#include <c10/util/Exception.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/frontend/error_report.h>
6#include <torch/csrc/jit/ir/alias_analysis.h>
7#include <torch/csrc/jit/ir/constants.h>
8#include <torch/csrc/jit/ir/ir.h>
9#include <torch/csrc/jit/ir/ir_views.h>
10#include <torch/csrc/jit/passes/utils/op_registry.h>
11#include <torch/csrc/jit/runtime/exception_message.h>
12#include <torch/csrc/jit/runtime/operator.h>
13
14#include <torch/csrc/autograd/variable.h>
15
16#include <ATen/DeviceGuard.h>
17#include <ATen/ExpandUtils.h>
18#include <ATen/core/symbol.h>
19
20#ifndef AT_PER_OPERATOR_HEADERS
21#include <ATen/Functions.h>
22#else
23#include <ATen/ops/empty_strided.h>
24#endif
25
26#include <exception>
27#include <iostream>
28#include <memory>
29#include <utility>
30#include <vector>
31
32namespace torch {
33namespace jit {
34
35bool mergeTypes(
36 ArrayRef<Value*> lhs,
37 ArrayRef<Value*> rhs,
38 ArrayRef<Value*> outputs) {
39 AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
40 bool changed = false;
41 for (const auto i : c10::irange(lhs.size())) {
42 auto old_output_type = outputs[i]->type();
43 auto new_type =
44 unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true);
45 AT_ASSERT(new_type);
46 outputs[i]->setType(*new_type);
47 if (*old_output_type != *outputs[i]->type())
48 changed = true;
49 }
50 return changed;
51}
52
53void applyTypes(ArrayRef<Value*> src, ArrayRef<Value*> dst) {
54 AT_ASSERT(src.size() == dst.size());
55 for (const auto i : c10::irange(src.size())) {
56 dst[i]->setType(src[i]->type());
57 }
58}
59
60void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
61 for (Node* node : block->nodes()) {
62 try {
63 propagateNode(node, insert_expands);
64 } catch (propagation_error& e) {
65 setUnshapedType(node);
66 } catch (std::exception& e) {
67 throw ErrorReport(node->sourceRange())
68 << ExceptionMessage(e)
69 << "\nThe above operation failed shape propagation in this context";
70 }
71 }
72}
73
74void PropertyPropBase::processIf(Node* node) {
75 auto then_block = node->blocks().at(0);
76 auto else_block = node->blocks().at(1);
77 propagateBlock(then_block);
78 propagateBlock(else_block);
79 mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs());
80}
81
82void PropertyPropBase::processLoop(Node* node) {
83 LoopView loop(node);
84 // propagate counter type
85 loop.currentTripCount()->setType(loop.maxTripCount()->type());
86 applyTypes(loop.carriedInputs(), loop.bodyCarriedInputs());
87
88 do {
89 propagateBlock(loop.bodyBlock(), /*insert_expands=*/false);
90 // note: inserting expands is unsafe at this point, we don't know
91 // if the types are stable yet, so the arguments to expand may change
92 } while (mergeTypes(
93 loop.bodyCarriedInputs(),
94 loop.bodyCarriedOutputs(),
95 loop.bodyCarriedInputs()));
96
97 // now that the types are stable, we can insert the expands
98 propagateBlock(loop.bodyBlock(), /*insert_expands=*/true);
99 applyTypes(loop.bodyCarriedInputs(), loop.carriedOutputs());
100}
101
102void PropertyPropBase::setUnshapedType(Value* o) {
103 o->setType(unshapedType(o->type()));
104}
105
106void PropertyPropBase::setUnshapedType(Node* node) {
107 for (auto o : node->outputs()) {
108 setUnshapedType(o);
109 }
110}
111
112namespace prim {
113using namespace ::c10::prim;
114}
115
116#define SHAPE_ASSERT(cond) \
117 if (!(cond)) \
118 throw propagation_error()
119
120namespace {
121
122bool isValidArgumentForRunning(Value* v) {
123 // allow constants
124 if (toIValue(v))
125 return true;
126 if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
127 if (!tt->scalarType()) {
128 return false;
129 }
130 return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
131 }
132 return v->type()->isSubtypeOf(*FloatType::get());
133}
134
135bool isValidReturnForRunning(Value* v) {
136 return v->type()->isSubtypeOf(*TensorType::get()) ||
137 v->type()->isSubtypeOf(*NumberType::get());
138}
139
140bool containsTensorType(const TypePtr& t) {
141 auto n_contained = t->containedTypes().size();
142 if (n_contained == 1) {
143 return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get());
144 } else if (n_contained > 1) {
145 return std::any_of(
146 t->containedTypes().begin(),
147 t->containedTypes().end(),
148 containsTensorType);
149 }
150 return false;
151}
152
153// for each node in the schema with type Tensor, extract the T type
154// returns c10::nullopt if any Tensor in the schema does not have a known
155// shape ignores non-tensor in the list of inputs
156c10::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
157 Node* node,
158 bool complete = false) {
159 std::vector<TensorTypePtr> tensor_types;
160
161 auto schema_opt = node->maybeSchema();
162 if (!schema_opt) {
163 return c10::nullopt;
164 }
165 auto& schema = *schema_opt;
166 auto& args = schema.arguments();
167 // can't handle varargs primitives because we don't know what should be a
168 // Tensor
169 if (schema.is_vararg()) {
170 return c10::nullopt;
171 }
172 for (const auto i : c10::irange(args.size())) {
173 if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) {
174 return c10::nullopt;
175 } else if (args[i].type()->isSubtypeOf(*TensorType::get())) {
176 if (auto type = node->input(i)->type()->cast<TensorType>()) {
177 if (complete && !type->isComplete()) {
178 return c10::nullopt;
179 }
180 tensor_types.push_back(type);
181 } else {
182 return c10::nullopt;
183 }
184 } else /* non-tensor type */ {
185 continue;
186 }
187 }
188 return tensor_types;
189}
190
191int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
192 if (dim < 0) {
193 dim += (int64_t)sizes.size();
194 }
195 return dim;
196}
197
198c10::ScalarType unionScalarTypes(
199 c10::ScalarType original,
200 c10::ScalarType next) {
201 if (original == c10::ScalarType::Undefined) {
202 return next;
203 } else {
204 return c10::promoteTypes(original, next);
205 }
206}
207
208// Promotes result types for arithmetic operations on Tensor operands using
209// new type promotion logic. See tensor_attributes.rst for details.
210// This doesn't handle the case of arithmetic ops with Scalar arguments (when
211// `Tensor.getUnsafeTensorImpl()->is_wrapped_nubmer()` would return true)
212c10::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) {
213 c10::ScalarType dimmed = c10::ScalarType::Undefined;
214 c10::ScalarType zerodim = c10::ScalarType::Undefined;
215 // binary arithmetic ops, more than 2 args is alpha.
216 for (const auto i : c10::irange(2)) {
217 auto dtt = node->inputs()[i]->type()->expect<TensorType>();
218 auto inputDtype = dtt->scalarType();
219 if (!dtt || !inputDtype) {
220 return c10::nullopt;
221 }
222 if (dtt->dim() && *dtt->dim() > 0) {
223 dimmed = unionScalarTypes(dimmed, *inputDtype);
224 } else if (!isFloatingType(dimmed)) {
225 // if no dimensions
226 zerodim = unionScalarTypes(zerodim, *inputDtype);
227 }
228 }
229 // if a tensor with dimensions is already of the highest category, don't
230 // need to check zero-dim tensors.
231 if (isFloatingType(dimmed)) {
232 return dimmed;
233 }
234 // int_tensor * zero_dim_floating -> floating_tensor
235 if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) {
236 return zerodim;
237 }
238 // bool_tensor * non_bool_scalar -> non_bool_tensor
239 if (c10::ScalarType::Bool == dimmed &&
240 c10::ScalarType::Undefined != zerodim) {
241 return zerodim;
242 }
243 // types of dimensioned tensors generally take precedence over zero-dim
244 // tensors if not promoting due to category. e.g.:
245 // int_tensor * long -> int_tensor
246 if (c10::ScalarType::Undefined != dimmed) {
247 return dimmed;
248 }
249
250 // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor.
251 return zerodim;
252}
253
254class ShapePropagator : public PropertyPropBase {
255 public:
256 explicit ShapePropagator(const std::shared_ptr<Graph>& graph)
257 : PropertyPropBase(graph), aliasDb_(graph) {
258 collectResizeSet(graph->block());
259 }
260
261 private:
262 ValueSet resized_alias_set;
263 const AliasDb aliasDb_;
264
265 bool resizesInput(Node* n) {
266 static std::unordered_set<Symbol> resize_ops{
267 aten::resize_,
268 aten::resize_as_,
269 aten::copy_,
270 aten::set_,
271 aten::unsqueeze_,
272 aten::t_,
273 aten::transpose_,
274 };
275
276 if (resize_ops.count(n->kind()))
277 return true;
278
279 if (!n->maybeSchema())
280 return false;
281
282 // ops which take the result and write to input "out"
283 if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
284 auto arg = n->schema().arguments().at(*out_arg_index);
285 return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get());
286 }
287 return false;
288 }
289
290 void collectResizeSet(Block* block) {
291 for (Node* n : block->nodes()) {
292 for (Block* b : n->blocks()) {
293 collectResizeSet(b);
294 }
295 if (resizesInput(n)) {
296 for (const auto input : n->inputs()) {
297 if (aliasDb_.writesToAlias(n, {input})) {
298 resized_alias_set.insert(input);
299 }
300 }
301 }
302 }
303 }
304
305 IValue representativeValue(Value* v) {
306 TypePtr type_ = v->type();
307 // if the value is actually constant, just use it!
308 if (auto iv = toIValue(v)) {
309 return *iv;
310 }
311 if (TensorTypePtr type = type_->cast<TensorType>()) {
312 if (type->isComplete()) {
313 at::DeviceGuard device_guard(*type->device());
314 return at::empty_strided(
315 *type->sizes().concrete_sizes(),
316 *type->strides().concrete_sizes(),
317 at::TensorOptions(*type->device())
318 .dtype(*type->scalarType()))
319 .zero_();
320 }
321 // fallthrough
322 } else if (type_->isSubtypeOf(*FloatType::get())) {
323 return 0.f;
324 }
325 // we should not get here because isValidArgumentForRunning should have
326 // prevented it
327 std::stringstream ss;
328 ss << "unable to create representative value for: " << type_->str()
329 << ". File a bug report";
330 throw std::runtime_error(ss.str());
331 }
332
333 void broadcastBinary(
334 Node* node,
335 std::vector<TensorTypePtr>& types,
336 size_t idx1,
337 size_t idx2) {
338 auto expected_size = at::infer_size(
339 *types[idx1]->sizes().concrete_sizes(),
340 *types[idx2]->sizes().concrete_sizes());
341 auto broadcast = [&](size_t input_idx) {
342 TensorTypePtr input_type = types.at(input_idx);
343 if (input_type->sizes() == expected_size)
344 return;
345 auto graph = node->owningGraph();
346 WithInsertPoint point_guard{node};
347 Node* expand = graph
348 ->create(
349 aten::expand,
350 {node->inputs().at(input_idx),
351 graph->insertConstant(expected_size),
352 graph->insertConstant(false)})
353 ->insertBefore(node);
354 propagateNode(expand);
355 node->replaceInput(input_idx, expand->output());
356 };
357 broadcast(idx1);
358 broadcast(idx2);
359 types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
360 types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
361 }
362
363 OperatorSet cannot_propagate_shape_by_running_it = {
364 "aten::inverse(Tensor self) -> Tensor",
365 };
366
367 // Check if this node depends on a value that has been mutated previously. If
368 // it has, then it's not safe to run this node in isolation, since we don't
369 // know whether the dependency has been executed.
370 std::unordered_map<Node*, bool> dependsOnMutationMemo_;
371 bool dependsOnMutation(Node* node) {
372 if (dependsOnMutationMemo_.count(node) != 0) {
373 return dependsOnMutationMemo_[node];
374 }
375
376 if (aliasDb_.hasWriters(node)) {
377 // If something could have written to a value used by this node, we can't
378 // guarantee the result is the same when running it in isolation.
379 dependsOnMutationMemo_[node] = true;
380 return true;
381 }
382
383 // recursively check the producers of its inputs. We need to do this if the
384 // mutable value has been laundered through a pure function:
385 // a += 1
386 // c = a + b
387 // d = c + 1
388 // In this case, `d` cares whether `a` has been mutated even though it's not
389 // a direct input.
390 auto depends = false;
391 for (auto input : node->inputs()) {
392 depends |= dependsOnMutation(input->node());
393 }
394
395 dependsOnMutationMemo_[node] = depends;
396 return depends;
397 }
398
399 bool canPropagateShapeByRunningIt(Node* node) {
400 if (node->isMemberOf(cannot_propagate_shape_by_running_it)) {
401 return false;
402 }
403
404 if (dependsOnMutation(node)) {
405 return false;
406 }
407
408 bool valid_args = std::all_of(
409 node->inputs().begin(),
410 node->inputs().end(),
411 isValidArgumentForRunning);
412 if (!valid_args)
413 return false;
414
415 bool valid_returns = std::all_of(
416 node->outputs().begin(),
417 node->outputs().end(),
418 isValidReturnForRunning);
419 if (!valid_returns)
420 return false;
421
422 return true;
423 }
424
425 // If there's no Tensor in outputs, e.g float / float,
426 // we don't need to propagate shape.
427 bool DoesntRefineOutputs(Node* node) {
428 auto outputs = node->outputs();
429 for (auto& out : outputs) {
430 if (containsTensorType(out->type())) {
431 return false;
432 }
433 }
434 return true;
435 }
436
437 bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
438 if (!canPropagateShapeByRunningIt(node))
439 return false;
440
441 if (!op)
442 op = node->getOperation();
443
444 Stack stack;
445
446 for (auto input : node->inputs()) {
447 stack.push_back(representativeValue(input));
448 }
449
450 // XXX: we're not catching any exceptions from the op for now. This
451 // is to uncover any mistakes we could make when editing this code,
452 // and eventually it shouldn't matter, because this phase should be
453 // preceded by schema checking.
454 op(stack);
455
456 AT_ASSERT(stack.size() == node->outputs().size());
457 for (const auto i : c10::irange(stack.size())) {
458 // some ops may have mixed tensor/primitive outputs
459 // for primitives, we don't need to change the type because it is already
460 // its most constrained form.
461 auto tensor_type = node->outputs()[i]->type()->cast<TensorType>();
462 if (stack[i].isTensor() && tensor_type) {
463 // gradient information isn't always available or part of represenative
464 // inputs, maintain original grad property
465 auto tensor_grad = tensor_type->requiresGrad();
466 node->outputs()[i]->setType(TensorType::create(stack[i].toTensor())
467 ->withRequiresGrad(tensor_grad));
468 }
469 }
470 return true;
471 }
472
473 void PropagateCatShape(Node* cat_node) {
474 static const auto propagate_complete =
475 [](Node* node, at::ArrayRef<Value*> tensors) -> bool {
476 auto input_types =
477 fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
478 if (!std::all_of(
479 input_types.begin(),
480 input_types.end(),
481 [](const TensorTypePtr& tp) {
482 return tp != nullptr && tp->isComplete();
483 })) {
484 return false;
485 }
486 if (!node->is_constant(attr::dim))
487 return false;
488 std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
489 const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
490 const int64_t ndim = (int64_t)sizes.size();
491
492 if (dim < 0 || dim >= ndim)
493 return false;
494
495 sizes[dim] = 0;
496 for (auto& tp : input_types) {
497 auto tp_sizes = tp->sizes().concrete_sizes().value();
498 if (sizes.size() != tp_sizes.size())
499 return false;
500 for (const auto i : c10::irange(ndim)) {
501 if (sizes[i] != tp_sizes[i] && i != dim) {
502 return false;
503 }
504 }
505 sizes[dim] += tp_sizes[dim];
506 }
507 node->output()->setType(input_types[0]->withSizes(sizes));
508 return true;
509 };
510 static const auto propagate = [](Node* node,
511 at::ArrayRef<Value*> tensors) -> bool {
512 for (Value* v : tensors) {
513 if (auto type = v->type()->cast<TensorType>()) {
514 node->output()->setType(type->dimensionedOnly());
515 return true;
516 }
517 }
518 return false;
519 };
520 auto list_node =
521 ((cat_node->kind() == prim::FusedConcat)
522 ? cat_node
523 : cat_node->namedInput(attr::tensors)->node());
524 if (list_node->kind() == prim::ListConstruct ||
525 cat_node->kind() == prim::FusedConcat) {
526 auto tensors = list_node->inputs();
527 if (!tensors.empty()) {
528 // NOLINTNEXTLINE(bugprone-branch-clone)
529 if (propagate_complete(cat_node, tensors)) {
530 return;
531 } else if (propagate(cat_node, tensors)) {
532 return;
533 }
534 }
535 }
536 setUnshapedType(cat_node);
537 }
538
539 void propagateTorchTensorShape(Node* node) {
540 auto input_type = node->inputs().at(0)->type();
541
542 size_t dims = 0;
543 auto input_base_type = input_type;
544 auto list_type = input_type->cast<ListType>();
545 while (list_type) {
546 dims++;
547 input_base_type = list_type->getElementType();
548 list_type = input_base_type->cast<ListType>();
549 }
550
551 at::optional<at::ScalarType> default_type =
552 tryScalarTypeFromJitType(*input_base_type);
553 if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
554 auto inp = toIValue(node->inputs().at(*grad_index));
555 if (inp == c10::nullopt) {
556 return;
557 } else if (!inp->isNone()) {
558 default_type = inp->toScalarType();
559 }
560 }
561
562 at::Device default_device = at::kCPU;
563 if (auto device_index = node->schema().argumentIndexWithName("device")) {
564 auto inp = toIValue(node->inputs().at(*device_index));
565 if (inp == c10::nullopt) {
566 return;
567 } else if (!inp->isNone()) {
568 default_device = inp->toDevice();
569 }
570 }
571 node->output()->setType(TensorType::create(
572 default_type, default_device, dims, /*requires_grad=*/c10::nullopt));
573 }
574
575 // returns whether any such values were found
576 bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) {
577 bool in_resize = false;
578 for (auto v : vs) {
579 if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
580 setUnshapedType(v);
581 in_resize = true;
582 }
583 }
584 return in_resize;
585 }
586
587 void propagateNode(Node* node, bool insert_expands = true) override {
588 // Certain ops like resize_ change the input tensors size. Because our
589 // analysis is flow invariant, we set any Tensor that can alias a resized
590 // Tensor to the base Tensor Type without size information.
591 if (setUnshapedTypeIfAliasResizedSet(node->inputs())) {
592 return setUnshapedType(node);
593 }
594
595 // These don't require the types, and have complicated schema. Return early
596 // after we process them.
597 switch (node->kind()) {
598 case prim::If:
599 return processIf(node);
600 case prim::Loop: {
601 return processLoop(node);
602 }
603 case aten::Bool:
604 case aten::Int:
605 case aten::Float:
606 case aten::ScalarImplicit:
607 case aten::FloatImplicit:
608 case aten::IntImplicit:
609 return; // correct num type is already set
610 case prim::NumToTensor: {
611 TypePtr typ = node->input()->type();
612 if (typ->isSubtypeOf(*IntType::get()) ||
613 typ->isSubtypeOf(*BoolType::get())) {
614 node->output()->setType(TensorType::create(
615 at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
616 } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) {
617 node->output()->setType(TensorType::create(
618 at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
619 }
620 return;
621 }
622 case aten::tensor:
623 case aten::as_tensor: {
624 // as_tensor has an overloaded schema and can either have a tensor or
625 // a list as the first input, if the input is a tensor, we delegate
626 // the shape propagation in PropagateTensorShapeOnNode
627 if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) {
628 break;
629 }
630 return propagateTorchTensorShape(node);
631 }
632 case prim::TupleConstruct: {
633 // We refresh the tuple type, because the input types could have been
634 // refined.
635 auto orig_type = node->output()->type()->expect<TupleType>();
636 auto new_types =
637 fmap(node->inputs(), [](Value* v) { return v->type(); });
638 node->output()->setType(
639 orig_type->createWithContained(std::move(new_types)));
640 return;
641 }
642 case prim::TupleUnpack: {
643 auto tuple_type = node->input()->type()->cast<TupleType>();
644 AT_ASSERT(
645 tuple_type &&
646 tuple_type->elements().size() == node->outputs().size());
647 auto elems = tuple_type->elements();
648 for (size_t i = 0; i < node->outputs().size(); ++i) {
649 node->output(i)->setType(elems[i]);
650 }
651 return;
652 }
653 case prim::Constant: {
654 if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
655 node->output()->inferTypeFrom(node->t(attr::value));
656 }
657 return;
658 }
659 case prim::unchecked_unwrap_optional: {
660 // If we have specialized the optional type to the element type,
661 // we want to pass it down. We write this as input.isSubtypeOf(output)
662 // to be sure that we don't screw up nested optionals.
663 if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
664 node->output()->setType(node->input()->type());
665 }
666 return;
667 }
668 case prim::ConstantChunk: {
669 Value* tensor = node->input();
670 if (auto type = tensor->type()->cast<TensorType>()) {
671 type = type->dimensionedOnly();
672 for (Value* output : node->outputs()) {
673 output->setType(type);
674 }
675 } else {
676 setUnshapedType(node);
677 }
678 return;
679 }
680 case prim::grad: {
681 auto tt = node->input()->type()->expect<TensorType>();
682 // grad may be undefined
683 // requires_grad may be required
684 auto grad_type = TensorType::get()->withPossiblyUndefined();
685 node->output()->setType(std::move(grad_type));
686 return;
687 }
688 case prim::CallFunction:
689 case prim::CallMethod:
690 case prim::AutogradZero: {
691 setUnshapedType(node);
692 return;
693 }
694 case prim::GetAttr: {
695 auto cls = node->input()->type()->expect<ClassType>();
696 // propagate any type specializations encoded in the type of the class
697 node->output()->setType(cls->getAttribute(node->s(attr::name)));
698 return;
699 }
700 case aten::_unwrap_optional: {
701 // If we have specialized the optional type to the element type,
702 // we want to pass it down. We write this as input.isSubtypeOf(output)
703 // to be sure that we don't screw up nested optionals.
704 if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
705 node->output()->setType(node->input()->type());
706 }
707 return;
708 }
709 default:
710 break; // fall-through
711 }
712
713 if (node->hasSideEffects()) {
714 return;
715 }
716
717 if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
718 node->kind() == prim::FusedConcat) {
719 return PropagateCatShape(node);
720 }
721
722 if (auto maybe_complete_types =
723 gatherTensorTypes(node, /*complete=*/true)) {
724 if (PropagateCompleteShapeOnNode(
725 node, insert_expands, std::move(*maybe_complete_types))) {
726 return;
727 }
728 }
729
730 if (PropagateTensorShapeOnNode(node, insert_expands)) {
731 return;
732 }
733
734 if (DoesntRefineOutputs(node)) {
735 return;
736 }
737
738 if (PropagateShapeOnNodeByRunningIt(node)) {
739 return;
740 }
741 return setUnshapedType(node);
742 }
743
744 static c10::optional<size_t> determineListSize(Value* list) {
745 AT_ASSERT(list->type()->cast<ListType>());
746 if (auto shape = constant_as<c10::List<int64_t>>(list)) {
747 return shape->size();
748 }
749 auto input_node = list->node();
750 if (input_node->kind() == prim::ListConstruct) {
751 return input_node->inputs().size();
752 }
753 return c10::nullopt;
754 }
755
756 // is it ok to try to run the op
757 // If an input is a constant, then we assume that the input is valid
758 // and we can try to run it.
759 // Otherwise:
760 // Integral typed _inputs_ are often an indicator that we're indexing into
761 // a tensor, so we should special-case these ops in the shape propagation.
762 // Additionally, passing in a zero representative tensor into an integer
763 // division op causes divide-by-zero errors
764 // _Outputs_ must be tensors or primitives
765 // We will call inferTypeFrom on the tensors, and ignore the primitives.
766 // However, we allow primitive returns because we want to support mixed
767 // primitive/tensor outputs.
768
769 bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
770 static const auto broadcast =
771 [](std::vector<TensorTypePtr>& tensor_types,
772 c10::optional<at::ScalarType> t) -> TensorTypePtr {
773 if (tensor_types.size() == 1) {
774 return tensor_types[0]->dimensionedOnly()->withScalarType(t);
775 }
776 AT_ASSERT(!tensor_types.empty());
777 auto any_type = tensor_types[0];
778 auto max_dims = any_type->dim();
779 for (auto& type : tensor_types) {
780 if (!max_dims || !type->dim()) {
781 max_dims = c10::nullopt;
782 } else {
783 max_dims = std::max(*max_dims, *type->dim());
784 }
785 }
786 return TensorType::create(
787 t,
788 any_type->device(),
789 max_dims,
790 /*requires_grad=*/c10::nullopt);
791 };
792
793 using type_vec_t = std::vector<TensorTypePtr>;
794 // Formula is expected to return a vector of length equal to the number of
795 // tensor outputs of the node, or an empty vector which implies that it
796 // failed to propagate.
797 using formula_t = std::function<type_vec_t(Node*)>;
798 static std::mutex shape_formulas_mutex;
799 static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
800 struct register_formula_for {
801 register_formula_for(OperatorSet operators, formula_t formula) {
802 std::unique_lock<std::mutex> lock{shape_formulas_mutex};
803 shape_formulas.emplace_back(std::move(operators), std::move(formula));
804 }
805 };
806
807 // Requirements:
808 // dims : preserved
809 // scalar type : preserved
810 // device : preserved
811 // tensor inputs : 1
812 // tensor outputs : 1
813 // Additionally:
814 // - First input should be the only tensor input
815 static const register_formula_for simple_unary_ops{
816 {
817 "aten::acos(Tensor self) -> Tensor",
818 "aten::neg(Tensor self) -> Tensor",
819 "aten::t(Tensor self) -> Tensor",
820 "aten::sigmoid(Tensor self) -> Tensor",
821 "aten::logit(Tensor self, float? eps=None) -> Tensor",
822 "aten::tanh(Tensor self) -> Tensor",
823 "aten::relu(Tensor self) -> Tensor",
824 "aten::asin(Tensor self) -> Tensor",
825 "aten::atan(Tensor self) -> Tensor",
826 "aten::ceil(Tensor self) -> Tensor",
827 "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
828 "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
829 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
830 "aten::celu(Tensor self, Scalar alpha) -> Tensor",
831 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
832 "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
833 "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
834 "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
835 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
836 "aten::cos(Tensor self) -> Tensor",
837 "aten::cosh(Tensor self) -> Tensor",
838 "aten::digamma(Tensor self) -> Tensor",
839 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
840 "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
841 "aten::erf(Tensor self) -> Tensor",
842 "aten::erfc(Tensor self) -> Tensor",
843 "aten::erfinv(Tensor self) -> Tensor",
844 "aten::exp(Tensor self) -> Tensor",
845 "aten::expm1(Tensor self) -> Tensor",
846 "aten::log(Tensor self) -> Tensor",
847 "aten::log10(Tensor self) -> Tensor",
848 "aten::log1p(Tensor self) -> Tensor",
849 "aten::log2(Tensor self) -> Tensor",
850 "aten::log_sigmoid(Tensor self) -> Tensor",
851 "aten::floor(Tensor self) -> Tensor",
852 "aten::frac(Tensor self) -> Tensor",
853 "aten::flip(Tensor self, int[] dims) -> Tensor",
854 "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
855 "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
856 "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
857 "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
858 "aten::glu(Tensor self, int dim) -> Tensor",
859 "aten::inverse(Tensor self) -> Tensor",
860 "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
861 "aten::lgamma(Tensor self) -> Tensor",
862 "aten::mvlgamma(Tensor self, int p) -> Tensor",
863 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
864 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
865 "aten::permute(Tensor self, int[] dims) -> Tensor",
866 "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
867 "aten::pinverse(Tensor self, float rcond) -> Tensor",
868 "aten::reciprocal(Tensor self) -> Tensor",
869 "aten::relu(Tensor self) -> Tensor",
870 "aten::round(Tensor self) -> Tensor",
871 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
872 "aten::rsqrt(Tensor self) -> Tensor",
873 "aten::selu(Tensor self) -> Tensor",
874 "aten::gelu(Tensor self, *, str approximate='none') -> Tensor",
875 "aten::sigmoid(Tensor self) -> Tensor",
876 "aten::sign(Tensor self) -> Tensor",
877 "aten::sin(Tensor self) -> Tensor",
878 "aten::sinh(Tensor self) -> Tensor",
879 "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
880 "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
881 "aten::sqrt(Tensor self) -> Tensor",
882 "aten::tan(Tensor self) -> Tensor",
883 "aten::tanh(Tensor self) -> Tensor",
884 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
885 "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
886 "aten::tril(Tensor self, int diagonal) -> Tensor",
887 "aten::triu(Tensor self, int diagonal) -> Tensor",
888 "aten::trunc(Tensor self) -> Tensor",
889 "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
890 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
891 "aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
892 "aten::alias(Tensor self) -> Tensor",
893 },
894 [](Node* node) -> type_vec_t {
895 auto input_type = node->input(0)->type()->cast<TensorType>();
896 return input_type ? type_vec_t{input_type->dimensionedOnly()}
897 : type_vec_t{};
898 }};
899
900 // Requirements:
901 // dims : preserved
902 // scalar type : preserved, except complex maps to float
903 // device : preserved
904 // tensor inputs : 1
905 // tensor outputs : 1
906 // Additionally:
907 // - First input should be the only tensor input
908 static const register_formula_for simple_unary_ops_complex_to_float{
909 {
910 "aten::abs(Tensor self) -> Tensor",
911 },
912 [](Node* node) -> type_vec_t {
913 auto input_type = node->input(0)->type()->cast<TensorType>();
914
915 // Maps complex -> float
916 if (input_type->scalarType()) {
917 const auto scalar_type = *(input_type->scalarType());
918 if (isComplexType(scalar_type)) {
919 const auto out_type = c10::toRealValueType(scalar_type);
920 return type_vec_t{
921 input_type->dimensionedOnly()->withScalarType(out_type)};
922 }
923 }
924
925 return input_type ? type_vec_t{input_type->dimensionedOnly()}
926 : type_vec_t{};
927 }};
928
929 // Requirements:
930 // dims : broadcast all tensor args
931 // scalar type : promoted from input dtypes
932 // device : always matching and preserved
933 // tensor inputs : *
934 // tensor outputs : 1
935 static const register_formula_for broadcasting_ops_arithmetic{
936 {
937 // Tensor-Tensor operators
938 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
939 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
940 "aten::mul(Tensor self, Tensor other) -> Tensor",
941 "aten::div(Tensor self, Tensor other) -> Tensor",
942 },
943 [](Node* node) -> type_vec_t {
944 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
945 AT_ASSERT(maybe_tensor_types->size() >= 2);
946 auto dtype = getPromotedTypeForArithmeticOp(node);
947 return {broadcast(*maybe_tensor_types, dtype)};
948 }
949 return {};
950 }};
951
952 // Requirements:
953 // dims : broadcast all tensor args
954 // scalar type : always matching and preserved
955 // device : always matching and preserved
956 // tensor inputs : *
957 // tensor outputs : 1
958 static const register_formula_for broadcasting_ops{
959 {
960 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
961 "aten::fmod(Tensor self, Tensor other) -> Tensor",
962 "aten::remainder(Tensor self, Tensor other) -> Tensor",
963 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
964 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
965 "aten::max(Tensor self, Tensor other) -> Tensor",
966 "aten::min(Tensor self, Tensor other) -> Tensor",
967 "aten::__and__(Tensor self, Tensor other) -> Tensor",
968 "aten::__or__(Tensor self, Tensor other) -> Tensor",
969 "aten::__xor__(Tensor self, Tensor other) -> Tensor",
970 "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
971 "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
972 "aten::__iand__(Tensor self, Tensor other) -> Tensor",
973 "aten::__ior__(Tensor self, Tensor other) -> Tensor",
974 "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
975 "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
976 "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
977
978 // Ops with Tensor-Tensor overloads only
979 "aten::atan2(Tensor self, Tensor other) -> Tensor",
980 },
981 [](Node* node) -> type_vec_t {
982 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
983 AT_ASSERT(maybe_tensor_types->size() >= 2);
984 auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
985 auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
986 if (!first_scalar_type || !second_scalar_type) {
987 return {};
988 }
989 size_t arg_for_type = 0;
990 if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) !=
991 first_scalar_type) {
992 arg_for_type = 1;
993 }
994 auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
995 return {broadcast(*maybe_tensor_types, *t)};
996 }
997 return {};
998 }};
999
1000 static const register_formula_for fused_accum_binary_ops{
1001 {
1002 // Non-binary ops
1003 "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1004 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1005 },
1006 [](Node* node) -> type_vec_t {
1007 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1008 auto dtype = (*maybe_tensor_types)[0]->scalarType();
1009 if (!dtype) {
1010 return {};
1011 }
1012 return {broadcast(*maybe_tensor_types, *dtype)};
1013 }
1014 return {};
1015 }};
1016
1017 static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{
1018 {
1019 // Tensor-Scalar operators
1020 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1021 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1022 "aten::mul(Tensor self, Scalar other) -> Tensor",
1023 "aten::div(Tensor self, Scalar other) -> Tensor",
1024 },
1025 [](Node* node) -> type_vec_t {
1026 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1027 auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
1028 auto second_scalar_type =
1029 tryScalarTypeFromJitType(*node->inputs()[1]->type());
1030 if (!first_scalar_type || !second_scalar_type) {
1031 return {};
1032 }
1033 if (isIntegralType(*first_scalar_type, false) &&
1034 isFloatingType(*second_scalar_type)) {
1035 auto default_dtype =
1036 at::typeMetaToScalarType(caffe2::get_default_dtype());
1037 return {broadcast(*maybe_tensor_types, default_dtype)};
1038 }
1039 if (c10::ScalarType::Bool == *first_scalar_type &&
1040 c10::ScalarType::Bool != *second_scalar_type) {
1041 auto result_type =
1042 c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1043 return {broadcast(*maybe_tensor_types, result_type)};
1044 }
1045 return {broadcast(*maybe_tensor_types, first_scalar_type)};
1046 }
1047 return {};
1048 }};
1049
1050 // NB: we always take the scalar type of the Tensor
1051 static const register_formula_for broadcasting_tensor_scalar_ops{
1052 {
1053
1054 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
1055 "aten::fmod(Tensor self, Scalar other) -> Tensor",
1056 "aten::remainder(Tensor self, Scalar other) -> Tensor",
1057 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
1058 "aten::__and__(Tensor self, Scalar other) -> Tensor",
1059 "aten::__or__(Tensor self, Scalar other) -> Tensor",
1060 "aten::__xor__(Tensor self, Scalar other) -> Tensor",
1061 "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
1062 "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
1063 "aten::__iand__(Tensor self, Scalar other) -> Tensor",
1064 "aten::__ior__(Tensor self, Scalar other) -> Tensor",
1065 "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
1066 "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
1067 "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
1068 },
1069 [](Node* node) -> type_vec_t {
1070 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1071 return {broadcast(
1072 *maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())};
1073 }
1074 return {};
1075 }};
1076
1077 // aten::where is special in that its return type is the second argument's
1078 // (self) type rather than the that of condition
1079 static const register_formula_for where_op{
1080 {
1081 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
1082 },
1083 [](Node* node) -> type_vec_t {
1084 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1085 return {broadcast(
1086 *maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())};
1087 }
1088 return {};
1089 }};
1090
1091 static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
1092 for (Value* input : node->inputs()) {
1093 if (auto type = input->type()->cast<TensorType>()) {
1094 if (type->dim().has_value()) {
1095 return type;
1096 }
1097 }
1098 }
1099 return nullptr;
1100 };
1101
1102 // Requirements:
1103 // dims : always matching and preserved
1104 // scalar type : always matching and preserved
1105 // device : always matching and preserved
1106 // tensor inputs : 2
1107 // tensor outputs : 1
1108 static const register_formula_for binary_ops_strict_match{
1109 {
1110 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
1111 "aten::mm(Tensor self, Tensor mat2) -> Tensor",
1112 "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
1113 },
1114 [](Node* node) -> type_vec_t {
1115 if (auto type = any_tensor_type(node)) {
1116 return {std::move(type)};
1117 }
1118 return {};
1119 }};
1120
1121 // Requirements:
1122 // dims : all tensor args are broadcast
1123 // scalar type : byte/uint8
1124 // device : always matching and preserved
1125 // tensor inputs : *
1126 // tensor outputs : 1
1127 static const register_formula_for comparison_ops{
1128 {
1129 "aten::lt(Tensor self, Tensor other) -> Tensor",
1130 "aten::le(Tensor self, Tensor other) -> Tensor",
1131 "aten::gt(Tensor self, Tensor other) -> Tensor",
1132 "aten::ge(Tensor self, Tensor other) -> Tensor",
1133 "aten::eq(Tensor self, Tensor other) -> Tensor",
1134 "aten::ne(Tensor self, Tensor other) -> Tensor",
1135 "aten::lt(Tensor self, Scalar other) -> Tensor",
1136 "aten::le(Tensor self, Scalar other) -> Tensor",
1137 "aten::gt(Tensor self, Scalar other) -> Tensor",
1138 "aten::ge(Tensor self, Scalar other) -> Tensor",
1139 "aten::eq(Tensor self, Scalar other) -> Tensor",
1140 "aten::ne(Tensor self, Scalar other) -> Tensor",
1141 },
1142 [](Node* node) -> type_vec_t {
1143 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1144 return {broadcast(*maybe_tensor_types, at::kBool)};
1145 }
1146 return {};
1147 }};
1148
1149 static const register_formula_for nn_ops_first_input_formula{
1150 *nn_ops_first_input_preserving(), [](Node* node) -> type_vec_t {
1151 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1152 return {type->dimensionedOnly()};
1153 }
1154 return {};
1155 }};
1156
1157 // Requirements:
1158 // dims : 0
1159 // scalar type : preserved
1160 // device : preserved
1161 // tensor inputs : 1
1162 // tensor outputs : 1
1163 // Additionally:
1164 // - First input should be the only tensor input
1165 static const register_formula_for all_reduce_ops{
1166 {
1167 "aten::det(Tensor self) -> Tensor",
1168 "aten::logdet(Tensor self) -> Tensor",
1169 "aten::max(Tensor self) -> Tensor",
1170 "aten::min(Tensor self) -> Tensor",
1171 "aten::median(Tensor self) -> Tensor",
1172 "aten::nanmedian(Tensor self) -> Tensor",
1173 "aten::norm(Tensor self, Scalar p) -> Tensor",
1174 "aten::std(Tensor self, bool unbiased) -> Tensor",
1175 "aten::trace(Tensor self) -> Tensor",
1176 "aten::var(Tensor self, bool unbiased) -> Tensor",
1177 "aten::all(Tensor self) -> Tensor",
1178 "aten::any(Tensor self) -> Tensor",
1179 },
1180 [](Node* node) -> type_vec_t {
1181 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1182 return {type->withDim(0)};
1183 }
1184 return {};
1185 }};
1186
1187 // Requirements:
1188 // dims : 0
1189 // scalar type : dtype if specified, else preserved
1190 // device : preserved
1191 // tensor inputs : 1
1192 // tensor outputs : 1
1193 // Additionally:
1194 // - First input should be the only tensor input
1195 static const register_formula_for reduce_ops_with_opt_dtype{
1196 {"aten::mean(Tensor self, *, int? dtype) -> Tensor"},
1197 [](Node* node) -> type_vec_t {
1198 at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1199 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1200 auto ret = type->withDim(0);
1201 if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1202 return {ret->withScalarType(maybe_dtype_option->toScalarType())};
1203 } else {
1204 return {std::move(ret)};
1205 }
1206 }
1207 return {};
1208 }};
1209
1210 // Requirements:
1211 // dims : 0
1212 // scalar type : dtype if specified, else preserved if floating point,
1213 // otherwise long/int64 device : preserved tensor inputs : 1
1214 // tensor outputs : 1
1215 // Additionally:
1216 // - First input should be the only tensor input
1217 static const register_formula_for
1218 all_reduce_ops_with_integer_upcast_and_dtype{
1219 {
1220 "aten::sum(Tensor self, *, int? dtype) -> Tensor",
1221 "aten::prod(Tensor self, *, int? dtype) -> Tensor",
1222 },
1223 [](Node* node) -> type_vec_t {
1224 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1225 type = type->withDim(0);
1226 at::optional<IValue> maybe_dtype_option =
1227 node->get(attr::dtype);
1228 if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1229 return {
1230 type->withScalarType(maybe_dtype_option->toScalarType())};
1231 }
1232 if (type->scalarType()) {
1233 return {
1234 at::isFloatingType(*type->scalarType())
1235 ? std::move(type)
1236 : type->withScalarType(at::kLong)};
1237 } else {
1238 return {std::move(type)};
1239 }
1240 }
1241 return {};
1242 }};
1243
1244 static const auto reduce_op_handler = [](Node* node,
1245 int64_t num_reduced_dim = 0,
1246 bool upcast_integer = false,
1247 c10::optional<IValue> opt_dtype =
1248 c10::nullopt) -> type_vec_t {
1249 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1250 if (!type->scalarType() || !type->dim()) {
1251 return {};
1252 }
1253 if (opt_dtype && !opt_dtype->isNone()) {
1254 type = type->withScalarType(opt_dtype->toScalarType());
1255 } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
1256 type = type->withScalarType(at::kLong);
1257 }
1258 // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
1259 if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) {
1260 return {type->withDim(*type->dim() - num_reduced_dim)};
1261 } else {
1262 return {std::move(type)};
1263 }
1264 }
1265 return {};
1266 };
1267
1268 static const auto multidim_reduce_with_keepdim =
1269 [](Node* node,
1270 int64_t num_reduced_dim,
1271 bool upcast_integer) -> type_vec_t {
1272 auto maybe_keepdim = node->get<bool>(attr::keepdim);
1273 if (!maybe_keepdim)
1274 return {};
1275 return reduce_op_handler(
1276 node, *maybe_keepdim ? 0 : num_reduced_dim, upcast_integer);
1277 };
1278
1279 // Requirements:
1280 // dims : 0 if dim is None, otherwise preserved if keepdim ==
1281 // false or 1 smaller otherwise scalar type : preserved device :
1282 // preserved tensor inputs : 1 tensor outputs : 1
1283 // Additionally:
1284 // - First input should be the only tensor input
1285 // - Has a bool keepdim argument
1286 static const register_formula_for argminmax{
1287 {
1288 "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
1289 "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
1290 },
1291 [](Node* node) -> type_vec_t {
1292 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1293 if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
1294 return {type->withDim(0)};
1295 } else {
1296 return multidim_reduce_with_keepdim(
1297 node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1298 }
1299 }
1300 return {};
1301 }};
1302
1303 // Requirements:
1304 // dims : preserved if keepdim == false, 1 smaller otherwise
1305 // scalar type : preserved for first output, byte/uint8 for second
1306 // output if exists device : preserved tensor inputs : 1 tensor
1307 // outputs : 1 or 2
1308 // Additionally:
1309 // - First input should be the only tensor input
1310 // - Has a bool keepdim argument
1311 static const register_formula_for dim_reduce_ops{
1312 {
1313 "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
1314 "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
1315
1316 // Ops returning indices as second output
1317 "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
1318 "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1319 "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1320 "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1321 "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1322 "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1323 },
1324 [](Node* node) -> type_vec_t {
1325 // NB: Note that while this function is generally meant to be used
1326 // with ops that have a single output, we will fix up its return right
1327 // below.
1328 auto output_types = multidim_reduce_with_keepdim(
1329 node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1330 if (!output_types.empty() && node->outputs().size() == 2) {
1331 output_types.push_back(
1332 output_types.back()->withScalarType(at::kLong));
1333 }
1334 return output_types;
1335 }};
1336
1337 // Requirements:
1338 // dims : preserved if keepdim == false, 1 smaller otherwise
1339 // scalar type : dtype if specified. preserved if floating point,
1340 // otherwise long/int64 device : preserved tensor inputs : 1
1341 // tensor outputs : 1
1342 // Additionally:
1343 // - First input should be the only tensor input
1344 // - has a bool keepdim argument
1345 static const register_formula_for dim_reduce_ops_with_integer_upcast{
1346 {
1347 "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor",
1348 },
1349 [](Node* node) -> type_vec_t {
1350 auto maybe_keepdim = node->get<bool>(attr::keepdim);
1351 at::optional<IValue> opt_dtype = node->get(attr::dtype);
1352 return reduce_op_handler(
1353 node,
1354 /*num_reduce_dim=*/*maybe_keepdim ? 0 : 1,
1355 /*integer_upcast=*/true,
1356 std::move(opt_dtype));
1357 }};
1358
1359 // Requirements:
1360 // dims : preserved
1361 // scalar type : dtype if specified, preserved if floating point,
1362 // otherwise long/int64
1363 // device : preserved
1364 // tensor inputs : 1
1365 // tensor outputs : 1
1366 // Additionally:
1367 // - First input should be the only tensor input
1368 static const register_formula_for dim_reduce_ops_dtype{
1369 {"aten::cumprod(Tensor self, int dim, *, int? dtype) -> Tensor",
1370 "aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor",
1371 "aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1372 [](Node* node) -> type_vec_t {
1373 at::optional<IValue> opt_dtype = node->get(attr::dtype);
1374 return reduce_op_handler(
1375 node,
1376 /*num_reduce_dim=*/0,
1377 /*integer_upcast=*/true,
1378 std::move(opt_dtype));
1379 }};
1380
1381 // Requirements:
1382 // dims : preserved
1383 // scalar type : dtype if specified, otherwise preserved
1384 // device : preserved
1385 // tensor inputs : 1
1386 // tensor outputs : 1
1387 // Additionally:
1388 // - has bool keepdim and int[] dim arguments
1389 static const register_formula_for register_softmax{
1390 {"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1391 [](Node* node) -> type_vec_t {
1392 at::optional<IValue> opt_dtype = node->get(attr::dtype);
1393 return reduce_op_handler(
1394 node,
1395 /*num_reduced_dim=*/0,
1396 /*upcast_integer=*/false,
1397 std::move(opt_dtype));
1398 }};
1399
1400 static const auto factory_with_ndim =
1401 [](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t {
1402 at::optional<IValue> maybe_layout_option = node->get(attr::layout);
1403 if (!maybe_layout_option)
1404 return {};
1405
1406 at::optional<IValue> maybe_device_option = node->get(attr::device);
1407 if (!maybe_device_option)
1408 return {};
1409 auto device =
1410 (maybe_device_option->isNone() ? at::kCPU
1411 : maybe_device_option->toDevice());
1412
1413 at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1414 if (!maybe_dtype_option)
1415 return {};
1416 auto dtype =
1417 (maybe_dtype_option->isNone() ? default_dtype
1418 : maybe_dtype_option->toScalarType());
1419
1420 return {TensorType::create(
1421 dtype, device, dim, /*requires_grad=*/c10::nullopt)};
1422 };
1423
1424 static const auto factory_like_with_ndim = [](Node* node,
1425 int dim) -> type_vec_t {
1426 auto tt = node->input(0)->type()->expect<TensorType>();
1427 auto in_type = tt->scalarType();
1428 auto in_dev = tt->device();
1429
1430 at::optional<IValue> maybe_layout_option = node->get(attr::layout);
1431 if (!maybe_layout_option)
1432 return {};
1433
1434 at::optional<IValue> maybe_device_option = node->get(attr::device);
1435 if (!maybe_device_option)
1436 return {};
1437
1438 if (!maybe_device_option->isNone()) {
1439 in_dev = maybe_device_option->toDevice();
1440 }
1441
1442 at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1443 if (!maybe_dtype_option)
1444 return {};
1445
1446 if (!maybe_dtype_option->isNone()) {
1447 in_type = maybe_dtype_option->toScalarType();
1448 }
1449
1450 return {TensorType::create(
1451 in_type, in_dev, dim, /*requires_grad=*/c10::nullopt)};
1452 };
1453
1454 // Requirements:
1455 // dims : preserved
1456 // scalar type : equal to value of dtype
1457 // device : equal to value of device
1458 // tensor inputs : 1
1459 // tensor outputs : 1
1460 // Additionally:
1461 // - has ScalarType dtype, Layeout layout and Device device arguments
1462 static const register_formula_for like_factories_with_options{
1463 {
1464 "aten::empty_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1465 "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1466 "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1467 "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1468 "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1469 "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1470 "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1471 "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1472 },
1473 [](Node* node) -> type_vec_t {
1474 if (auto type =
1475 node->namedInput(attr::self)->type()->cast<TensorType>()) {
1476 if (type->dim()) {
1477 return factory_like_with_ndim(node, (int)*type->dim());
1478 }
1479 }
1480 return {};
1481 }};
1482
1483 // Requirements:
1484 // dims : equal to number of elements in size
1485 // scalar type : equal to value of dtype
1486 // device : equal to value of device
1487 // tensor inputs : 1
1488 // tensor outputs : 1
1489 // Additionally:
1490 // - has int[] size, ScalarType dtype, Layeout layout and Device device
1491 // arguments
1492 static const register_formula_for size_factories_with_options{
1493 {
1494 "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor",
1495 "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1496 "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1497 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1498 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1499 "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1500 },
1501 [](Node* node) -> type_vec_t {
1502 if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1503 return factory_with_ndim(
1504 node, (int)maybe_size->size(), at::kDouble);
1505 }
1506 return {};
1507 }};
1508
1509 static const register_formula_for randint{
1510 {
1511 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1512 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1513 },
1514 [](Node* node) -> type_vec_t {
1515 if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1516 return factory_with_ndim(node, (int)maybe_size->size(), at::kLong);
1517 }
1518 return {};
1519 }};
1520
1521 static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
1522 switch (node->kind()) {
1523 case aten::_cast_Byte:
1524 return at::kByte;
1525 case aten::_cast_Char:
1526 return at::kChar;
1527 case aten::_cast_Double:
1528 return at::kDouble;
1529 case aten::_cast_Float:
1530 return at::kFloat;
1531 case aten::_cast_Half:
1532 return at::kHalf;
1533 case aten::_cast_Int:
1534 return at::kInt;
1535 case aten::_cast_Long:
1536 return at::kLong;
1537 case aten::_cast_Short:
1538 return at::kShort;
1539 default:
1540 AT_ASSERTM(
1541 false,
1542 "unknown node kind in get_cast_scalar_type: ",
1543 node->kind().toQualString());
1544 }
1545 };
1546 static const register_formula_for cast_ops{
1547 {
1548 "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
1549 "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
1550 "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
1551 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
1552 "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
1553 "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
1554 "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
1555 "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
1556 },
1557 [](Node* node) -> type_vec_t {
1558 if (auto type =
1559 node->namedInput(attr::self)->type()->cast<TensorType>()) {
1560 return {type->withScalarType(get_cast_scalar_type(node))};
1561 }
1562 return {};
1563 }};
1564
1565 // First, try to match one of the registered formulas to their operator
1566 // sets.
1567 for (auto& entry : shape_formulas) {
1568 if (node->isMemberOf(entry.first)) {
1569 auto types = entry.second(node);
1570 if (types.empty()) {
1571 return false;
1572 } else {
1573 auto outputs = node->outputs();
1574 AT_ASSERT(types.size() == outputs.size());
1575 for (const auto i : c10::irange(types.size())) {
1576 AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get()));
1577 outputs[i]->setType(types[i]);
1578 }
1579 return true;
1580 }
1581 }
1582 }
1583
1584 // This section implements shape prop for an assorted set of nodes that only
1585 // need partial information about their input types.
1586 const auto input_type = [node](size_t index) {
1587 auto result = node->input(index)->type()->cast<TensorType>();
1588 if (result) {
1589 result = result->dimensionedOnly();
1590 }
1591 return result;
1592 };
1593 if (node->matches(
1594 "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
1595 if (auto type = input_type(0)) {
1596 node->output()->setType(type->withDim(1));
1597 return true;
1598 }
1599 } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
1600 if (auto type = input_type(0)) {
1601 node->output()->setType(type->withRequiresGrad(false));
1602 return true;
1603 }
1604 } else if (
1605 node->matches(
1606 "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)")) {
1607 if (auto type = input_type(0)) {
1608 if (type->scalarType() == at::kHalf) {
1609 type = type->withScalarType(at::kFloat);
1610 }
1611 type = type->withDim(1);
1612 node->outputs()[0]->setType(type);
1613 node->outputs()[1]->setType(std::move(type));
1614 return true;
1615 }
1616 } else if (node->matches(
1617 "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
1618 if (auto type = any_tensor_type(node)) {
1619 node->output()->setType(type->withDim(0));
1620 return true;
1621 }
1622 } else if (
1623 node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
1624 node->matches(
1625 "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
1626 if (auto type = any_tensor_type(node)) {
1627 node->output()->setType(type->withDim(1));
1628 return true;
1629 }
1630 } else if (
1631 node->matches(
1632 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1633 node->matches(
1634 "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1635 node->matches(
1636 "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1637 if (auto type = any_tensor_type(node)) {
1638 node->output()->setType(type->withDim(2));
1639 return true;
1640 }
1641 } else if (
1642 node->matches(
1643 "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1644 if (auto type = any_tensor_type(node)) {
1645 node->output()->setType(type->withDim(3));
1646 return true;
1647 }
1648 } else if (
1649 node->matches(
1650 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
1651 auto type = input_type(0);
1652 auto index_type = input_type(1);
1653 // index_select behaves very weirdly when self.dim() == 0. It allows both
1654 // 0D and 1D indices, and returns a value that has as many dimensions as
1655 // index.
1656 if (type && index_type && type->dim()) {
1657 if (*type->dim() == 0) {
1658 node->output()->setType(type->withDim(index_type->dim()));
1659 } else {
1660 node->output()->setType(std::move(type));
1661 }
1662 return true;
1663 }
1664 } else if (
1665 node->matches(
1666 "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
1667 auto type = input_type(0);
1668 auto index_type = input_type(1);
1669 // Gather has this annoying edge case where index always needs to match
1670 // the number of dims of self, **except** when self is 1D and index is 0D
1671 // in which case we return a 0D output.
1672 if (type && index_type && index_type->dim()) {
1673 if (*index_type->dim() == 0) {
1674 node->output()->setType(type->withDim(0));
1675 } else {
1676 node->output()->setType(std::move(type));
1677 }
1678 return true;
1679 }
1680 } else if (
1681 node->matches(
1682 "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
1683 auto weight_type = input_type(0);
1684 auto indices_type = input_type(1);
1685 if (weight_type && indices_type && indices_type->dim()) {
1686 node->output()->setType(weight_type->withDim(*indices_type->dim() + 1));
1687 return true;
1688 }
1689 } else if (
1690 node->matches(
1691 "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
1692 if (auto type = input_type(0)) {
1693 node->output()->setType(std::move(type));
1694 return true;
1695 }
1696 if (auto type = input_type(1)) {
1697 node->output()->setType(std::move(type));
1698 return true;
1699 }
1700 } else if (
1701 node->matches(
1702 "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
1703 if (auto type = any_tensor_type(node)) {
1704 node->output()->setType(type->withDim(0));
1705 return true;
1706 }
1707 }
1708
1709 // The code below implements formulas that need type information for all
1710 // their tensor inputs, and have exactly one output.
1711 std::vector<TensorTypePtr> tensor_types;
1712 static const auto reshape_prop =
1713 [](Node* node,
1714 Symbol shape_input,
1715 const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
1716 if (auto list_size = determineListSize(node->namedInput(shape_input))) {
1717 return tensor_types.at(0)->withDim(*list_size);
1718 }
1719 return nullptr;
1720 };
1721 const auto getSingleOutputType = [&]() -> TypePtr {
1722 if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1723 return tensor_types.at(0)->withScalarType(
1724 tensor_types.at(1)->scalarType());
1725 } else if (
1726 node->matches(
1727 "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1728 node->matches(
1729 "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1730 node->matches(
1731 "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)")) {
1732 return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
1733 } else if (
1734 node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
1735 node->matches(
1736 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
1737 node->matches(
1738 "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
1739 return reshape_prop(node, attr::size, tensor_types);
1740 } else if (
1741 node->matches(
1742 "aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
1743 TypePtr input_type = node->inputs().at(0)->type();
1744 if (auto type = input_type->cast<TensorType>()) {
1745 if (type->scalarType() && type->device()) {
1746 at::ScalarType default_type = *type->scalarType();
1747 c10::Device default_device = *type->device();
1748 if (auto dtype_index =
1749 node->schema().argumentIndexWithName("dtype")) {
1750 auto inp = toIValue(node->inputs().at(*dtype_index));
1751 if (inp == c10::nullopt) {
1752 return nullptr;
1753 }
1754 if (!inp->isNone()) {
1755 default_type = inp->toScalarType();
1756 }
1757 }
1758 if (auto device_index =
1759 node->schema().argumentIndexWithName("device")) {
1760 auto inp = toIValue(node->inputs().at(*device_index));
1761 if (inp == c10::nullopt) {
1762 return nullptr;
1763 }
1764 if (!inp->isNone()) {
1765 default_device = inp->toDevice();
1766 }
1767 }
1768 node->output()->setType(TensorType::create(
1769 default_type,
1770 default_device,
1771 type->dim(),
1772 /*requires_grad=*/c10::nullopt));
1773 }
1774 }
1775 return nullptr;
1776 } else if (
1777 node->matches(
1778 "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)")) {
1779 return reshape_prop(node, attr::shape, tensor_types);
1780 } else if (node->matches(
1781 "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
1782 return reshape_prop(node, attr::repeats, tensor_types);
1783 } else if (node->matches(
1784 "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
1785 auto& t = tensor_types.at(0);
1786 if (!t->dim()) {
1787 return t;
1788 }
1789 return t->withDim(*t->dim() + 1);
1790 } else if (
1791 node->matches(
1792 "aten::select(Tensor self, int dim, int index) -> Tensor") ||
1793 node->matches(
1794 "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
1795 auto& t = tensor_types.at(0);
1796 return t->dim() && *t->dim() > 0 ? t->withDim(*t->dim() - 1) : nullptr;
1797 } else if (node->matches(
1798 "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
1799 if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
1800 return nullptr;
1801 }
1802 int dim1 = *tensor_types.at(0)->dim();
1803 int dim2 = *tensor_types.at(1)->dim();
1804 if (dim1 == 1 && dim2 == 1) {
1805 // Dot product
1806 return tensor_types.at(0)->withDim(0);
1807 // NOLINTNEXTLINE(bugprone-branch-clone)
1808 } else if (dim1 == 2 && dim2 == 2) {
1809 // Matrix multiply
1810 return tensor_types.at(0);
1811 } else if (dim1 == 1 && dim2 == 2) {
1812 // Unsqueeze + matrix multiply + squeeze
1813 return tensor_types.at(0);
1814 } else if (dim1 == 2 && dim2 == 1) {
1815 // Matrix vector multiply
1816 return tensor_types.at(1);
1817 } else {
1818 // Batched matrix multiply (possibly with squeeze + unsqueeze if one
1819 // argument is 1D)
1820 auto type = broadcast(tensor_types, tensor_types[0]->scalarType());
1821 if (dim1 == 1 || dim2 == 1) {
1822 type = type->withDim(type->dim().value() - 1);
1823 }
1824 return type;
1825 }
1826 } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
1827 return tensor_types.at(0)->dimensionedOnly()->withScalarType(at::kLong);
1828 } else if (node->matches(
1829 "aten::take(Tensor self, Tensor index) -> Tensor")) {
1830 return tensor_types.at(1)->dimensionedOnly()->withScalarType(
1831 tensor_types.at(0)->scalarType());
1832 } else if (node->matches(
1833 "aten::diagflat(Tensor self, int offset) -> Tensor")) {
1834 return tensor_types.at(0)->withDim(2);
1835 } else if (node->matches(
1836 "aten::diag(Tensor self, int diagonal) -> Tensor")) {
1837 auto& t = tensor_types.at(0);
1838 if (t->dim() && *t->dim() == 1) {
1839 return t->withDim(2);
1840 } else if (t->dim() && *t->dim() == 2) {
1841 return t->withDim(1);
1842 } else {
1843 return nullptr;
1844 }
1845 } else if (
1846 node->matches(
1847 "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
1848 auto& t = tensor_types.at(0);
1849 if (!t->dim()) {
1850 return nullptr;
1851 }
1852 return t->withDim(*t->dim() + 1);
1853 } else if (node->matches(
1854 "aten::polygamma(int n, Tensor self) -> Tensor")) {
1855 return tensor_types.at(0);
1856 }
1857 return nullptr;
1858 };
1859 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1860 tensor_types = std::move(*maybe_tensor_types);
1861 } else {
1862 return false;
1863 }
1864 if (node->outputs().size() == 1) {
1865 if (auto type = getSingleOutputType()) {
1866 node->output()->setType(std::move(type));
1867 return true;
1868 }
1869 }
1870 return false;
1871 }
1872
1873 bool PropagateCompleteShapeOnNode(
1874 Node* node,
1875 bool insert_expands,
1876 std::vector<TensorTypePtr> tensor_types) {
1877 // For expensive ops we can directly encode their shape propagation
1878 // here, otherwise we fallback to running a fake version of the op
1879 // to get a quick and dirty propagation.
1880 if (node->matches(
1881 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1882 node->matches(
1883 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1884 node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
1885 // These nodes handle tensors of different shapes internally, so there's
1886 // no need to insert explicit expand nodes.
1887 return PropagateShapeOnNodeByRunningIt(node);
1888 } else if (node->matches(
1889 "aten::div(Tensor self, Tensor other) -> Tensor")) {
1890 // "div" handle tensors of different shapes internally, so there's no need
1891 // to insert explicit expand nodes.
1892 // Note that this function could be merged to the one above , but "div" is
1893 // not always safe to run by itself due to integer divide-by-zero.
1894 // We fake the execution by running "mul" operation instead.
1895 auto op = getOperatorForLiteral(
1896 "aten::mul(Tensor self, Tensor other) -> Tensor")
1897 ->getOperation();
1898 return PropagateShapeOnNodeByRunningIt(node, std::move(op));
1899 } else if (node->matches(
1900 "aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
1901 node->output()->setType(tensor_types.at(0));
1902 return true;
1903 } else if (
1904 node->matches(
1905 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1906 node->matches(
1907 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1908 node->matches("aten::div(Tensor self, Scalar other) -> Tensor") ||
1909 node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
1910 auto first_scalar_type = (tensor_types)[0]->scalarType();
1911 auto second_scalar_type =
1912 tryScalarTypeFromJitType(*node->inputs()[1]->type());
1913 if (!first_scalar_type || !second_scalar_type) {
1914 return false;
1915 }
1916 if (isIntegralType(*first_scalar_type, false) &&
1917 isFloatingType(*second_scalar_type)) {
1918 auto default_dtype =
1919 at::typeMetaToScalarType(caffe2::get_default_dtype());
1920 auto type = tensor_types[0]->withScalarType(default_dtype);
1921 node->output()->setType(std::move(type));
1922 return true;
1923 }
1924 if (c10::ScalarType::Bool == *first_scalar_type &&
1925 c10::ScalarType::Bool != *second_scalar_type) {
1926 auto result_type =
1927 c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1928 auto type = tensor_types[0]->withScalarType(result_type);
1929 node->output()->setType(std::move(type));
1930 return true;
1931 }
1932 auto type = tensor_types[0]->withScalarType(first_scalar_type);
1933 node->output()->setType(std::move(type));
1934 return true;
1935 } else if (
1936 insert_expands &&
1937 (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
1938 node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
1939 node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
1940 node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
1941 node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
1942 node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
1943 node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
1944 node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
1945 node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
1946 // Binary broadcasting ops
1947 // NB: we don't handle the nodes in any other way (note the lack of
1948 // return!), because the type casting logic in scalar cases is
1949 // non-trivial. It's better to just run them.
1950 broadcastBinary(node, tensor_types, 0, 1);
1951 return PropagateShapeOnNodeByRunningIt(node);
1952 } else if (
1953 node->matches(
1954 "aten::logit(Tensor self, float? eps = None) -> Tensor") ||
1955 node->matches("aten::neg(Tensor self) -> Tensor") ||
1956 node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
1957 node->matches("aten::tanh(Tensor self) -> Tensor")) {
1958 node->output()->setType(tensor_types.at(0)->contiguous());
1959 return true;
1960 } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
1961 auto lhs_type = tensor_types.at(0);
1962 auto rhs_type = tensor_types.at(1);
1963 auto lhs_sizes = lhs_type->sizes().concrete_sizes().value();
1964 auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
1965 SHAPE_ASSERT(
1966 *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
1967 node->output()->setType(TensorType::createContiguous(
1968 *lhs_type->scalarType(),
1969 *lhs_type->device(),
1970 at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
1971 return true;
1972 } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
1973 auto tp = tensor_types.at(0);
1974 auto sizes = tp->sizes().concrete_sizes().value();
1975 auto strides = tp->strides().concrete_sizes().value();
1976 SHAPE_ASSERT(sizes.size() == 2);
1977 std::swap(sizes.at(0), sizes.at(1));
1978 std::swap(strides.at(0), strides.at(1));
1979 node->output()->setType(tp->withSizesStrides(sizes, strides));
1980 return true;
1981 } else if (
1982 node->matches(
1983 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
1984 /*const_inputs=*/{attr::dim, attr::length})) {
1985 auto tp = tensor_types.at(0);
1986 auto sizes = tp->sizes().concrete_sizes().value();
1987 int64_t dim = node->get<int64_t>(attr::dim).value();
1988 int64_t length = node->get<int64_t>(attr::length).value();
1989 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1990 sizes.at(dim) = length;
1991 node->output()->setType(
1992 tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value()));
1993 return true;
1994 } else if (node->matches(
1995 "aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
1996 node->output()->setType(tensor_types.at(0)->withSizes({}));
1997 return true;
1998 } else if (
1999 node->matches(
2000 "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
2001 /*const_inputs=*/{attr::dim, attr::keepdim})) {
2002 auto& tp = tensor_types.at(0);
2003 auto sizes = tp->sizes().concrete_sizes().value();
2004 auto dims = node->get<c10::List<int64_t>>(attr::dim).value();
2005 bool keepdim = node->get<bool>(attr::keepdim).value();
2006 std::reverse(dims.begin(), dims.end());
2007 for (int64_t dim : dims) {
2008 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2009 if (keepdim) {
2010 sizes.at(dim) = 1;
2011 } else {
2012 sizes.erase(sizes.begin() + dim);
2013 }
2014 }
2015 node->output()->setType(tp->withSizes(sizes));
2016 return true;
2017 } else if (node->matches(
2018 "aten::squeeze(Tensor self, int dim) -> Tensor",
2019 /*const_inputs=*/attr::dim)) {
2020 auto& tp = tensor_types.at(0);
2021 auto sizes = tp->sizes().concrete_sizes().value();
2022 auto strides = tp->strides().concrete_sizes().value();
2023 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2024 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2025 if (sizes.at(dim) == 1) {
2026 sizes.erase(sizes.begin() + dim);
2027 strides.erase(strides.begin() + dim);
2028 }
2029 node->output()->setType(tp->withSizesStrides(sizes, strides));
2030 return true;
2031 } else if (node->matches(
2032 "aten::unsqueeze(Tensor self, int dim) -> Tensor",
2033 /*const_inputs=*/attr::dim)) {
2034 auto& tp = tensor_types.at(0);
2035 auto sizes = tp->sizes().concrete_sizes().value();
2036 auto strides = tp->strides().concrete_sizes().value();
2037 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2038 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
2039 int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
2040 ? 1
2041 : sizes.at(dim) * strides.at(dim);
2042 sizes.insert(sizes.begin() + dim, 1);
2043 strides.insert(strides.begin() + dim, new_stride);
2044 node->output()->setType(tp->withSizesStrides(sizes, strides));
2045 return true;
2046 } else if (node->matches(
2047 "aten::view(Tensor self, int[] size) -> Tensor",
2048 /*const_inputs=*/attr::size)) {
2049 auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
2050 bool inferred = false;
2051 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
2052 size_t inferred_idx;
2053 int64_t size_product = 1;
2054 for (const auto i : c10::irange(sizes.size())) {
2055 if (sizes.get(i) == -1) {
2056 if (inferred)
2057 throw propagation_error();
2058 inferred = true;
2059 inferred_idx = i;
2060 } else {
2061 size_product *= sizes.get(i);
2062 }
2063 }
2064
2065 if (inferred) {
2066 SHAPE_ASSERT(size_product != 0);
2067 size_t numel = 1;
2068 auto concrete_sizes =
2069 tensor_types.at(0)->sizes().concrete_sizes().value();
2070 for (int64_t s : concrete_sizes)
2071 numel *= s;
2072 int64_t inferred_size = numel / size_product;
2073 sizes[inferred_idx] = inferred_size;
2074 }
2075 node->output()->setType(tensor_types.at(0)->withSizes(sizes.vec()));
2076 return true;
2077 } else if (node->matches(
2078 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
2079 if (tensor_types.at(0)->scalarType() ==
2080 tensor_types.at(1)->scalarType()) {
2081 node->output()->setType(node->namedInput(attr::self)->type());
2082 } else {
2083 // This will be a copy, so the result will be contiguous
2084 node->output()->setType(tensor_types.at(1)->withSizes(
2085 tensor_types.at(0)->sizes().concrete_sizes().value()));
2086 }
2087 return true;
2088 } else if (
2089 node->matches(
2090 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
2091 /*const_inputs=*/attr::size)) {
2092 auto tp = tensor_types.at(0);
2093 auto sizesAndStrides = at::inferExpandGeometry_dimvector(
2094 tp->sizes().concrete_sizes().value(),
2095 tp->strides().concrete_sizes().value(),
2096 node->get<c10::List<int64_t>>(attr::size).value().vec());
2097 node->output()->setType(
2098 tp->withSizesStrides(sizesAndStrides.sizes, sizesAndStrides.strides));
2099 return true;
2100 } else if (
2101 node->matches(
2102 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
2103 /*const_inputs=*/attr::dim)) {
2104 auto ten = tensor_types.at(0);
2105 auto index = tensor_types.at(1);
2106 int64_t dim = node->get<int64_t>(attr::dim).value();
2107 SHAPE_ASSERT(*index->sizes().size() == 1);
2108 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
2109 std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value();
2110 sizes[dim] = index->sizes()[0].value();
2111 node->output()->setType(ten->withSizes(sizes));
2112 return true;
2113 } else if (node->matches(
2114 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
2115 /*const_inputs=*/{attr::chunks, attr::dim})) {
2116 auto input_type = tensor_types.at(0);
2117 auto sizes = input_type->sizes().concrete_sizes().value();
2118 auto strides = input_type->strides().concrete_sizes().value();
2119 int64_t dim = node->get<int64_t>(attr::dim).value();
2120 int64_t chunks = node->get<int64_t>(attr::chunks).value();
2121 sizes[dim] /= chunks;
2122 for (Value* output : node->outputs()) {
2123 output->setType(input_type->withSizesStrides(sizes, strides));
2124 }
2125 if (*input_type->sizes()[dim] % chunks != 0) {
2126 sizes[dim] = *input_type->sizes()[dim] % chunks;
2127 node->outputs().back()->setType(
2128 input_type->withSizesStrides(sizes, strides));
2129 }
2130 return true;
2131 } else if (node->kind() == ::c10::onnx::Shape) {
2132 SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
2133 std::vector<int64_t> dim_vec = {
2134 (int64_t)*tensor_types.at(0)->sizes().size()};
2135 at::IntArrayRef dims(dim_vec);
2136 node->output()->setType(
2137 TensorType::createContiguous(at::kLong, at::kCPU, dims));
2138 return true;
2139 } else if (node->kind() == ::c10::onnx::Reshape) {
2140 setUnshapedType(node);
2141 return true;
2142 }
2143 setUnshapedType(node);
2144 return false;
2145 }
2146};
2147} // anonymous namespace
2148
2149void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {
2150 ShapePropagator(graph).propagateBlock(graph->block());
2151}
2152
2153namespace {
2154
2155using TypeCache = std::unordered_map<TypePtr, TypePtr>;
2156
2157TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache);
2158
2159TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
2160 if (type->isSubtypeOf(*TensorType::get())) {
2161 return TensorType::get();
2162 }
2163 at::ArrayRef<TypePtr> contained = type->containedTypes();
2164 if (contained.empty()) {
2165 return type;
2166 }
2167 std::vector<TypePtr> unshaped_contained_types;
2168 for (const auto& contained_type : contained) {
2169 unshaped_contained_types.push_back(
2170 getOrCreateUnshapedType(contained_type, unshaped_type_cache));
2171 }
2172 return type->withContained(std::move(unshaped_contained_types));
2173}
2174
2175TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache) {
2176 auto maybe_cached_type = unshaped_type_cache.find(type);
2177 if (maybe_cached_type != unshaped_type_cache.end()) {
2178 return maybe_cached_type->second;
2179 }
2180 auto unshaped_type = unshapedTypeImpl(type, unshaped_type_cache);
2181 unshaped_type_cache[type] = unshaped_type;
2182 return unshaped_type;
2183}
2184
2185void EraseShapeInformation(
2186 const std::shared_ptr<Graph>& graph,
2187 TypeCache& unshaped_type_cache);
2188
2189void EraseShapeInformation(
2190 at::ArrayRef<Value*> vals,
2191 TypeCache& unshaped_type_cache) {
2192 for (Value* v : vals) {
2193 v->setType(getOrCreateUnshapedType(v->type(), unshaped_type_cache));
2194 }
2195}
2196
2197void EraseShapeInformation(Block* b, TypeCache& unshaped_type_cache) {
2198 EraseShapeInformation(b->inputs(), unshaped_type_cache);
2199 EraseShapeInformation(b->outputs(), unshaped_type_cache);
2200 for (Node* n : b->nodes()) {
2201 EraseShapeInformation(n->outputs(), unshaped_type_cache);
2202 for (Block* sb : n->blocks()) {
2203 EraseShapeInformation(sb, unshaped_type_cache);
2204 }
2205 if (n->hasAttribute(attr::Subgraph)) {
2206 EraseShapeInformation(n->g(attr::Subgraph), unshaped_type_cache);
2207 }
2208 }
2209}
2210
2211void EraseShapeInformation(
2212 const std::shared_ptr<Graph>& graph,
2213 TypeCache& unshaped_type_cache) {
2214 EraseShapeInformation(graph->block(), unshaped_type_cache);
2215}
2216
2217} // anonymous namespace
2218
2219void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
2220 TypeCache unshaped_type_cache;
2221 EraseShapeInformation(graph->block(), unshaped_type_cache);
2222}
2223} // namespace jit
2224} // namespace torch
2225