1 | #include <torch/csrc/jit/passes/shape_analysis.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/frontend/error_report.h> |
6 | #include <torch/csrc/jit/ir/alias_analysis.h> |
7 | #include <torch/csrc/jit/ir/constants.h> |
8 | #include <torch/csrc/jit/ir/ir.h> |
9 | #include <torch/csrc/jit/ir/ir_views.h> |
10 | #include <torch/csrc/jit/passes/utils/op_registry.h> |
11 | #include <torch/csrc/jit/runtime/exception_message.h> |
12 | #include <torch/csrc/jit/runtime/operator.h> |
13 | |
14 | #include <torch/csrc/autograd/variable.h> |
15 | |
16 | #include <ATen/DeviceGuard.h> |
17 | #include <ATen/ExpandUtils.h> |
18 | #include <ATen/core/symbol.h> |
19 | |
20 | #ifndef AT_PER_OPERATOR_HEADERS |
21 | #include <ATen/Functions.h> |
22 | #else |
23 | #include <ATen/ops/empty_strided.h> |
24 | #endif |
25 | |
26 | #include <exception> |
27 | #include <iostream> |
28 | #include <memory> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | namespace torch { |
33 | namespace jit { |
34 | |
35 | bool mergeTypes( |
36 | ArrayRef<Value*> lhs, |
37 | ArrayRef<Value*> rhs, |
38 | ArrayRef<Value*> outputs) { |
39 | AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size()); |
40 | bool changed = false; |
41 | for (const auto i : c10::irange(lhs.size())) { |
42 | auto old_output_type = outputs[i]->type(); |
43 | auto new_type = |
44 | unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true); |
45 | AT_ASSERT(new_type); |
46 | outputs[i]->setType(*new_type); |
47 | if (*old_output_type != *outputs[i]->type()) |
48 | changed = true; |
49 | } |
50 | return changed; |
51 | } |
52 | |
53 | void applyTypes(ArrayRef<Value*> src, ArrayRef<Value*> dst) { |
54 | AT_ASSERT(src.size() == dst.size()); |
55 | for (const auto i : c10::irange(src.size())) { |
56 | dst[i]->setType(src[i]->type()); |
57 | } |
58 | } |
59 | |
60 | void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) { |
61 | for (Node* node : block->nodes()) { |
62 | try { |
63 | propagateNode(node, insert_expands); |
64 | } catch (propagation_error& e) { |
65 | setUnshapedType(node); |
66 | } catch (std::exception& e) { |
67 | throw ErrorReport(node->sourceRange()) |
68 | << ExceptionMessage(e) |
69 | << "\nThe above operation failed shape propagation in this context" ; |
70 | } |
71 | } |
72 | } |
73 | |
74 | void PropertyPropBase::processIf(Node* node) { |
75 | auto then_block = node->blocks().at(0); |
76 | auto else_block = node->blocks().at(1); |
77 | propagateBlock(then_block); |
78 | propagateBlock(else_block); |
79 | mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs()); |
80 | } |
81 | |
82 | void PropertyPropBase::processLoop(Node* node) { |
83 | LoopView loop(node); |
84 | // propagate counter type |
85 | loop.currentTripCount()->setType(loop.maxTripCount()->type()); |
86 | applyTypes(loop.carriedInputs(), loop.bodyCarriedInputs()); |
87 | |
88 | do { |
89 | propagateBlock(loop.bodyBlock(), /*insert_expands=*/false); |
90 | // note: inserting expands is unsafe at this point, we don't know |
91 | // if the types are stable yet, so the arguments to expand may change |
92 | } while (mergeTypes( |
93 | loop.bodyCarriedInputs(), |
94 | loop.bodyCarriedOutputs(), |
95 | loop.bodyCarriedInputs())); |
96 | |
97 | // now that the types are stable, we can insert the expands |
98 | propagateBlock(loop.bodyBlock(), /*insert_expands=*/true); |
99 | applyTypes(loop.bodyCarriedInputs(), loop.carriedOutputs()); |
100 | } |
101 | |
102 | void PropertyPropBase::setUnshapedType(Value* o) { |
103 | o->setType(unshapedType(o->type())); |
104 | } |
105 | |
106 | void PropertyPropBase::setUnshapedType(Node* node) { |
107 | for (auto o : node->outputs()) { |
108 | setUnshapedType(o); |
109 | } |
110 | } |
111 | |
112 | namespace prim { |
113 | using namespace ::c10::prim; |
114 | } |
115 | |
116 | #define SHAPE_ASSERT(cond) \ |
117 | if (!(cond)) \ |
118 | throw propagation_error() |
119 | |
120 | namespace { |
121 | |
122 | bool isValidArgumentForRunning(Value* v) { |
123 | // allow constants |
124 | if (toIValue(v)) |
125 | return true; |
126 | if (TensorTypePtr tt = v->type()->cast<TensorType>()) { |
127 | if (!tt->scalarType()) { |
128 | return false; |
129 | } |
130 | return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false); |
131 | } |
132 | return v->type()->isSubtypeOf(*FloatType::get()); |
133 | } |
134 | |
135 | bool isValidReturnForRunning(Value* v) { |
136 | return v->type()->isSubtypeOf(*TensorType::get()) || |
137 | v->type()->isSubtypeOf(*NumberType::get()); |
138 | } |
139 | |
140 | bool containsTensorType(const TypePtr& t) { |
141 | auto n_contained = t->containedTypes().size(); |
142 | if (n_contained == 1) { |
143 | return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get()); |
144 | } else if (n_contained > 1) { |
145 | return std::any_of( |
146 | t->containedTypes().begin(), |
147 | t->containedTypes().end(), |
148 | containsTensorType); |
149 | } |
150 | return false; |
151 | } |
152 | |
153 | // for each node in the schema with type Tensor, extract the T type |
154 | // returns c10::nullopt if any Tensor in the schema does not have a known |
155 | // shape ignores non-tensor in the list of inputs |
156 | c10::optional<std::vector<TensorTypePtr>> gatherTensorTypes( |
157 | Node* node, |
158 | bool complete = false) { |
159 | std::vector<TensorTypePtr> tensor_types; |
160 | |
161 | auto schema_opt = node->maybeSchema(); |
162 | if (!schema_opt) { |
163 | return c10::nullopt; |
164 | } |
165 | auto& schema = *schema_opt; |
166 | auto& args = schema.arguments(); |
167 | // can't handle varargs primitives because we don't know what should be a |
168 | // Tensor |
169 | if (schema.is_vararg()) { |
170 | return c10::nullopt; |
171 | } |
172 | for (const auto i : c10::irange(args.size())) { |
173 | if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) { |
174 | return c10::nullopt; |
175 | } else if (args[i].type()->isSubtypeOf(*TensorType::get())) { |
176 | if (auto type = node->input(i)->type()->cast<TensorType>()) { |
177 | if (complete && !type->isComplete()) { |
178 | return c10::nullopt; |
179 | } |
180 | tensor_types.push_back(type); |
181 | } else { |
182 | return c10::nullopt; |
183 | } |
184 | } else /* non-tensor type */ { |
185 | continue; |
186 | } |
187 | } |
188 | return tensor_types; |
189 | } |
190 | |
191 | int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) { |
192 | if (dim < 0) { |
193 | dim += (int64_t)sizes.size(); |
194 | } |
195 | return dim; |
196 | } |
197 | |
198 | c10::ScalarType unionScalarTypes( |
199 | c10::ScalarType original, |
200 | c10::ScalarType next) { |
201 | if (original == c10::ScalarType::Undefined) { |
202 | return next; |
203 | } else { |
204 | return c10::promoteTypes(original, next); |
205 | } |
206 | } |
207 | |
208 | // Promotes result types for arithmetic operations on Tensor operands using |
209 | // new type promotion logic. See tensor_attributes.rst for details. |
210 | // This doesn't handle the case of arithmetic ops with Scalar arguments (when |
211 | // `Tensor.getUnsafeTensorImpl()->is_wrapped_nubmer()` would return true) |
212 | c10::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) { |
213 | c10::ScalarType dimmed = c10::ScalarType::Undefined; |
214 | c10::ScalarType zerodim = c10::ScalarType::Undefined; |
215 | // binary arithmetic ops, more than 2 args is alpha. |
216 | for (const auto i : c10::irange(2)) { |
217 | auto dtt = node->inputs()[i]->type()->expect<TensorType>(); |
218 | auto inputDtype = dtt->scalarType(); |
219 | if (!dtt || !inputDtype) { |
220 | return c10::nullopt; |
221 | } |
222 | if (dtt->dim() && *dtt->dim() > 0) { |
223 | dimmed = unionScalarTypes(dimmed, *inputDtype); |
224 | } else if (!isFloatingType(dimmed)) { |
225 | // if no dimensions |
226 | zerodim = unionScalarTypes(zerodim, *inputDtype); |
227 | } |
228 | } |
229 | // if a tensor with dimensions is already of the highest category, don't |
230 | // need to check zero-dim tensors. |
231 | if (isFloatingType(dimmed)) { |
232 | return dimmed; |
233 | } |
234 | // int_tensor * zero_dim_floating -> floating_tensor |
235 | if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) { |
236 | return zerodim; |
237 | } |
238 | // bool_tensor * non_bool_scalar -> non_bool_tensor |
239 | if (c10::ScalarType::Bool == dimmed && |
240 | c10::ScalarType::Undefined != zerodim) { |
241 | return zerodim; |
242 | } |
243 | // types of dimensioned tensors generally take precedence over zero-dim |
244 | // tensors if not promoting due to category. e.g.: |
245 | // int_tensor * long -> int_tensor |
246 | if (c10::ScalarType::Undefined != dimmed) { |
247 | return dimmed; |
248 | } |
249 | |
250 | // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor. |
251 | return zerodim; |
252 | } |
253 | |
254 | class ShapePropagator : public PropertyPropBase { |
255 | public: |
256 | explicit ShapePropagator(const std::shared_ptr<Graph>& graph) |
257 | : PropertyPropBase(graph), aliasDb_(graph) { |
258 | collectResizeSet(graph->block()); |
259 | } |
260 | |
261 | private: |
262 | ValueSet resized_alias_set; |
263 | const AliasDb aliasDb_; |
264 | |
265 | bool resizesInput(Node* n) { |
266 | static std::unordered_set<Symbol> resize_ops{ |
267 | aten::resize_, |
268 | aten::resize_as_, |
269 | aten::copy_, |
270 | aten::set_, |
271 | aten::unsqueeze_, |
272 | aten::t_, |
273 | aten::transpose_, |
274 | }; |
275 | |
276 | if (resize_ops.count(n->kind())) |
277 | return true; |
278 | |
279 | if (!n->maybeSchema()) |
280 | return false; |
281 | |
282 | // ops which take the result and write to input "out" |
283 | if (auto out_arg_index = n->schema().argumentIndexWithName("out" )) { |
284 | auto arg = n->schema().arguments().at(*out_arg_index); |
285 | return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get()); |
286 | } |
287 | return false; |
288 | } |
289 | |
290 | void collectResizeSet(Block* block) { |
291 | for (Node* n : block->nodes()) { |
292 | for (Block* b : n->blocks()) { |
293 | collectResizeSet(b); |
294 | } |
295 | if (resizesInput(n)) { |
296 | for (const auto input : n->inputs()) { |
297 | if (aliasDb_.writesToAlias(n, {input})) { |
298 | resized_alias_set.insert(input); |
299 | } |
300 | } |
301 | } |
302 | } |
303 | } |
304 | |
305 | IValue representativeValue(Value* v) { |
306 | TypePtr type_ = v->type(); |
307 | // if the value is actually constant, just use it! |
308 | if (auto iv = toIValue(v)) { |
309 | return *iv; |
310 | } |
311 | if (TensorTypePtr type = type_->cast<TensorType>()) { |
312 | if (type->isComplete()) { |
313 | at::DeviceGuard device_guard(*type->device()); |
314 | return at::empty_strided( |
315 | *type->sizes().concrete_sizes(), |
316 | *type->strides().concrete_sizes(), |
317 | at::TensorOptions(*type->device()) |
318 | .dtype(*type->scalarType())) |
319 | .zero_(); |
320 | } |
321 | // fallthrough |
322 | } else if (type_->isSubtypeOf(*FloatType::get())) { |
323 | return 0.f; |
324 | } |
325 | // we should not get here because isValidArgumentForRunning should have |
326 | // prevented it |
327 | std::stringstream ss; |
328 | ss << "unable to create representative value for: " << type_->str() |
329 | << ". File a bug report" ; |
330 | throw std::runtime_error(ss.str()); |
331 | } |
332 | |
333 | void broadcastBinary( |
334 | Node* node, |
335 | std::vector<TensorTypePtr>& types, |
336 | size_t idx1, |
337 | size_t idx2) { |
338 | auto expected_size = at::infer_size( |
339 | *types[idx1]->sizes().concrete_sizes(), |
340 | *types[idx2]->sizes().concrete_sizes()); |
341 | auto broadcast = [&](size_t input_idx) { |
342 | TensorTypePtr input_type = types.at(input_idx); |
343 | if (input_type->sizes() == expected_size) |
344 | return; |
345 | auto graph = node->owningGraph(); |
346 | WithInsertPoint point_guard{node}; |
347 | Node* expand = graph |
348 | ->create( |
349 | aten::expand, |
350 | {node->inputs().at(input_idx), |
351 | graph->insertConstant(expected_size), |
352 | graph->insertConstant(false)}) |
353 | ->insertBefore(node); |
354 | propagateNode(expand); |
355 | node->replaceInput(input_idx, expand->output()); |
356 | }; |
357 | broadcast(idx1); |
358 | broadcast(idx2); |
359 | types[0] = node->inputs().at(idx1)->type()->expect<TensorType>(); |
360 | types[1] = node->inputs().at(idx2)->type()->expect<TensorType>(); |
361 | } |
362 | |
363 | OperatorSet cannot_propagate_shape_by_running_it = { |
364 | "aten::inverse(Tensor self) -> Tensor" , |
365 | }; |
366 | |
367 | // Check if this node depends on a value that has been mutated previously. If |
368 | // it has, then it's not safe to run this node in isolation, since we don't |
369 | // know whether the dependency has been executed. |
370 | std::unordered_map<Node*, bool> dependsOnMutationMemo_; |
371 | bool dependsOnMutation(Node* node) { |
372 | if (dependsOnMutationMemo_.count(node) != 0) { |
373 | return dependsOnMutationMemo_[node]; |
374 | } |
375 | |
376 | if (aliasDb_.hasWriters(node)) { |
377 | // If something could have written to a value used by this node, we can't |
378 | // guarantee the result is the same when running it in isolation. |
379 | dependsOnMutationMemo_[node] = true; |
380 | return true; |
381 | } |
382 | |
383 | // recursively check the producers of its inputs. We need to do this if the |
384 | // mutable value has been laundered through a pure function: |
385 | // a += 1 |
386 | // c = a + b |
387 | // d = c + 1 |
388 | // In this case, `d` cares whether `a` has been mutated even though it's not |
389 | // a direct input. |
390 | auto depends = false; |
391 | for (auto input : node->inputs()) { |
392 | depends |= dependsOnMutation(input->node()); |
393 | } |
394 | |
395 | dependsOnMutationMemo_[node] = depends; |
396 | return depends; |
397 | } |
398 | |
399 | bool canPropagateShapeByRunningIt(Node* node) { |
400 | if (node->isMemberOf(cannot_propagate_shape_by_running_it)) { |
401 | return false; |
402 | } |
403 | |
404 | if (dependsOnMutation(node)) { |
405 | return false; |
406 | } |
407 | |
408 | bool valid_args = std::all_of( |
409 | node->inputs().begin(), |
410 | node->inputs().end(), |
411 | isValidArgumentForRunning); |
412 | if (!valid_args) |
413 | return false; |
414 | |
415 | bool valid_returns = std::all_of( |
416 | node->outputs().begin(), |
417 | node->outputs().end(), |
418 | isValidReturnForRunning); |
419 | if (!valid_returns) |
420 | return false; |
421 | |
422 | return true; |
423 | } |
424 | |
425 | // If there's no Tensor in outputs, e.g float / float, |
426 | // we don't need to propagate shape. |
427 | bool DoesntRefineOutputs(Node* node) { |
428 | auto outputs = node->outputs(); |
429 | for (auto& out : outputs) { |
430 | if (containsTensorType(out->type())) { |
431 | return false; |
432 | } |
433 | } |
434 | return true; |
435 | } |
436 | |
437 | bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) { |
438 | if (!canPropagateShapeByRunningIt(node)) |
439 | return false; |
440 | |
441 | if (!op) |
442 | op = node->getOperation(); |
443 | |
444 | Stack stack; |
445 | |
446 | for (auto input : node->inputs()) { |
447 | stack.push_back(representativeValue(input)); |
448 | } |
449 | |
450 | // XXX: we're not catching any exceptions from the op for now. This |
451 | // is to uncover any mistakes we could make when editing this code, |
452 | // and eventually it shouldn't matter, because this phase should be |
453 | // preceded by schema checking. |
454 | op(stack); |
455 | |
456 | AT_ASSERT(stack.size() == node->outputs().size()); |
457 | for (const auto i : c10::irange(stack.size())) { |
458 | // some ops may have mixed tensor/primitive outputs |
459 | // for primitives, we don't need to change the type because it is already |
460 | // its most constrained form. |
461 | auto tensor_type = node->outputs()[i]->type()->cast<TensorType>(); |
462 | if (stack[i].isTensor() && tensor_type) { |
463 | // gradient information isn't always available or part of represenative |
464 | // inputs, maintain original grad property |
465 | auto tensor_grad = tensor_type->requiresGrad(); |
466 | node->outputs()[i]->setType(TensorType::create(stack[i].toTensor()) |
467 | ->withRequiresGrad(tensor_grad)); |
468 | } |
469 | } |
470 | return true; |
471 | } |
472 | |
473 | void PropagateCatShape(Node* cat_node) { |
474 | static const auto propagate_complete = |
475 | [](Node* node, at::ArrayRef<Value*> tensors) -> bool { |
476 | auto input_types = |
477 | fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); }); |
478 | if (!std::all_of( |
479 | input_types.begin(), |
480 | input_types.end(), |
481 | [](const TensorTypePtr& tp) { |
482 | return tp != nullptr && tp->isComplete(); |
483 | })) { |
484 | return false; |
485 | } |
486 | if (!node->is_constant(attr::dim)) |
487 | return false; |
488 | std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes(); |
489 | const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes); |
490 | const int64_t ndim = (int64_t)sizes.size(); |
491 | |
492 | if (dim < 0 || dim >= ndim) |
493 | return false; |
494 | |
495 | sizes[dim] = 0; |
496 | for (auto& tp : input_types) { |
497 | auto tp_sizes = tp->sizes().concrete_sizes().value(); |
498 | if (sizes.size() != tp_sizes.size()) |
499 | return false; |
500 | for (const auto i : c10::irange(ndim)) { |
501 | if (sizes[i] != tp_sizes[i] && i != dim) { |
502 | return false; |
503 | } |
504 | } |
505 | sizes[dim] += tp_sizes[dim]; |
506 | } |
507 | node->output()->setType(input_types[0]->withSizes(sizes)); |
508 | return true; |
509 | }; |
510 | static const auto propagate = [](Node* node, |
511 | at::ArrayRef<Value*> tensors) -> bool { |
512 | for (Value* v : tensors) { |
513 | if (auto type = v->type()->cast<TensorType>()) { |
514 | node->output()->setType(type->dimensionedOnly()); |
515 | return true; |
516 | } |
517 | } |
518 | return false; |
519 | }; |
520 | auto list_node = |
521 | ((cat_node->kind() == prim::FusedConcat) |
522 | ? cat_node |
523 | : cat_node->namedInput(attr::tensors)->node()); |
524 | if (list_node->kind() == prim::ListConstruct || |
525 | cat_node->kind() == prim::FusedConcat) { |
526 | auto tensors = list_node->inputs(); |
527 | if (!tensors.empty()) { |
528 | // NOLINTNEXTLINE(bugprone-branch-clone) |
529 | if (propagate_complete(cat_node, tensors)) { |
530 | return; |
531 | } else if (propagate(cat_node, tensors)) { |
532 | return; |
533 | } |
534 | } |
535 | } |
536 | setUnshapedType(cat_node); |
537 | } |
538 | |
539 | void propagateTorchTensorShape(Node* node) { |
540 | auto input_type = node->inputs().at(0)->type(); |
541 | |
542 | size_t dims = 0; |
543 | auto input_base_type = input_type; |
544 | auto list_type = input_type->cast<ListType>(); |
545 | while (list_type) { |
546 | dims++; |
547 | input_base_type = list_type->getElementType(); |
548 | list_type = input_base_type->cast<ListType>(); |
549 | } |
550 | |
551 | at::optional<at::ScalarType> default_type = |
552 | tryScalarTypeFromJitType(*input_base_type); |
553 | if (auto grad_index = node->schema().argumentIndexWithName("dtype" )) { |
554 | auto inp = toIValue(node->inputs().at(*grad_index)); |
555 | if (inp == c10::nullopt) { |
556 | return; |
557 | } else if (!inp->isNone()) { |
558 | default_type = inp->toScalarType(); |
559 | } |
560 | } |
561 | |
562 | at::Device default_device = at::kCPU; |
563 | if (auto device_index = node->schema().argumentIndexWithName("device" )) { |
564 | auto inp = toIValue(node->inputs().at(*device_index)); |
565 | if (inp == c10::nullopt) { |
566 | return; |
567 | } else if (!inp->isNone()) { |
568 | default_device = inp->toDevice(); |
569 | } |
570 | } |
571 | node->output()->setType(TensorType::create( |
572 | default_type, default_device, dims, /*requires_grad=*/c10::nullopt)); |
573 | } |
574 | |
575 | // returns whether any such values were found |
576 | bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) { |
577 | bool in_resize = false; |
578 | for (auto v : vs) { |
579 | if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) { |
580 | setUnshapedType(v); |
581 | in_resize = true; |
582 | } |
583 | } |
584 | return in_resize; |
585 | } |
586 | |
587 | void propagateNode(Node* node, bool insert_expands = true) override { |
588 | // Certain ops like resize_ change the input tensors size. Because our |
589 | // analysis is flow invariant, we set any Tensor that can alias a resized |
590 | // Tensor to the base Tensor Type without size information. |
591 | if (setUnshapedTypeIfAliasResizedSet(node->inputs())) { |
592 | return setUnshapedType(node); |
593 | } |
594 | |
595 | // These don't require the types, and have complicated schema. Return early |
596 | // after we process them. |
597 | switch (node->kind()) { |
598 | case prim::If: |
599 | return processIf(node); |
600 | case prim::Loop: { |
601 | return processLoop(node); |
602 | } |
603 | case aten::Bool: |
604 | case aten::Int: |
605 | case aten::Float: |
606 | case aten::ScalarImplicit: |
607 | case aten::FloatImplicit: |
608 | case aten::IntImplicit: |
609 | return; // correct num type is already set |
610 | case prim::NumToTensor: { |
611 | TypePtr typ = node->input()->type(); |
612 | if (typ->isSubtypeOf(*IntType::get()) || |
613 | typ->isSubtypeOf(*BoolType::get())) { |
614 | node->output()->setType(TensorType::create( |
615 | at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); |
616 | } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) { |
617 | node->output()->setType(TensorType::create( |
618 | at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); |
619 | } |
620 | return; |
621 | } |
622 | case aten::tensor: |
623 | case aten::as_tensor: { |
624 | // as_tensor has an overloaded schema and can either have a tensor or |
625 | // a list as the first input, if the input is a tensor, we delegate |
626 | // the shape propagation in PropagateTensorShapeOnNode |
627 | if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) { |
628 | break; |
629 | } |
630 | return propagateTorchTensorShape(node); |
631 | } |
632 | case prim::TupleConstruct: { |
633 | // We refresh the tuple type, because the input types could have been |
634 | // refined. |
635 | auto orig_type = node->output()->type()->expect<TupleType>(); |
636 | auto new_types = |
637 | fmap(node->inputs(), [](Value* v) { return v->type(); }); |
638 | node->output()->setType( |
639 | orig_type->createWithContained(std::move(new_types))); |
640 | return; |
641 | } |
642 | case prim::TupleUnpack: { |
643 | auto tuple_type = node->input()->type()->cast<TupleType>(); |
644 | AT_ASSERT( |
645 | tuple_type && |
646 | tuple_type->elements().size() == node->outputs().size()); |
647 | auto elems = tuple_type->elements(); |
648 | for (size_t i = 0; i < node->outputs().size(); ++i) { |
649 | node->output(i)->setType(elems[i]); |
650 | } |
651 | return; |
652 | } |
653 | case prim::Constant: { |
654 | if (node->output()->type()->isSubtypeOf(*TensorType::get())) { |
655 | node->output()->inferTypeFrom(node->t(attr::value)); |
656 | } |
657 | return; |
658 | } |
659 | case prim::unchecked_unwrap_optional: { |
660 | // If we have specialized the optional type to the element type, |
661 | // we want to pass it down. We write this as input.isSubtypeOf(output) |
662 | // to be sure that we don't screw up nested optionals. |
663 | if (node->input()->type()->isSubtypeOf(*node->output()->type())) { |
664 | node->output()->setType(node->input()->type()); |
665 | } |
666 | return; |
667 | } |
668 | case prim::ConstantChunk: { |
669 | Value* tensor = node->input(); |
670 | if (auto type = tensor->type()->cast<TensorType>()) { |
671 | type = type->dimensionedOnly(); |
672 | for (Value* output : node->outputs()) { |
673 | output->setType(type); |
674 | } |
675 | } else { |
676 | setUnshapedType(node); |
677 | } |
678 | return; |
679 | } |
680 | case prim::grad: { |
681 | auto tt = node->input()->type()->expect<TensorType>(); |
682 | // grad may be undefined |
683 | // requires_grad may be required |
684 | auto grad_type = TensorType::get()->withPossiblyUndefined(); |
685 | node->output()->setType(std::move(grad_type)); |
686 | return; |
687 | } |
688 | case prim::CallFunction: |
689 | case prim::CallMethod: |
690 | case prim::AutogradZero: { |
691 | setUnshapedType(node); |
692 | return; |
693 | } |
694 | case prim::GetAttr: { |
695 | auto cls = node->input()->type()->expect<ClassType>(); |
696 | // propagate any type specializations encoded in the type of the class |
697 | node->output()->setType(cls->getAttribute(node->s(attr::name))); |
698 | return; |
699 | } |
700 | case aten::_unwrap_optional: { |
701 | // If we have specialized the optional type to the element type, |
702 | // we want to pass it down. We write this as input.isSubtypeOf(output) |
703 | // to be sure that we don't screw up nested optionals. |
704 | if (node->input()->type()->isSubtypeOf(*node->output()->type())) { |
705 | node->output()->setType(node->input()->type()); |
706 | } |
707 | return; |
708 | } |
709 | default: |
710 | break; // fall-through |
711 | } |
712 | |
713 | if (node->hasSideEffects()) { |
714 | return; |
715 | } |
716 | |
717 | if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor" ) || |
718 | node->kind() == prim::FusedConcat) { |
719 | return PropagateCatShape(node); |
720 | } |
721 | |
722 | if (auto maybe_complete_types = |
723 | gatherTensorTypes(node, /*complete=*/true)) { |
724 | if (PropagateCompleteShapeOnNode( |
725 | node, insert_expands, std::move(*maybe_complete_types))) { |
726 | return; |
727 | } |
728 | } |
729 | |
730 | if (PropagateTensorShapeOnNode(node, insert_expands)) { |
731 | return; |
732 | } |
733 | |
734 | if (DoesntRefineOutputs(node)) { |
735 | return; |
736 | } |
737 | |
738 | if (PropagateShapeOnNodeByRunningIt(node)) { |
739 | return; |
740 | } |
741 | return setUnshapedType(node); |
742 | } |
743 | |
744 | static c10::optional<size_t> determineListSize(Value* list) { |
745 | AT_ASSERT(list->type()->cast<ListType>()); |
746 | if (auto shape = constant_as<c10::List<int64_t>>(list)) { |
747 | return shape->size(); |
748 | } |
749 | auto input_node = list->node(); |
750 | if (input_node->kind() == prim::ListConstruct) { |
751 | return input_node->inputs().size(); |
752 | } |
753 | return c10::nullopt; |
754 | } |
755 | |
756 | // is it ok to try to run the op |
757 | // If an input is a constant, then we assume that the input is valid |
758 | // and we can try to run it. |
759 | // Otherwise: |
760 | // Integral typed _inputs_ are often an indicator that we're indexing into |
761 | // a tensor, so we should special-case these ops in the shape propagation. |
762 | // Additionally, passing in a zero representative tensor into an integer |
763 | // division op causes divide-by-zero errors |
764 | // _Outputs_ must be tensors or primitives |
765 | // We will call inferTypeFrom on the tensors, and ignore the primitives. |
766 | // However, we allow primitive returns because we want to support mixed |
767 | // primitive/tensor outputs. |
768 | |
769 | bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) { |
770 | static const auto broadcast = |
771 | [](std::vector<TensorTypePtr>& tensor_types, |
772 | c10::optional<at::ScalarType> t) -> TensorTypePtr { |
773 | if (tensor_types.size() == 1) { |
774 | return tensor_types[0]->dimensionedOnly()->withScalarType(t); |
775 | } |
776 | AT_ASSERT(!tensor_types.empty()); |
777 | auto any_type = tensor_types[0]; |
778 | auto max_dims = any_type->dim(); |
779 | for (auto& type : tensor_types) { |
780 | if (!max_dims || !type->dim()) { |
781 | max_dims = c10::nullopt; |
782 | } else { |
783 | max_dims = std::max(*max_dims, *type->dim()); |
784 | } |
785 | } |
786 | return TensorType::create( |
787 | t, |
788 | any_type->device(), |
789 | max_dims, |
790 | /*requires_grad=*/c10::nullopt); |
791 | }; |
792 | |
793 | using type_vec_t = std::vector<TensorTypePtr>; |
794 | // Formula is expected to return a vector of length equal to the number of |
795 | // tensor outputs of the node, or an empty vector which implies that it |
796 | // failed to propagate. |
797 | using formula_t = std::function<type_vec_t(Node*)>; |
798 | static std::mutex shape_formulas_mutex; |
799 | static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas; |
800 | struct register_formula_for { |
801 | register_formula_for(OperatorSet operators, formula_t formula) { |
802 | std::unique_lock<std::mutex> lock{shape_formulas_mutex}; |
803 | shape_formulas.emplace_back(std::move(operators), std::move(formula)); |
804 | } |
805 | }; |
806 | |
807 | // Requirements: |
808 | // dims : preserved |
809 | // scalar type : preserved |
810 | // device : preserved |
811 | // tensor inputs : 1 |
812 | // tensor outputs : 1 |
813 | // Additionally: |
814 | // - First input should be the only tensor input |
815 | static const register_formula_for simple_unary_ops{ |
816 | { |
817 | "aten::acos(Tensor self) -> Tensor" , |
818 | "aten::neg(Tensor self) -> Tensor" , |
819 | "aten::t(Tensor self) -> Tensor" , |
820 | "aten::sigmoid(Tensor self) -> Tensor" , |
821 | "aten::logit(Tensor self, float? eps=None) -> Tensor" , |
822 | "aten::tanh(Tensor self) -> Tensor" , |
823 | "aten::relu(Tensor self) -> Tensor" , |
824 | "aten::asin(Tensor self) -> Tensor" , |
825 | "aten::atan(Tensor self) -> Tensor" , |
826 | "aten::ceil(Tensor self) -> Tensor" , |
827 | "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor" , |
828 | "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)" , |
829 | "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor" , |
830 | "aten::celu(Tensor self, Scalar alpha) -> Tensor" , |
831 | "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor" , |
832 | "aten::clamp_max(Tensor self, Scalar max) -> Tensor" , |
833 | "aten::clamp_min(Tensor self, Scalar min) -> Tensor" , |
834 | "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor" , |
835 | "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor" , |
836 | "aten::cos(Tensor self) -> Tensor" , |
837 | "aten::cosh(Tensor self) -> Tensor" , |
838 | "aten::digamma(Tensor self) -> Tensor" , |
839 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" , |
840 | "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor" , |
841 | "aten::erf(Tensor self) -> Tensor" , |
842 | "aten::erfc(Tensor self) -> Tensor" , |
843 | "aten::erfinv(Tensor self) -> Tensor" , |
844 | "aten::exp(Tensor self) -> Tensor" , |
845 | "aten::expm1(Tensor self) -> Tensor" , |
846 | "aten::log(Tensor self) -> Tensor" , |
847 | "aten::log10(Tensor self) -> Tensor" , |
848 | "aten::log1p(Tensor self) -> Tensor" , |
849 | "aten::log2(Tensor self) -> Tensor" , |
850 | "aten::log_sigmoid(Tensor self) -> Tensor" , |
851 | "aten::floor(Tensor self) -> Tensor" , |
852 | "aten::frac(Tensor self) -> Tensor" , |
853 | "aten::flip(Tensor self, int[] dims) -> Tensor" , |
854 | "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor" , |
855 | "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor" , |
856 | "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor" , |
857 | "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor" , |
858 | "aten::glu(Tensor self, int dim) -> Tensor" , |
859 | "aten::inverse(Tensor self) -> Tensor" , |
860 | "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor" , |
861 | "aten::lgamma(Tensor self) -> Tensor" , |
862 | "aten::mvlgamma(Tensor self, int p) -> Tensor" , |
863 | "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor" , |
864 | "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor" , |
865 | "aten::permute(Tensor self, int[] dims) -> Tensor" , |
866 | "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)" , |
867 | "aten::pinverse(Tensor self, float rcond) -> Tensor" , |
868 | "aten::reciprocal(Tensor self) -> Tensor" , |
869 | "aten::relu(Tensor self) -> Tensor" , |
870 | "aten::round(Tensor self) -> Tensor" , |
871 | "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor" , |
872 | "aten::rsqrt(Tensor self) -> Tensor" , |
873 | "aten::selu(Tensor self) -> Tensor" , |
874 | "aten::gelu(Tensor self, *, str approximate='none') -> Tensor" , |
875 | "aten::sigmoid(Tensor self) -> Tensor" , |
876 | "aten::sign(Tensor self) -> Tensor" , |
877 | "aten::sin(Tensor self) -> Tensor" , |
878 | "aten::sinh(Tensor self) -> Tensor" , |
879 | "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor" , |
880 | "aten::softshrink(Tensor self, Scalar lambd) -> Tensor" , |
881 | "aten::sqrt(Tensor self) -> Tensor" , |
882 | "aten::tan(Tensor self) -> Tensor" , |
883 | "aten::tanh(Tensor self) -> Tensor" , |
884 | "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor" , |
885 | "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor" , |
886 | "aten::tril(Tensor self, int diagonal) -> Tensor" , |
887 | "aten::triu(Tensor self, int diagonal) -> Tensor" , |
888 | "aten::trunc(Tensor self) -> Tensor" , |
889 | "aten::rot90(Tensor self, int k, int[] dims) -> Tensor" , |
890 | "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor" , |
891 | "aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor" , |
892 | "aten::alias(Tensor self) -> Tensor" , |
893 | }, |
894 | [](Node* node) -> type_vec_t { |
895 | auto input_type = node->input(0)->type()->cast<TensorType>(); |
896 | return input_type ? type_vec_t{input_type->dimensionedOnly()} |
897 | : type_vec_t{}; |
898 | }}; |
899 | |
900 | // Requirements: |
901 | // dims : preserved |
902 | // scalar type : preserved, except complex maps to float |
903 | // device : preserved |
904 | // tensor inputs : 1 |
905 | // tensor outputs : 1 |
906 | // Additionally: |
907 | // - First input should be the only tensor input |
908 | static const register_formula_for simple_unary_ops_complex_to_float{ |
909 | { |
910 | "aten::abs(Tensor self) -> Tensor" , |
911 | }, |
912 | [](Node* node) -> type_vec_t { |
913 | auto input_type = node->input(0)->type()->cast<TensorType>(); |
914 | |
915 | // Maps complex -> float |
916 | if (input_type->scalarType()) { |
917 | const auto scalar_type = *(input_type->scalarType()); |
918 | if (isComplexType(scalar_type)) { |
919 | const auto out_type = c10::toRealValueType(scalar_type); |
920 | return type_vec_t{ |
921 | input_type->dimensionedOnly()->withScalarType(out_type)}; |
922 | } |
923 | } |
924 | |
925 | return input_type ? type_vec_t{input_type->dimensionedOnly()} |
926 | : type_vec_t{}; |
927 | }}; |
928 | |
929 | // Requirements: |
930 | // dims : broadcast all tensor args |
931 | // scalar type : promoted from input dtypes |
932 | // device : always matching and preserved |
933 | // tensor inputs : * |
934 | // tensor outputs : 1 |
935 | static const register_formula_for broadcasting_ops_arithmetic{ |
936 | { |
937 | // Tensor-Tensor operators |
938 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
939 | "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
940 | "aten::mul(Tensor self, Tensor other) -> Tensor" , |
941 | "aten::div(Tensor self, Tensor other) -> Tensor" , |
942 | }, |
943 | [](Node* node) -> type_vec_t { |
944 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
945 | AT_ASSERT(maybe_tensor_types->size() >= 2); |
946 | auto dtype = getPromotedTypeForArithmeticOp(node); |
947 | return {broadcast(*maybe_tensor_types, dtype)}; |
948 | } |
949 | return {}; |
950 | }}; |
951 | |
952 | // Requirements: |
953 | // dims : broadcast all tensor args |
954 | // scalar type : always matching and preserved |
955 | // device : always matching and preserved |
956 | // tensor inputs : * |
957 | // tensor outputs : 1 |
958 | static const register_formula_for broadcasting_ops{ |
959 | { |
960 | "aten::pow(Tensor self, Tensor exponent) -> Tensor" , |
961 | "aten::fmod(Tensor self, Tensor other) -> Tensor" , |
962 | "aten::remainder(Tensor self, Tensor other) -> Tensor" , |
963 | "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor" , |
964 | "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor" , |
965 | "aten::max(Tensor self, Tensor other) -> Tensor" , |
966 | "aten::min(Tensor self, Tensor other) -> Tensor" , |
967 | "aten::__and__(Tensor self, Tensor other) -> Tensor" , |
968 | "aten::__or__(Tensor self, Tensor other) -> Tensor" , |
969 | "aten::__xor__(Tensor self, Tensor other) -> Tensor" , |
970 | "aten::__lshift__(Tensor self, Tensor other) -> Tensor" , |
971 | "aten::__rshift__(Tensor self, Tensor other) -> Tensor" , |
972 | "aten::__iand__(Tensor self, Tensor other) -> Tensor" , |
973 | "aten::__ior__(Tensor self, Tensor other) -> Tensor" , |
974 | "aten::__ixor__(Tensor self, Tensor other) -> Tensor" , |
975 | "aten::__ilshift__(Tensor self, Tensor other) -> Tensor" , |
976 | "aten::__irshift__(Tensor self, Tensor other) -> Tensor" , |
977 | |
978 | // Ops with Tensor-Tensor overloads only |
979 | "aten::atan2(Tensor self, Tensor other) -> Tensor" , |
980 | }, |
981 | [](Node* node) -> type_vec_t { |
982 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
983 | AT_ASSERT(maybe_tensor_types->size() >= 2); |
984 | auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); |
985 | auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType(); |
986 | if (!first_scalar_type || !second_scalar_type) { |
987 | return {}; |
988 | } |
989 | size_t arg_for_type = 0; |
990 | if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) != |
991 | first_scalar_type) { |
992 | arg_for_type = 1; |
993 | } |
994 | auto t = (*maybe_tensor_types)[arg_for_type]->scalarType(); |
995 | return {broadcast(*maybe_tensor_types, *t)}; |
996 | } |
997 | return {}; |
998 | }}; |
999 | |
1000 | static const register_formula_for fused_accum_binary_ops{ |
1001 | { |
1002 | // Non-binary ops |
1003 | "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor" , |
1004 | "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor" , |
1005 | }, |
1006 | [](Node* node) -> type_vec_t { |
1007 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1008 | auto dtype = (*maybe_tensor_types)[0]->scalarType(); |
1009 | if (!dtype) { |
1010 | return {}; |
1011 | } |
1012 | return {broadcast(*maybe_tensor_types, *dtype)}; |
1013 | } |
1014 | return {}; |
1015 | }}; |
1016 | |
1017 | static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{ |
1018 | { |
1019 | // Tensor-Scalar operators |
1020 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
1021 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
1022 | "aten::mul(Tensor self, Scalar other) -> Tensor" , |
1023 | "aten::div(Tensor self, Scalar other) -> Tensor" , |
1024 | }, |
1025 | [](Node* node) -> type_vec_t { |
1026 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1027 | auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); |
1028 | auto second_scalar_type = |
1029 | tryScalarTypeFromJitType(*node->inputs()[1]->type()); |
1030 | if (!first_scalar_type || !second_scalar_type) { |
1031 | return {}; |
1032 | } |
1033 | if (isIntegralType(*first_scalar_type, false) && |
1034 | isFloatingType(*second_scalar_type)) { |
1035 | auto default_dtype = |
1036 | at::typeMetaToScalarType(caffe2::get_default_dtype()); |
1037 | return {broadcast(*maybe_tensor_types, default_dtype)}; |
1038 | } |
1039 | if (c10::ScalarType::Bool == *first_scalar_type && |
1040 | c10::ScalarType::Bool != *second_scalar_type) { |
1041 | auto result_type = |
1042 | c10::promoteTypes(*first_scalar_type, *second_scalar_type); |
1043 | return {broadcast(*maybe_tensor_types, result_type)}; |
1044 | } |
1045 | return {broadcast(*maybe_tensor_types, first_scalar_type)}; |
1046 | } |
1047 | return {}; |
1048 | }}; |
1049 | |
1050 | // NB: we always take the scalar type of the Tensor |
1051 | static const register_formula_for broadcasting_tensor_scalar_ops{ |
1052 | { |
1053 | |
1054 | "aten::pow(Tensor self, Scalar exponent) -> Tensor" , |
1055 | "aten::fmod(Tensor self, Scalar other) -> Tensor" , |
1056 | "aten::remainder(Tensor self, Scalar other) -> Tensor" , |
1057 | "aten::pow(Scalar self, Tensor exponent) -> Tensor" , |
1058 | "aten::__and__(Tensor self, Scalar other) -> Tensor" , |
1059 | "aten::__or__(Tensor self, Scalar other) -> Tensor" , |
1060 | "aten::__xor__(Tensor self, Scalar other) -> Tensor" , |
1061 | "aten::__lshift__(Tensor self, Scalar other) -> Tensor" , |
1062 | "aten::__rshift__(Tensor self, Scalar other) -> Tensor" , |
1063 | "aten::__iand__(Tensor self, Scalar other) -> Tensor" , |
1064 | "aten::__ior__(Tensor self, Scalar other) -> Tensor" , |
1065 | "aten::__ixor__(Tensor self, Scalar other) -> Tensor" , |
1066 | "aten::__ilshift__(Tensor self, Scalar other) -> Tensor" , |
1067 | "aten::__irshift__(Tensor self, Scalar other) -> Tensor" , |
1068 | }, |
1069 | [](Node* node) -> type_vec_t { |
1070 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1071 | return {broadcast( |
1072 | *maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())}; |
1073 | } |
1074 | return {}; |
1075 | }}; |
1076 | |
1077 | // aten::where is special in that its return type is the second argument's |
1078 | // (self) type rather than the that of condition |
1079 | static const register_formula_for where_op{ |
1080 | { |
1081 | "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor" , |
1082 | }, |
1083 | [](Node* node) -> type_vec_t { |
1084 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1085 | return {broadcast( |
1086 | *maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())}; |
1087 | } |
1088 | return {}; |
1089 | }}; |
1090 | |
1091 | static const auto any_tensor_type = [](Node* node) -> TensorTypePtr { |
1092 | for (Value* input : node->inputs()) { |
1093 | if (auto type = input->type()->cast<TensorType>()) { |
1094 | if (type->dim().has_value()) { |
1095 | return type; |
1096 | } |
1097 | } |
1098 | } |
1099 | return nullptr; |
1100 | }; |
1101 | |
1102 | // Requirements: |
1103 | // dims : always matching and preserved |
1104 | // scalar type : always matching and preserved |
1105 | // device : always matching and preserved |
1106 | // tensor inputs : 2 |
1107 | // tensor outputs : 1 |
1108 | static const register_formula_for binary_ops_strict_match{ |
1109 | { |
1110 | "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor" , |
1111 | "aten::mm(Tensor self, Tensor mat2) -> Tensor" , |
1112 | "aten::bmm(Tensor self, Tensor mat2) -> Tensor" , |
1113 | }, |
1114 | [](Node* node) -> type_vec_t { |
1115 | if (auto type = any_tensor_type(node)) { |
1116 | return {std::move(type)}; |
1117 | } |
1118 | return {}; |
1119 | }}; |
1120 | |
1121 | // Requirements: |
1122 | // dims : all tensor args are broadcast |
1123 | // scalar type : byte/uint8 |
1124 | // device : always matching and preserved |
1125 | // tensor inputs : * |
1126 | // tensor outputs : 1 |
1127 | static const register_formula_for comparison_ops{ |
1128 | { |
1129 | "aten::lt(Tensor self, Tensor other) -> Tensor" , |
1130 | "aten::le(Tensor self, Tensor other) -> Tensor" , |
1131 | "aten::gt(Tensor self, Tensor other) -> Tensor" , |
1132 | "aten::ge(Tensor self, Tensor other) -> Tensor" , |
1133 | "aten::eq(Tensor self, Tensor other) -> Tensor" , |
1134 | "aten::ne(Tensor self, Tensor other) -> Tensor" , |
1135 | "aten::lt(Tensor self, Scalar other) -> Tensor" , |
1136 | "aten::le(Tensor self, Scalar other) -> Tensor" , |
1137 | "aten::gt(Tensor self, Scalar other) -> Tensor" , |
1138 | "aten::ge(Tensor self, Scalar other) -> Tensor" , |
1139 | "aten::eq(Tensor self, Scalar other) -> Tensor" , |
1140 | "aten::ne(Tensor self, Scalar other) -> Tensor" , |
1141 | }, |
1142 | [](Node* node) -> type_vec_t { |
1143 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1144 | return {broadcast(*maybe_tensor_types, at::kBool)}; |
1145 | } |
1146 | return {}; |
1147 | }}; |
1148 | |
1149 | static const register_formula_for nn_ops_first_input_formula{ |
1150 | *nn_ops_first_input_preserving(), [](Node* node) -> type_vec_t { |
1151 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1152 | return {type->dimensionedOnly()}; |
1153 | } |
1154 | return {}; |
1155 | }}; |
1156 | |
1157 | // Requirements: |
1158 | // dims : 0 |
1159 | // scalar type : preserved |
1160 | // device : preserved |
1161 | // tensor inputs : 1 |
1162 | // tensor outputs : 1 |
1163 | // Additionally: |
1164 | // - First input should be the only tensor input |
1165 | static const register_formula_for all_reduce_ops{ |
1166 | { |
1167 | "aten::det(Tensor self) -> Tensor" , |
1168 | "aten::logdet(Tensor self) -> Tensor" , |
1169 | "aten::max(Tensor self) -> Tensor" , |
1170 | "aten::min(Tensor self) -> Tensor" , |
1171 | "aten::median(Tensor self) -> Tensor" , |
1172 | "aten::nanmedian(Tensor self) -> Tensor" , |
1173 | "aten::norm(Tensor self, Scalar p) -> Tensor" , |
1174 | "aten::std(Tensor self, bool unbiased) -> Tensor" , |
1175 | "aten::trace(Tensor self) -> Tensor" , |
1176 | "aten::var(Tensor self, bool unbiased) -> Tensor" , |
1177 | "aten::all(Tensor self) -> Tensor" , |
1178 | "aten::any(Tensor self) -> Tensor" , |
1179 | }, |
1180 | [](Node* node) -> type_vec_t { |
1181 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1182 | return {type->withDim(0)}; |
1183 | } |
1184 | return {}; |
1185 | }}; |
1186 | |
1187 | // Requirements: |
1188 | // dims : 0 |
1189 | // scalar type : dtype if specified, else preserved |
1190 | // device : preserved |
1191 | // tensor inputs : 1 |
1192 | // tensor outputs : 1 |
1193 | // Additionally: |
1194 | // - First input should be the only tensor input |
1195 | static const register_formula_for reduce_ops_with_opt_dtype{ |
1196 | {"aten::mean(Tensor self, *, int? dtype) -> Tensor" }, |
1197 | [](Node* node) -> type_vec_t { |
1198 | at::optional<IValue> maybe_dtype_option = node->get(attr::dtype); |
1199 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1200 | auto ret = type->withDim(0); |
1201 | if (maybe_dtype_option && !maybe_dtype_option->isNone()) { |
1202 | return {ret->withScalarType(maybe_dtype_option->toScalarType())}; |
1203 | } else { |
1204 | return {std::move(ret)}; |
1205 | } |
1206 | } |
1207 | return {}; |
1208 | }}; |
1209 | |
1210 | // Requirements: |
1211 | // dims : 0 |
1212 | // scalar type : dtype if specified, else preserved if floating point, |
1213 | // otherwise long/int64 device : preserved tensor inputs : 1 |
1214 | // tensor outputs : 1 |
1215 | // Additionally: |
1216 | // - First input should be the only tensor input |
1217 | static const register_formula_for |
1218 | all_reduce_ops_with_integer_upcast_and_dtype{ |
1219 | { |
1220 | "aten::sum(Tensor self, *, int? dtype) -> Tensor" , |
1221 | "aten::prod(Tensor self, *, int? dtype) -> Tensor" , |
1222 | }, |
1223 | [](Node* node) -> type_vec_t { |
1224 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1225 | type = type->withDim(0); |
1226 | at::optional<IValue> maybe_dtype_option = |
1227 | node->get(attr::dtype); |
1228 | if (maybe_dtype_option && !maybe_dtype_option->isNone()) { |
1229 | return { |
1230 | type->withScalarType(maybe_dtype_option->toScalarType())}; |
1231 | } |
1232 | if (type->scalarType()) { |
1233 | return { |
1234 | at::isFloatingType(*type->scalarType()) |
1235 | ? std::move(type) |
1236 | : type->withScalarType(at::kLong)}; |
1237 | } else { |
1238 | return {std::move(type)}; |
1239 | } |
1240 | } |
1241 | return {}; |
1242 | }}; |
1243 | |
1244 | static const auto reduce_op_handler = [](Node* node, |
1245 | int64_t num_reduced_dim = 0, |
1246 | bool upcast_integer = false, |
1247 | c10::optional<IValue> opt_dtype = |
1248 | c10::nullopt) -> type_vec_t { |
1249 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1250 | if (!type->scalarType() || !type->dim()) { |
1251 | return {}; |
1252 | } |
1253 | if (opt_dtype && !opt_dtype->isNone()) { |
1254 | type = type->withScalarType(opt_dtype->toScalarType()); |
1255 | } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) { |
1256 | type = type->withScalarType(at::kLong); |
1257 | } |
1258 | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
1259 | if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) { |
1260 | return {type->withDim(*type->dim() - num_reduced_dim)}; |
1261 | } else { |
1262 | return {std::move(type)}; |
1263 | } |
1264 | } |
1265 | return {}; |
1266 | }; |
1267 | |
1268 | static const auto multidim_reduce_with_keepdim = |
1269 | [](Node* node, |
1270 | int64_t num_reduced_dim, |
1271 | bool upcast_integer) -> type_vec_t { |
1272 | auto maybe_keepdim = node->get<bool>(attr::keepdim); |
1273 | if (!maybe_keepdim) |
1274 | return {}; |
1275 | return reduce_op_handler( |
1276 | node, *maybe_keepdim ? 0 : num_reduced_dim, upcast_integer); |
1277 | }; |
1278 | |
1279 | // Requirements: |
1280 | // dims : 0 if dim is None, otherwise preserved if keepdim == |
1281 | // false or 1 smaller otherwise scalar type : preserved device : |
1282 | // preserved tensor inputs : 1 tensor outputs : 1 |
1283 | // Additionally: |
1284 | // - First input should be the only tensor input |
1285 | // - Has a bool keepdim argument |
1286 | static const register_formula_for argminmax{ |
1287 | { |
1288 | "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor" , |
1289 | "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor" , |
1290 | }, |
1291 | [](Node* node) -> type_vec_t { |
1292 | if (auto type = node->input(0)->type()->cast<TensorType>()) { |
1293 | if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) { |
1294 | return {type->withDim(0)}; |
1295 | } else { |
1296 | return multidim_reduce_with_keepdim( |
1297 | node, /*num_reduced_dim=*/1, /*upcast_integer=*/false); |
1298 | } |
1299 | } |
1300 | return {}; |
1301 | }}; |
1302 | |
1303 | // Requirements: |
1304 | // dims : preserved if keepdim == false, 1 smaller otherwise |
1305 | // scalar type : preserved for first output, byte/uint8 for second |
1306 | // output if exists device : preserved tensor inputs : 1 tensor |
1307 | // outputs : 1 or 2 |
1308 | // Additionally: |
1309 | // - First input should be the only tensor input |
1310 | // - Has a bool keepdim argument |
1311 | static const register_formula_for dim_reduce_ops{ |
1312 | { |
1313 | "aten::all(Tensor self, int dim, bool keepdim) -> Tensor" , |
1314 | "aten::any(Tensor self, int dim, bool keepdim) -> Tensor" , |
1315 | |
1316 | // Ops returning indices as second output |
1317 | "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1318 | "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1319 | "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1320 | "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1321 | "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1322 | "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)" , |
1323 | }, |
1324 | [](Node* node) -> type_vec_t { |
1325 | // NB: Note that while this function is generally meant to be used |
1326 | // with ops that have a single output, we will fix up its return right |
1327 | // below. |
1328 | auto output_types = multidim_reduce_with_keepdim( |
1329 | node, /*num_reduced_dim=*/1, /*upcast_integer=*/false); |
1330 | if (!output_types.empty() && node->outputs().size() == 2) { |
1331 | output_types.push_back( |
1332 | output_types.back()->withScalarType(at::kLong)); |
1333 | } |
1334 | return output_types; |
1335 | }}; |
1336 | |
1337 | // Requirements: |
1338 | // dims : preserved if keepdim == false, 1 smaller otherwise |
1339 | // scalar type : dtype if specified. preserved if floating point, |
1340 | // otherwise long/int64 device : preserved tensor inputs : 1 |
1341 | // tensor outputs : 1 |
1342 | // Additionally: |
1343 | // - First input should be the only tensor input |
1344 | // - has a bool keepdim argument |
1345 | static const register_formula_for dim_reduce_ops_with_integer_upcast{ |
1346 | { |
1347 | "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor" , |
1348 | }, |
1349 | [](Node* node) -> type_vec_t { |
1350 | auto maybe_keepdim = node->get<bool>(attr::keepdim); |
1351 | at::optional<IValue> opt_dtype = node->get(attr::dtype); |
1352 | return reduce_op_handler( |
1353 | node, |
1354 | /*num_reduce_dim=*/*maybe_keepdim ? 0 : 1, |
1355 | /*integer_upcast=*/true, |
1356 | std::move(opt_dtype)); |
1357 | }}; |
1358 | |
1359 | // Requirements: |
1360 | // dims : preserved |
1361 | // scalar type : dtype if specified, preserved if floating point, |
1362 | // otherwise long/int64 |
1363 | // device : preserved |
1364 | // tensor inputs : 1 |
1365 | // tensor outputs : 1 |
1366 | // Additionally: |
1367 | // - First input should be the only tensor input |
1368 | static const register_formula_for dim_reduce_ops_dtype{ |
1369 | {"aten::cumprod(Tensor self, int dim, *, int? dtype) -> Tensor" , |
1370 | "aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor" , |
1371 | "aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor" }, |
1372 | [](Node* node) -> type_vec_t { |
1373 | at::optional<IValue> opt_dtype = node->get(attr::dtype); |
1374 | return reduce_op_handler( |
1375 | node, |
1376 | /*num_reduce_dim=*/0, |
1377 | /*integer_upcast=*/true, |
1378 | std::move(opt_dtype)); |
1379 | }}; |
1380 | |
1381 | // Requirements: |
1382 | // dims : preserved |
1383 | // scalar type : dtype if specified, otherwise preserved |
1384 | // device : preserved |
1385 | // tensor inputs : 1 |
1386 | // tensor outputs : 1 |
1387 | // Additionally: |
1388 | // - has bool keepdim and int[] dim arguments |
1389 | static const register_formula_for register_softmax{ |
1390 | {"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor" }, |
1391 | [](Node* node) -> type_vec_t { |
1392 | at::optional<IValue> opt_dtype = node->get(attr::dtype); |
1393 | return reduce_op_handler( |
1394 | node, |
1395 | /*num_reduced_dim=*/0, |
1396 | /*upcast_integer=*/false, |
1397 | std::move(opt_dtype)); |
1398 | }}; |
1399 | |
1400 | static const auto factory_with_ndim = |
1401 | [](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t { |
1402 | at::optional<IValue> maybe_layout_option = node->get(attr::layout); |
1403 | if (!maybe_layout_option) |
1404 | return {}; |
1405 | |
1406 | at::optional<IValue> maybe_device_option = node->get(attr::device); |
1407 | if (!maybe_device_option) |
1408 | return {}; |
1409 | auto device = |
1410 | (maybe_device_option->isNone() ? at::kCPU |
1411 | : maybe_device_option->toDevice()); |
1412 | |
1413 | at::optional<IValue> maybe_dtype_option = node->get(attr::dtype); |
1414 | if (!maybe_dtype_option) |
1415 | return {}; |
1416 | auto dtype = |
1417 | (maybe_dtype_option->isNone() ? default_dtype |
1418 | : maybe_dtype_option->toScalarType()); |
1419 | |
1420 | return {TensorType::create( |
1421 | dtype, device, dim, /*requires_grad=*/c10::nullopt)}; |
1422 | }; |
1423 | |
1424 | static const auto factory_like_with_ndim = [](Node* node, |
1425 | int dim) -> type_vec_t { |
1426 | auto tt = node->input(0)->type()->expect<TensorType>(); |
1427 | auto in_type = tt->scalarType(); |
1428 | auto in_dev = tt->device(); |
1429 | |
1430 | at::optional<IValue> maybe_layout_option = node->get(attr::layout); |
1431 | if (!maybe_layout_option) |
1432 | return {}; |
1433 | |
1434 | at::optional<IValue> maybe_device_option = node->get(attr::device); |
1435 | if (!maybe_device_option) |
1436 | return {}; |
1437 | |
1438 | if (!maybe_device_option->isNone()) { |
1439 | in_dev = maybe_device_option->toDevice(); |
1440 | } |
1441 | |
1442 | at::optional<IValue> maybe_dtype_option = node->get(attr::dtype); |
1443 | if (!maybe_dtype_option) |
1444 | return {}; |
1445 | |
1446 | if (!maybe_dtype_option->isNone()) { |
1447 | in_type = maybe_dtype_option->toScalarType(); |
1448 | } |
1449 | |
1450 | return {TensorType::create( |
1451 | in_type, in_dev, dim, /*requires_grad=*/c10::nullopt)}; |
1452 | }; |
1453 | |
1454 | // Requirements: |
1455 | // dims : preserved |
1456 | // scalar type : equal to value of dtype |
1457 | // device : equal to value of device |
1458 | // tensor inputs : 1 |
1459 | // tensor outputs : 1 |
1460 | // Additionally: |
1461 | // - has ScalarType dtype, Layeout layout and Device device arguments |
1462 | static const register_formula_for like_factories_with_options{ |
1463 | { |
1464 | "aten::empty_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1465 | "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1466 | "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1467 | "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1468 | "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1469 | "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1470 | "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1471 | "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
1472 | }, |
1473 | [](Node* node) -> type_vec_t { |
1474 | if (auto type = |
1475 | node->namedInput(attr::self)->type()->cast<TensorType>()) { |
1476 | if (type->dim()) { |
1477 | return factory_like_with_ndim(node, (int)*type->dim()); |
1478 | } |
1479 | } |
1480 | return {}; |
1481 | }}; |
1482 | |
1483 | // Requirements: |
1484 | // dims : equal to number of elements in size |
1485 | // scalar type : equal to value of dtype |
1486 | // device : equal to value of device |
1487 | // tensor inputs : 1 |
1488 | // tensor outputs : 1 |
1489 | // Additionally: |
1490 | // - has int[] size, ScalarType dtype, Layeout layout and Device device |
1491 | // arguments |
1492 | static const register_formula_for size_factories_with_options{ |
1493 | { |
1494 | "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor" , |
1495 | "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1496 | "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1497 | "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1498 | "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1499 | "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1500 | }, |
1501 | [](Node* node) -> type_vec_t { |
1502 | if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) { |
1503 | return factory_with_ndim( |
1504 | node, (int)maybe_size->size(), at::kDouble); |
1505 | } |
1506 | return {}; |
1507 | }}; |
1508 | |
1509 | static const register_formula_for randint{ |
1510 | { |
1511 | "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1512 | "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
1513 | }, |
1514 | [](Node* node) -> type_vec_t { |
1515 | if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) { |
1516 | return factory_with_ndim(node, (int)maybe_size->size(), at::kLong); |
1517 | } |
1518 | return {}; |
1519 | }}; |
1520 | |
1521 | static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType { |
1522 | switch (node->kind()) { |
1523 | case aten::_cast_Byte: |
1524 | return at::kByte; |
1525 | case aten::_cast_Char: |
1526 | return at::kChar; |
1527 | case aten::_cast_Double: |
1528 | return at::kDouble; |
1529 | case aten::_cast_Float: |
1530 | return at::kFloat; |
1531 | case aten::_cast_Half: |
1532 | return at::kHalf; |
1533 | case aten::_cast_Int: |
1534 | return at::kInt; |
1535 | case aten::_cast_Long: |
1536 | return at::kLong; |
1537 | case aten::_cast_Short: |
1538 | return at::kShort; |
1539 | default: |
1540 | AT_ASSERTM( |
1541 | false, |
1542 | "unknown node kind in get_cast_scalar_type: " , |
1543 | node->kind().toQualString()); |
1544 | } |
1545 | }; |
1546 | static const register_formula_for cast_ops{ |
1547 | { |
1548 | "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor" , |
1549 | "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor" , |
1550 | "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor" , |
1551 | "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor" , |
1552 | "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor" , |
1553 | "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor" , |
1554 | "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor" , |
1555 | "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor" , |
1556 | }, |
1557 | [](Node* node) -> type_vec_t { |
1558 | if (auto type = |
1559 | node->namedInput(attr::self)->type()->cast<TensorType>()) { |
1560 | return {type->withScalarType(get_cast_scalar_type(node))}; |
1561 | } |
1562 | return {}; |
1563 | }}; |
1564 | |
1565 | // First, try to match one of the registered formulas to their operator |
1566 | // sets. |
1567 | for (auto& entry : shape_formulas) { |
1568 | if (node->isMemberOf(entry.first)) { |
1569 | auto types = entry.second(node); |
1570 | if (types.empty()) { |
1571 | return false; |
1572 | } else { |
1573 | auto outputs = node->outputs(); |
1574 | AT_ASSERT(types.size() == outputs.size()); |
1575 | for (const auto i : c10::irange(types.size())) { |
1576 | AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get())); |
1577 | outputs[i]->setType(types[i]); |
1578 | } |
1579 | return true; |
1580 | } |
1581 | } |
1582 | } |
1583 | |
1584 | // This section implements shape prop for an assorted set of nodes that only |
1585 | // need partial information about their input types. |
1586 | const auto input_type = [node](size_t index) { |
1587 | auto result = node->input(index)->type()->cast<TensorType>(); |
1588 | if (result) { |
1589 | result = result->dimensionedOnly(); |
1590 | } |
1591 | return result; |
1592 | }; |
1593 | if (node->matches( |
1594 | "aten::masked_select(Tensor self, Tensor mask) -> Tensor" )) { |
1595 | if (auto type = input_type(0)) { |
1596 | node->output()->setType(type->withDim(1)); |
1597 | return true; |
1598 | } |
1599 | } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)" )) { |
1600 | if (auto type = input_type(0)) { |
1601 | node->output()->setType(type->withRequiresGrad(false)); |
1602 | return true; |
1603 | } |
1604 | } else if ( |
1605 | node->matches( |
1606 | "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)" )) { |
1607 | if (auto type = input_type(0)) { |
1608 | if (type->scalarType() == at::kHalf) { |
1609 | type = type->withScalarType(at::kFloat); |
1610 | } |
1611 | type = type->withDim(1); |
1612 | node->outputs()[0]->setType(type); |
1613 | node->outputs()[1]->setType(std::move(type)); |
1614 | return true; |
1615 | } |
1616 | } else if (node->matches( |
1617 | "aten::dot(Tensor self, Tensor tensor) -> Tensor" )) { |
1618 | if (auto type = any_tensor_type(node)) { |
1619 | node->output()->setType(type->withDim(0)); |
1620 | return true; |
1621 | } |
1622 | } else if ( |
1623 | node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor" ) || |
1624 | node->matches( |
1625 | "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor" )) { |
1626 | if (auto type = any_tensor_type(node)) { |
1627 | node->output()->setType(type->withDim(1)); |
1628 | return true; |
1629 | } |
1630 | } else if ( |
1631 | node->matches( |
1632 | "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor" ) || |
1633 | node->matches( |
1634 | "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor" ) || |
1635 | node->matches( |
1636 | "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor" )) { |
1637 | if (auto type = any_tensor_type(node)) { |
1638 | node->output()->setType(type->withDim(2)); |
1639 | return true; |
1640 | } |
1641 | } else if ( |
1642 | node->matches( |
1643 | "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor" )) { |
1644 | if (auto type = any_tensor_type(node)) { |
1645 | node->output()->setType(type->withDim(3)); |
1646 | return true; |
1647 | } |
1648 | } else if ( |
1649 | node->matches( |
1650 | "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor" )) { |
1651 | auto type = input_type(0); |
1652 | auto index_type = input_type(1); |
1653 | // index_select behaves very weirdly when self.dim() == 0. It allows both |
1654 | // 0D and 1D indices, and returns a value that has as many dimensions as |
1655 | // index. |
1656 | if (type && index_type && type->dim()) { |
1657 | if (*type->dim() == 0) { |
1658 | node->output()->setType(type->withDim(index_type->dim())); |
1659 | } else { |
1660 | node->output()->setType(std::move(type)); |
1661 | } |
1662 | return true; |
1663 | } |
1664 | } else if ( |
1665 | node->matches( |
1666 | "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor" )) { |
1667 | auto type = input_type(0); |
1668 | auto index_type = input_type(1); |
1669 | // Gather has this annoying edge case where index always needs to match |
1670 | // the number of dims of self, **except** when self is 1D and index is 0D |
1671 | // in which case we return a 0D output. |
1672 | if (type && index_type && index_type->dim()) { |
1673 | if (*index_type->dim() == 0) { |
1674 | node->output()->setType(type->withDim(0)); |
1675 | } else { |
1676 | node->output()->setType(std::move(type)); |
1677 | } |
1678 | return true; |
1679 | } |
1680 | } else if ( |
1681 | node->matches( |
1682 | "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor" )) { |
1683 | auto weight_type = input_type(0); |
1684 | auto indices_type = input_type(1); |
1685 | if (weight_type && indices_type && indices_type->dim()) { |
1686 | node->output()->setType(weight_type->withDim(*indices_type->dim() + 1)); |
1687 | return true; |
1688 | } |
1689 | } else if ( |
1690 | node->matches( |
1691 | "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor" )) { |
1692 | if (auto type = input_type(0)) { |
1693 | node->output()->setType(std::move(type)); |
1694 | return true; |
1695 | } |
1696 | if (auto type = input_type(1)) { |
1697 | node->output()->setType(std::move(type)); |
1698 | return true; |
1699 | } |
1700 | } else if ( |
1701 | node->matches( |
1702 | "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor" )) { |
1703 | if (auto type = any_tensor_type(node)) { |
1704 | node->output()->setType(type->withDim(0)); |
1705 | return true; |
1706 | } |
1707 | } |
1708 | |
1709 | // The code below implements formulas that need type information for all |
1710 | // their tensor inputs, and have exactly one output. |
1711 | std::vector<TensorTypePtr> tensor_types; |
1712 | static const auto reshape_prop = |
1713 | [](Node* node, |
1714 | Symbol shape_input, |
1715 | const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr { |
1716 | if (auto list_size = determineListSize(node->namedInput(shape_input))) { |
1717 | return tensor_types.at(0)->withDim(*list_size); |
1718 | } |
1719 | return nullptr; |
1720 | }; |
1721 | const auto getSingleOutputType = [&]() -> TypePtr { |
1722 | if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor" )) { |
1723 | return tensor_types.at(0)->withScalarType( |
1724 | tensor_types.at(1)->scalarType()); |
1725 | } else if ( |
1726 | node->matches( |
1727 | "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)" ) || |
1728 | node->matches( |
1729 | "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)" ) || |
1730 | node->matches( |
1731 | "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)" )) { |
1732 | return tensor_types.at(0)->withDim(tensor_types.at(1)->dim()); |
1733 | } else if ( |
1734 | node->matches("aten::view(Tensor self, int[] size) -> Tensor" ) || |
1735 | node->matches( |
1736 | "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor" ) || |
1737 | node->matches( |
1738 | "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor" )) { |
1739 | return reshape_prop(node, attr::size, tensor_types); |
1740 | } else if ( |
1741 | node->matches( |
1742 | "aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor" )) { |
1743 | TypePtr input_type = node->inputs().at(0)->type(); |
1744 | if (auto type = input_type->cast<TensorType>()) { |
1745 | if (type->scalarType() && type->device()) { |
1746 | at::ScalarType default_type = *type->scalarType(); |
1747 | c10::Device default_device = *type->device(); |
1748 | if (auto dtype_index = |
1749 | node->schema().argumentIndexWithName("dtype" )) { |
1750 | auto inp = toIValue(node->inputs().at(*dtype_index)); |
1751 | if (inp == c10::nullopt) { |
1752 | return nullptr; |
1753 | } |
1754 | if (!inp->isNone()) { |
1755 | default_type = inp->toScalarType(); |
1756 | } |
1757 | } |
1758 | if (auto device_index = |
1759 | node->schema().argumentIndexWithName("device" )) { |
1760 | auto inp = toIValue(node->inputs().at(*device_index)); |
1761 | if (inp == c10::nullopt) { |
1762 | return nullptr; |
1763 | } |
1764 | if (!inp->isNone()) { |
1765 | default_device = inp->toDevice(); |
1766 | } |
1767 | } |
1768 | node->output()->setType(TensorType::create( |
1769 | default_type, |
1770 | default_device, |
1771 | type->dim(), |
1772 | /*requires_grad=*/c10::nullopt)); |
1773 | } |
1774 | } |
1775 | return nullptr; |
1776 | } else if ( |
1777 | node->matches( |
1778 | "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)" )) { |
1779 | return reshape_prop(node, attr::shape, tensor_types); |
1780 | } else if (node->matches( |
1781 | "aten::repeat(Tensor self, int[] repeats) -> Tensor" )) { |
1782 | return reshape_prop(node, attr::repeats, tensor_types); |
1783 | } else if (node->matches( |
1784 | "aten::unsqueeze(Tensor self, int dim) -> Tensor" )) { |
1785 | auto& t = tensor_types.at(0); |
1786 | if (!t->dim()) { |
1787 | return t; |
1788 | } |
1789 | return t->withDim(*t->dim() + 1); |
1790 | } else if ( |
1791 | node->matches( |
1792 | "aten::select(Tensor self, int dim, int index) -> Tensor" ) || |
1793 | node->matches( |
1794 | "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor" )) { |
1795 | auto& t = tensor_types.at(0); |
1796 | return t->dim() && *t->dim() > 0 ? t->withDim(*t->dim() - 1) : nullptr; |
1797 | } else if (node->matches( |
1798 | "aten::matmul(Tensor self, Tensor other) -> Tensor" )) { |
1799 | if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) { |
1800 | return nullptr; |
1801 | } |
1802 | int dim1 = *tensor_types.at(0)->dim(); |
1803 | int dim2 = *tensor_types.at(1)->dim(); |
1804 | if (dim1 == 1 && dim2 == 1) { |
1805 | // Dot product |
1806 | return tensor_types.at(0)->withDim(0); |
1807 | // NOLINTNEXTLINE(bugprone-branch-clone) |
1808 | } else if (dim1 == 2 && dim2 == 2) { |
1809 | // Matrix multiply |
1810 | return tensor_types.at(0); |
1811 | } else if (dim1 == 1 && dim2 == 2) { |
1812 | // Unsqueeze + matrix multiply + squeeze |
1813 | return tensor_types.at(0); |
1814 | } else if (dim1 == 2 && dim2 == 1) { |
1815 | // Matrix vector multiply |
1816 | return tensor_types.at(1); |
1817 | } else { |
1818 | // Batched matrix multiply (possibly with squeeze + unsqueeze if one |
1819 | // argument is 1D) |
1820 | auto type = broadcast(tensor_types, tensor_types[0]->scalarType()); |
1821 | if (dim1 == 1 || dim2 == 1) { |
1822 | type = type->withDim(type->dim().value() - 1); |
1823 | } |
1824 | return type; |
1825 | } |
1826 | } else if (node->matches("aten::nonzero(Tensor self) -> Tensor" )) { |
1827 | return tensor_types.at(0)->dimensionedOnly()->withScalarType(at::kLong); |
1828 | } else if (node->matches( |
1829 | "aten::take(Tensor self, Tensor index) -> Tensor" )) { |
1830 | return tensor_types.at(1)->dimensionedOnly()->withScalarType( |
1831 | tensor_types.at(0)->scalarType()); |
1832 | } else if (node->matches( |
1833 | "aten::diagflat(Tensor self, int offset) -> Tensor" )) { |
1834 | return tensor_types.at(0)->withDim(2); |
1835 | } else if (node->matches( |
1836 | "aten::diag(Tensor self, int diagonal) -> Tensor" )) { |
1837 | auto& t = tensor_types.at(0); |
1838 | if (t->dim() && *t->dim() == 1) { |
1839 | return t->withDim(2); |
1840 | } else if (t->dim() && *t->dim() == 2) { |
1841 | return t->withDim(1); |
1842 | } else { |
1843 | return nullptr; |
1844 | } |
1845 | } else if ( |
1846 | node->matches( |
1847 | "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor" )) { |
1848 | auto& t = tensor_types.at(0); |
1849 | if (!t->dim()) { |
1850 | return nullptr; |
1851 | } |
1852 | return t->withDim(*t->dim() + 1); |
1853 | } else if (node->matches( |
1854 | "aten::polygamma(int n, Tensor self) -> Tensor" )) { |
1855 | return tensor_types.at(0); |
1856 | } |
1857 | return nullptr; |
1858 | }; |
1859 | if (auto maybe_tensor_types = gatherTensorTypes(node)) { |
1860 | tensor_types = std::move(*maybe_tensor_types); |
1861 | } else { |
1862 | return false; |
1863 | } |
1864 | if (node->outputs().size() == 1) { |
1865 | if (auto type = getSingleOutputType()) { |
1866 | node->output()->setType(std::move(type)); |
1867 | return true; |
1868 | } |
1869 | } |
1870 | return false; |
1871 | } |
1872 | |
1873 | bool PropagateCompleteShapeOnNode( |
1874 | Node* node, |
1875 | bool insert_expands, |
1876 | std::vector<TensorTypePtr> tensor_types) { |
1877 | // For expensive ops we can directly encode their shape propagation |
1878 | // here, otherwise we fallback to running a fake version of the op |
1879 | // to get a quick and dirty propagation. |
1880 | if (node->matches( |
1881 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" ) || |
1882 | node->matches( |
1883 | "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" ) || |
1884 | node->matches("aten::mul(Tensor self, Tensor other) -> Tensor" )) { |
1885 | // These nodes handle tensors of different shapes internally, so there's |
1886 | // no need to insert explicit expand nodes. |
1887 | return PropagateShapeOnNodeByRunningIt(node); |
1888 | } else if (node->matches( |
1889 | "aten::div(Tensor self, Tensor other) -> Tensor" )) { |
1890 | // "div" handle tensors of different shapes internally, so there's no need |
1891 | // to insert explicit expand nodes. |
1892 | // Note that this function could be merged to the one above , but "div" is |
1893 | // not always safe to run by itself due to integer divide-by-zero. |
1894 | // We fake the execution by running "mul" operation instead. |
1895 | auto op = getOperatorForLiteral( |
1896 | "aten::mul(Tensor self, Tensor other) -> Tensor" ) |
1897 | ->getOperation(); |
1898 | return PropagateShapeOnNodeByRunningIt(node, std::move(op)); |
1899 | } else if (node->matches( |
1900 | "aten::pow(Tensor self, Scalar exponent) -> Tensor" )) { |
1901 | node->output()->setType(tensor_types.at(0)); |
1902 | return true; |
1903 | } else if ( |
1904 | node->matches( |
1905 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" ) || |
1906 | node->matches( |
1907 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" ) || |
1908 | node->matches("aten::div(Tensor self, Scalar other) -> Tensor" ) || |
1909 | node->matches("aten::mul(Tensor self, Scalar other) -> Tensor" )) { |
1910 | auto first_scalar_type = (tensor_types)[0]->scalarType(); |
1911 | auto second_scalar_type = |
1912 | tryScalarTypeFromJitType(*node->inputs()[1]->type()); |
1913 | if (!first_scalar_type || !second_scalar_type) { |
1914 | return false; |
1915 | } |
1916 | if (isIntegralType(*first_scalar_type, false) && |
1917 | isFloatingType(*second_scalar_type)) { |
1918 | auto default_dtype = |
1919 | at::typeMetaToScalarType(caffe2::get_default_dtype()); |
1920 | auto type = tensor_types[0]->withScalarType(default_dtype); |
1921 | node->output()->setType(std::move(type)); |
1922 | return true; |
1923 | } |
1924 | if (c10::ScalarType::Bool == *first_scalar_type && |
1925 | c10::ScalarType::Bool != *second_scalar_type) { |
1926 | auto result_type = |
1927 | c10::promoteTypes(*first_scalar_type, *second_scalar_type); |
1928 | auto type = tensor_types[0]->withScalarType(result_type); |
1929 | node->output()->setType(std::move(type)); |
1930 | return true; |
1931 | } |
1932 | auto type = tensor_types[0]->withScalarType(first_scalar_type); |
1933 | node->output()->setType(std::move(type)); |
1934 | return true; |
1935 | } else if ( |
1936 | insert_expands && |
1937 | (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor" ) || |
1938 | node->matches("aten::min(Tensor self, Tensor other) -> Tensor" ) || |
1939 | node->matches("aten::max(Tensor self, Tensor other) -> Tensor" ) || |
1940 | node->matches("aten::lt(Tensor self, Tensor other) -> Tensor" ) || |
1941 | node->matches("aten::le(Tensor self, Tensor other) -> Tensor" ) || |
1942 | node->matches("aten::gt(Tensor self, Tensor other) -> Tensor" ) || |
1943 | node->matches("aten::ge(Tensor self, Tensor other) -> Tensor" ) || |
1944 | node->matches("aten::eq(Tensor self, Tensor other) -> Tensor" ) || |
1945 | node->matches("aten::ne(Tensor self, Tensor other) -> Tensor" ))) { |
1946 | // Binary broadcasting ops |
1947 | // NB: we don't handle the nodes in any other way (note the lack of |
1948 | // return!), because the type casting logic in scalar cases is |
1949 | // non-trivial. It's better to just run them. |
1950 | broadcastBinary(node, tensor_types, 0, 1); |
1951 | return PropagateShapeOnNodeByRunningIt(node); |
1952 | } else if ( |
1953 | node->matches( |
1954 | "aten::logit(Tensor self, float? eps = None) -> Tensor" ) || |
1955 | node->matches("aten::neg(Tensor self) -> Tensor" ) || |
1956 | node->matches("aten::sigmoid(Tensor self) -> Tensor" ) || |
1957 | node->matches("aten::tanh(Tensor self) -> Tensor" )) { |
1958 | node->output()->setType(tensor_types.at(0)->contiguous()); |
1959 | return true; |
1960 | } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor" )) { |
1961 | auto lhs_type = tensor_types.at(0); |
1962 | auto rhs_type = tensor_types.at(1); |
1963 | auto lhs_sizes = lhs_type->sizes().concrete_sizes().value(); |
1964 | auto rhs_sizes = rhs_type->sizes().concrete_sizes().value(); |
1965 | SHAPE_ASSERT( |
1966 | *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2); |
1967 | node->output()->setType(TensorType::createContiguous( |
1968 | *lhs_type->scalarType(), |
1969 | *lhs_type->device(), |
1970 | at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]})); |
1971 | return true; |
1972 | } else if (node->matches("aten::t(Tensor self) -> Tensor" )) { |
1973 | auto tp = tensor_types.at(0); |
1974 | auto sizes = tp->sizes().concrete_sizes().value(); |
1975 | auto strides = tp->strides().concrete_sizes().value(); |
1976 | SHAPE_ASSERT(sizes.size() == 2); |
1977 | std::swap(sizes.at(0), sizes.at(1)); |
1978 | std::swap(strides.at(0), strides.at(1)); |
1979 | node->output()->setType(tp->withSizesStrides(sizes, strides)); |
1980 | return true; |
1981 | } else if ( |
1982 | node->matches( |
1983 | "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor" , |
1984 | /*const_inputs=*/{attr::dim, attr::length})) { |
1985 | auto tp = tensor_types.at(0); |
1986 | auto sizes = tp->sizes().concrete_sizes().value(); |
1987 | int64_t dim = node->get<int64_t>(attr::dim).value(); |
1988 | int64_t length = node->get<int64_t>(attr::length).value(); |
1989 | SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); |
1990 | sizes.at(dim) = length; |
1991 | node->output()->setType( |
1992 | tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value())); |
1993 | return true; |
1994 | } else if (node->matches( |
1995 | "aten::sum(Tensor self, *, int? dtype) -> Tensor" )) { |
1996 | node->output()->setType(tensor_types.at(0)->withSizes({})); |
1997 | return true; |
1998 | } else if ( |
1999 | node->matches( |
2000 | "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor" , |
2001 | /*const_inputs=*/{attr::dim, attr::keepdim})) { |
2002 | auto& tp = tensor_types.at(0); |
2003 | auto sizes = tp->sizes().concrete_sizes().value(); |
2004 | auto dims = node->get<c10::List<int64_t>>(attr::dim).value(); |
2005 | bool keepdim = node->get<bool>(attr::keepdim).value(); |
2006 | std::reverse(dims.begin(), dims.end()); |
2007 | for (int64_t dim : dims) { |
2008 | SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); |
2009 | if (keepdim) { |
2010 | sizes.at(dim) = 1; |
2011 | } else { |
2012 | sizes.erase(sizes.begin() + dim); |
2013 | } |
2014 | } |
2015 | node->output()->setType(tp->withSizes(sizes)); |
2016 | return true; |
2017 | } else if (node->matches( |
2018 | "aten::squeeze(Tensor self, int dim) -> Tensor" , |
2019 | /*const_inputs=*/attr::dim)) { |
2020 | auto& tp = tensor_types.at(0); |
2021 | auto sizes = tp->sizes().concrete_sizes().value(); |
2022 | auto strides = tp->strides().concrete_sizes().value(); |
2023 | int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes); |
2024 | SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size()); |
2025 | if (sizes.at(dim) == 1) { |
2026 | sizes.erase(sizes.begin() + dim); |
2027 | strides.erase(strides.begin() + dim); |
2028 | } |
2029 | node->output()->setType(tp->withSizesStrides(sizes, strides)); |
2030 | return true; |
2031 | } else if (node->matches( |
2032 | "aten::unsqueeze(Tensor self, int dim) -> Tensor" , |
2033 | /*const_inputs=*/attr::dim)) { |
2034 | auto& tp = tensor_types.at(0); |
2035 | auto sizes = tp->sizes().concrete_sizes().value(); |
2036 | auto strides = tp->strides().concrete_sizes().value(); |
2037 | int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes); |
2038 | SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size()); |
2039 | int64_t new_stride = dim >= static_cast<int64_t>(sizes.size()) |
2040 | ? 1 |
2041 | : sizes.at(dim) * strides.at(dim); |
2042 | sizes.insert(sizes.begin() + dim, 1); |
2043 | strides.insert(strides.begin() + dim, new_stride); |
2044 | node->output()->setType(tp->withSizesStrides(sizes, strides)); |
2045 | return true; |
2046 | } else if (node->matches( |
2047 | "aten::view(Tensor self, int[] size) -> Tensor" , |
2048 | /*const_inputs=*/attr::size)) { |
2049 | auto sizes = node->get<c10::List<int64_t>>(attr::size).value(); |
2050 | bool inferred = false; |
2051 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
2052 | size_t inferred_idx; |
2053 | int64_t size_product = 1; |
2054 | for (const auto i : c10::irange(sizes.size())) { |
2055 | if (sizes.get(i) == -1) { |
2056 | if (inferred) |
2057 | throw propagation_error(); |
2058 | inferred = true; |
2059 | inferred_idx = i; |
2060 | } else { |
2061 | size_product *= sizes.get(i); |
2062 | } |
2063 | } |
2064 | |
2065 | if (inferred) { |
2066 | SHAPE_ASSERT(size_product != 0); |
2067 | size_t numel = 1; |
2068 | auto concrete_sizes = |
2069 | tensor_types.at(0)->sizes().concrete_sizes().value(); |
2070 | for (int64_t s : concrete_sizes) |
2071 | numel *= s; |
2072 | int64_t inferred_size = numel / size_product; |
2073 | sizes[inferred_idx] = inferred_size; |
2074 | } |
2075 | node->output()->setType(tensor_types.at(0)->withSizes(sizes.vec())); |
2076 | return true; |
2077 | } else if (node->matches( |
2078 | "aten::type_as(Tensor self, Tensor other) -> Tensor" )) { |
2079 | if (tensor_types.at(0)->scalarType() == |
2080 | tensor_types.at(1)->scalarType()) { |
2081 | node->output()->setType(node->namedInput(attr::self)->type()); |
2082 | } else { |
2083 | // This will be a copy, so the result will be contiguous |
2084 | node->output()->setType(tensor_types.at(1)->withSizes( |
2085 | tensor_types.at(0)->sizes().concrete_sizes().value())); |
2086 | } |
2087 | return true; |
2088 | } else if ( |
2089 | node->matches( |
2090 | "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor" , |
2091 | /*const_inputs=*/attr::size)) { |
2092 | auto tp = tensor_types.at(0); |
2093 | auto sizesAndStrides = at::inferExpandGeometry_dimvector( |
2094 | tp->sizes().concrete_sizes().value(), |
2095 | tp->strides().concrete_sizes().value(), |
2096 | node->get<c10::List<int64_t>>(attr::size).value().vec()); |
2097 | node->output()->setType( |
2098 | tp->withSizesStrides(sizesAndStrides.sizes, sizesAndStrides.strides)); |
2099 | return true; |
2100 | } else if ( |
2101 | node->matches( |
2102 | "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor" , |
2103 | /*const_inputs=*/attr::dim)) { |
2104 | auto ten = tensor_types.at(0); |
2105 | auto index = tensor_types.at(1); |
2106 | int64_t dim = node->get<int64_t>(attr::dim).value(); |
2107 | SHAPE_ASSERT(*index->sizes().size() == 1); |
2108 | SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size()); |
2109 | std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value(); |
2110 | sizes[dim] = index->sizes()[0].value(); |
2111 | node->output()->setType(ten->withSizes(sizes)); |
2112 | return true; |
2113 | } else if (node->matches( |
2114 | "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]" , |
2115 | /*const_inputs=*/{attr::chunks, attr::dim})) { |
2116 | auto input_type = tensor_types.at(0); |
2117 | auto sizes = input_type->sizes().concrete_sizes().value(); |
2118 | auto strides = input_type->strides().concrete_sizes().value(); |
2119 | int64_t dim = node->get<int64_t>(attr::dim).value(); |
2120 | int64_t chunks = node->get<int64_t>(attr::chunks).value(); |
2121 | sizes[dim] /= chunks; |
2122 | for (Value* output : node->outputs()) { |
2123 | output->setType(input_type->withSizesStrides(sizes, strides)); |
2124 | } |
2125 | if (*input_type->sizes()[dim] % chunks != 0) { |
2126 | sizes[dim] = *input_type->sizes()[dim] % chunks; |
2127 | node->outputs().back()->setType( |
2128 | input_type->withSizesStrides(sizes, strides)); |
2129 | } |
2130 | return true; |
2131 | } else if (node->kind() == ::c10::onnx::Shape) { |
2132 | SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1); |
2133 | std::vector<int64_t> dim_vec = { |
2134 | (int64_t)*tensor_types.at(0)->sizes().size()}; |
2135 | at::IntArrayRef dims(dim_vec); |
2136 | node->output()->setType( |
2137 | TensorType::createContiguous(at::kLong, at::kCPU, dims)); |
2138 | return true; |
2139 | } else if (node->kind() == ::c10::onnx::Reshape) { |
2140 | setUnshapedType(node); |
2141 | return true; |
2142 | } |
2143 | setUnshapedType(node); |
2144 | return false; |
2145 | } |
2146 | }; |
2147 | } // anonymous namespace |
2148 | |
2149 | void PropagateInputShapes(const std::shared_ptr<Graph>& graph) { |
2150 | ShapePropagator(graph).propagateBlock(graph->block()); |
2151 | } |
2152 | |
2153 | namespace { |
2154 | |
2155 | using TypeCache = std::unordered_map<TypePtr, TypePtr>; |
2156 | |
2157 | TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache); |
2158 | |
2159 | TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) { |
2160 | if (type->isSubtypeOf(*TensorType::get())) { |
2161 | return TensorType::get(); |
2162 | } |
2163 | at::ArrayRef<TypePtr> contained = type->containedTypes(); |
2164 | if (contained.empty()) { |
2165 | return type; |
2166 | } |
2167 | std::vector<TypePtr> unshaped_contained_types; |
2168 | for (const auto& contained_type : contained) { |
2169 | unshaped_contained_types.push_back( |
2170 | getOrCreateUnshapedType(contained_type, unshaped_type_cache)); |
2171 | } |
2172 | return type->withContained(std::move(unshaped_contained_types)); |
2173 | } |
2174 | |
2175 | TypePtr getOrCreateUnshapedType(TypePtr type, TypeCache& unshaped_type_cache) { |
2176 | auto maybe_cached_type = unshaped_type_cache.find(type); |
2177 | if (maybe_cached_type != unshaped_type_cache.end()) { |
2178 | return maybe_cached_type->second; |
2179 | } |
2180 | auto unshaped_type = unshapedTypeImpl(type, unshaped_type_cache); |
2181 | unshaped_type_cache[type] = unshaped_type; |
2182 | return unshaped_type; |
2183 | } |
2184 | |
2185 | void EraseShapeInformation( |
2186 | const std::shared_ptr<Graph>& graph, |
2187 | TypeCache& unshaped_type_cache); |
2188 | |
2189 | void EraseShapeInformation( |
2190 | at::ArrayRef<Value*> vals, |
2191 | TypeCache& unshaped_type_cache) { |
2192 | for (Value* v : vals) { |
2193 | v->setType(getOrCreateUnshapedType(v->type(), unshaped_type_cache)); |
2194 | } |
2195 | } |
2196 | |
2197 | void EraseShapeInformation(Block* b, TypeCache& unshaped_type_cache) { |
2198 | EraseShapeInformation(b->inputs(), unshaped_type_cache); |
2199 | EraseShapeInformation(b->outputs(), unshaped_type_cache); |
2200 | for (Node* n : b->nodes()) { |
2201 | EraseShapeInformation(n->outputs(), unshaped_type_cache); |
2202 | for (Block* sb : n->blocks()) { |
2203 | EraseShapeInformation(sb, unshaped_type_cache); |
2204 | } |
2205 | if (n->hasAttribute(attr::Subgraph)) { |
2206 | EraseShapeInformation(n->g(attr::Subgraph), unshaped_type_cache); |
2207 | } |
2208 | } |
2209 | } |
2210 | |
2211 | void EraseShapeInformation( |
2212 | const std::shared_ptr<Graph>& graph, |
2213 | TypeCache& unshaped_type_cache) { |
2214 | EraseShapeInformation(graph->block(), unshaped_type_cache); |
2215 | } |
2216 | |
2217 | } // anonymous namespace |
2218 | |
2219 | void EraseShapeInformation(const std::shared_ptr<Graph>& graph) { |
2220 | TypeCache unshaped_type_cache; |
2221 | EraseShapeInformation(graph->block(), unshaped_type_cache); |
2222 | } |
2223 | } // namespace jit |
2224 | } // namespace torch |
2225 | |