1#include <ATen/core/symbol.h>
2#include <c10/util/Exception.h>
3#include <c10/util/irange.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/ir/constants.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/ir_views.h>
8#include <torch/csrc/jit/jit_log.h>
9#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
10#include <torch/csrc/jit/passes/constant_pooling.h>
11#include <torch/csrc/jit/passes/constant_propagation.h>
12#include <torch/csrc/jit/passes/dead_code_elimination.h>
13#include <torch/csrc/jit/passes/integer_value_refinement.h>
14#include <torch/csrc/jit/passes/loop_unrolling.h>
15#include <torch/csrc/jit/passes/lower_tuples.h>
16#include <torch/csrc/jit/passes/peephole.h>
17#include <torch/csrc/jit/passes/peephole_list_idioms.h>
18#include <torch/csrc/jit/passes/peephole_non_tensor.h>
19#include <torch/csrc/jit/passes/remove_mutation.h>
20#include <torch/csrc/jit/passes/shape_analysis.h>
21#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
22#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
23#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
24#include <torch/csrc/jit/runtime/exception_message.h>
25#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
26#include <torch/csrc/utils/memory.h>
27#include <algorithm>
28#include <memory>
29#include <numeric>
30#include <unordered_map>
31#include <utility>
32#include <vector>
33
34/*
35XXX: this is still in prototype phase and has much work left to do, including
36but not limited to:
37- Refactor APIs
38- Add decent coverage of common ops
39- Add shape analysis pass on Graph that handles Loops
40- Allow concurrent reads to the operator map
41- Supporting returning partially evaluated shape compute graph
42*/
43
44static bool symbolic_shape_analysis_test_mode = false;
45
46namespace torch {
47namespace jit {
48
49// This is similar to c10::SymbolicShape, but instead of either having
50// a concrete dimension or a symbolic dimension, an argument may be:
51// - A Symbolic Dimension
52// - A Constant Integer
53// - Neither of the above. The third case can occur due to inputs to
54// ops like view that accept negative values. Maintaining the distinction
55// between an unknown symbolic dimension and an unknown integer allows
56// us to optimize out comparisons to values < 0 (symbolic shapes are always >=
57// 0) For example, a call like graph(%y: Tensor(SS(-1), 10, 10), %inp: int):
58// %five: int = prim::Constant[value=5]()
59// %zero: int = prim::Constant[value=0]()
60// %1 : int = aten::size(%y, %zero)
61// %2 : int[] = prim::ListConstruct(%five, %1, %inp)
62// %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2)
63//
64// x.view([5, y.size(0), inp])
65// will have inputs equal to [5, SS(-1), c10::nullopt]
66
67struct ShapeArg
68 : public std::
69 pair<c10::optional<c10::ShapeSymbol>, c10::optional<int64_t>> {
70 using pair::pair;
71
72 static ShapeArg unknownInteger() {
73 return ShapeArg();
74 }
75
76 ShapeArg(int64_t int_value) {
77 this->first = c10::nullopt;
78 this->second = int_value;
79 }
80
81 ShapeArg(c10::ShapeSymbol ss) {
82 if (ss.is_static()) {
83 this->first = c10::nullopt;
84 this->second = ss.value();
85 } else {
86 this->first = ss;
87 this->second = c10::nullopt;
88 }
89 }
90
91 c10::optional<int64_t> asConstantInt() const {
92 return this->second;
93 }
94
95 c10::optional<c10::ShapeSymbol> asShapeSymbol() const {
96 return this->first;
97 }
98
99 private:
100 ShapeArg() {
101 this->first = c10::nullopt;
102 this->second = c10::nullopt;
103 }
104};
105
106std::ostream& operator<<(std::ostream& out, const ShapeArg& sa) {
107 if (auto val = sa.asConstantInt()) {
108 out << *val;
109 } else if (auto ss = sa.asShapeSymbol()) {
110 out << *ss;
111 } else {
112 out << "UNK";
113 }
114 return out;
115}
116
117struct ShapeArguments {
118 // Superset of SymbolicShape, with additional support for unknown, nonsymbolic
119 // vals
120 public:
121 ShapeArguments(const c10::SymbolicShape& ss) {
122 has_dim_ = ss.rank().has_value();
123 if (has_dim_) {
124 for (size_t i = 0; i < *ss.rank(); ++i) {
125 maybe_shape_symbols_.emplace_back(ss.at(i));
126 }
127 }
128 }
129
130 ShapeArguments(std::vector<ShapeArg> ss)
131 : has_dim_(true), maybe_shape_symbols_(std::move(ss)) {}
132
133 bool has_dim() const {
134 return has_dim_;
135 }
136
137 int64_t len() const {
138 TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
139 return (int64_t)maybe_shape_symbols_.size();
140 }
141
142 const ShapeArg at(size_t i) const {
143 TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
144 return maybe_shape_symbols_.at(i);
145 }
146
147 private:
148 bool has_dim_;
149 std::vector<ShapeArg> maybe_shape_symbols_;
150};
151
152std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
153 if (!sa.has_dim()) {
154 os << "(UNKNOWN DIM)";
155 return os;
156 }
157
158 os << "(";
159 for (size_t i = 0; i < sa.len(); i++) {
160 os << sa.at(i);
161 }
162 os << ")";
163
164 return os;
165}
166
167bool setSymbolicShapeAnalysisTestMode(bool value) {
168 bool old_value = symbolic_shape_analysis_test_mode;
169 symbolic_shape_analysis_test_mode = value;
170 return old_value;
171}
172
173bool symbolicShapeAnalysisTestModeEnabled() {
174 return symbolic_shape_analysis_test_mode;
175}
176
177using SSArgument = c10::variant<ShapeArguments, IValue>;
178
179std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
180 if (const IValue* iv = c10::get_if<IValue>(&sa)) {
181 out << *iv;
182 } else {
183 out << c10::get<ShapeArguments>(sa);
184 }
185 return out;
186}
187
188namespace {
189
190bool isListOfInts(const TypePtr& type) {
191 return type->cast<ListType>() &&
192 type->cast<ListType>()->getElementType()->cast<IntType>();
193}
194
195bool isListOfListOfInts(const TypePtr& type) {
196 // Allows List[Optional[List[Int]]]
197 if (!type->cast<ListType>()) {
198 return false;
199 }
200 TypePtr element_type = type->cast<ListType>()->getElementType();
201 if (element_type->cast<OptionalType>()) {
202 element_type = element_type->cast<OptionalType>()->getElementType();
203 }
204 return isListOfInts(element_type);
205}
206
207bool isListOfTensors(const TypePtr& type) {
208 return type->cast<ListType>() &&
209 type->cast<ListType>()->getElementType()->cast<TensorType>();
210}
211
212c10::optional<size_t> normIndex(int64_t index, size_t len) {
213 if (index < 0) {
214 index = index + len;
215 }
216 if (index >= 0 && index < static_cast<int64_t>(len)) {
217 return index;
218 } else {
219 return c10::nullopt;
220 }
221}
222
223bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
224 // TODO: lower simple tuples ?
225 bool made_change = RemoveListMutation(graph);
226 made_change |= UnrollConstantLoops(graph);
227 made_change |= ConstantPropagation(graph);
228 made_change |= PeepholeOptimizeNonTensor(graph);
229 made_change |= PeepholeOptimizeListIdioms(graph, /*refine_list_len*/ true);
230 made_change |= RefineIntegerValues(graph);
231 made_change |= ConstantPropagation(graph);
232 // todo add return change for constant pooling
233 ConstantPooling(graph);
234 made_change |= EliminateCommonSubexpression(graph);
235 EliminateDeadCode(graph);
236 return made_change;
237}
238
239void replaceWithIValue(Value* v, IValue val) {
240 WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
241 v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
242}
243
244c10::SymbolicShape extractListShape(
245 Value* list,
246 std::unordered_map<Value*, int64_t>& symbolic_shape_values,
247 const AliasDb& db) {
248 if (list->node()->kind() == prim::Constant) {
249 auto int_list = toIValue(list)->toIntVector();
250 return c10::SymbolicShape(int_list);
251 }
252 // We need a list construct or a constant output
253 // that is not written to in order to analyze the output shape
254 if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) {
255 GRAPH_DEBUG("Could not extract shape");
256 return c10::SymbolicShape();
257 }
258 Node* list_construct = list->node();
259 std::vector<c10::optional<int64_t>> output_shape;
260 for (Value* input : list_construct->inputs()) {
261 if (symbolic_shape_values.count(input)) {
262 output_shape.emplace_back(symbolic_shape_values[input]);
263 } else {
264 output_shape.push_back(constant_as<int64_t>(input));
265 }
266 }
267 return c10::SymbolicShape(output_shape);
268}
269
270// Symbolic Shape Analysis works through iteratively partially evaluating
271// a TorchScript shape compute graph by inputing properties from input
272// Tensors. We can substitute in properties like `len(x)` and `x[1]`
273// if they are statically on the input Tensors. We can also use
274// assertions like `assert len(x) == 4` in order to refine the input
275// length and unroll loops over its elements. We iteratively optimize and
276// substitute in properties until we are unable to make any further
277// optimizations. Finally, we try to extract Tensor properties from the output.
278// For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the ouptut
279// will be length 4 with first two dimensions equal to 1 and 2. We can also
280// deduce that the 4th dimension has the same symbolic shape as inp[3], which
281// means that we do know its concrete value statically but we can asssign sets
282// of tensor dimensions which must be equal at runtime.
283
284struct SymbolicShapeOpAnalyzer {
285 std::shared_ptr<Graph> shape_compute_graph_;
286 const FunctionSchema* schema_;
287 std::vector<SSArgument> inputs_;
288
289 // For the case where we have a JIT graph,
290 // subsititute optional types for their component types
291 // if the type is known. This doesn't need to be done
292 // for known IValues.
293 void refineInputUnionTypes(const Node* parent_graph_node) {
294 for (size_t op_in_index = 0;
295 op_in_index < shape_compute_graph_->inputs().size();
296 op_in_index++) {
297 auto type = parent_graph_node->input(op_in_index)->type();
298 if (auto opt_type = shape_compute_graph_->inputs()
299 .at(op_in_index)
300 ->type()
301 ->cast<OptionalType>()) {
302 // None will get handled with constant substitution later
303 if (!type->cast<OptionalType>() &&
304 !NoneType::get()->isSubtypeOf(*type)) {
305 shape_compute_graph_->inputs()
306 .at(op_in_index)
307 ->setType(opt_type->getElementType());
308 }
309 } else if (shape_compute_graph_->inputs()
310 .at(op_in_index)
311 ->type()
312 ->cast<NumberType>()) {
313 shape_compute_graph_->inputs().at(op_in_index)->setType(type);
314 }
315 }
316 }
317
318 // We handle non-constant values in the shape propagation step
319 void substituteConstantInputs() {
320 if (shape_compute_graph_->inputs().empty()) {
321 return;
322 }
323
324 bool seen_tensor_list = false;
325
326 size_t op_in_index = 0;
327 while (op_in_index < shape_compute_graph_->inputs().size()) {
328 Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
329 if (!isListOfListOfInts(graph_in_var->type())) {
330 op_in_index++;
331 continue;
332 }
333
334 // Modifying the graph where _node is part of to not use the tensor
335 // construct
336
337 // When we have partially evaluate a list of Tensors like cat(tensor[])
338 // We have a few problems:
339 // - optimizing out calls to the length of the list: len(tensors)
340 // - resolving accesses of the list to the tensor symbolic sizes the
341 // corresponding list element We can solve both of these problems by
342 // replacing the partial evaluation of cat([x, y]) def cat(tensors:
343 // List[List[int]], dim: int)
344 // body
345 // with
346 // def cat(x, y, dim: int)
347 // tensors = [x, y]
348 // body
349 TORCH_INTERNAL_ASSERT(
350 !seen_tensor_list,
351 "SSA doesn't handle case with multiple tensor lists")
352 seen_tensor_list = true;
353
354 uint64_t li_length = inputs_.size() - (schema_->arguments().size() - 1);
355 std::vector<Value*> li_inputs;
356
357 TypePtr element_type =
358 graph_in_var->type()->cast<ListType>()->getElementType();
359 for (size_t j = op_in_index; j < op_in_index + li_length; ++j) {
360 auto new_inp = shape_compute_graph_->insertInput(op_in_index + j);
361 new_inp->setType(element_type);
362 li_inputs.push_back(new_inp);
363 }
364 WithInsertPoint guard(*shape_compute_graph_->block()->nodes().begin());
365 auto new_li = shape_compute_graph_->insertNode(
366 shape_compute_graph_->createList(element_type, li_inputs));
367 graph_in_var->replaceAllUsesWith(new_li->output());
368 shape_compute_graph_->eraseInput(op_in_index + li_length);
369 }
370
371 TORCH_INTERNAL_ASSERT(
372 shape_compute_graph_->inputs().size() <= inputs_.size(),
373 "Shape Compute Graph expected to have less inputs than actual inputs"); //?
374
375 for (size_t op_in_index = 0;
376 op_in_index < shape_compute_graph_->inputs().size();
377 op_in_index++) {
378 SSArgument& argument = inputs_[op_in_index];
379 Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
380
381 if (IValue* cur_val = c10::get_if<IValue>(&argument)) {
382 GRAPH_DEBUG("Substituting constant input ", *cur_val);
383 replaceWithIValue(graph_in_var, *cur_val);
384 } else {
385 auto cur_arg = c10::get<ShapeArguments>(argument);
386 if (cur_arg.has_dim()) {
387 graph_in_var->setType(ListType::ofInts());
388 }
389 }
390 }
391 }
392
393 void substituteSymbolicProperties(
394 std::unordered_map<Value*, int64_t>* symbolic_shape_values) {
395 // clang-format off
396 // here we iteratively substitute properties of the node's input tensors
397 // into the shape compute graph. we can substitute constants into the
398 // like len(inp) or inp[0] if the tensor has a fixed length or a fixed
399 // first dimension. we also try to resolve symbolic shapes of the same
400 // symbolic value to the same Value * in the shape compute graph.
401 // for the shape logic:
402 // dim1 = inp1[0]
403 // dim2 = inp2[0]
404 // return dim1 if dim2 == 1 else dim2
405 // if we see that inp1[0] and inp2[0] both have the same symbolic shape
406 // value, then it is a valid transformation to replace dim2 with dim1 or
407 // vice versa. to do this we collect all Value * for a particular symbolic
408 // shape. Then, we replace all Value * within that set with their dominator.
409 // In the example above, this allows us to infer that the output will be the
410 // symbolic dimension value of dim1.
411
412 // if `symbolic_shape_values` is not null, record list accesses
413 // which resolve to symbolic dimension values with their concrete symbolic
414 // shape value. Because symbolic dimensions are represented as negative numbers and
415 // are not real values, inserting them as constants in the graph would invalidate
416 // the graph for further use. Instead, we keep track of what their value would be
417 // for extracting output shapes.
418 // clang-format on
419
420 std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
421
422 TORCH_INTERNAL_ASSERT(
423 inputs_.size() >= shape_compute_graph_->inputs().size(),
424 "Missing Arg for Shape Graph");
425 for (int64_t index = 0; index < shape_compute_graph_->inputs().size();
426 index++) {
427 auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
428 if (!shape_arguments || !shape_arguments->has_dim()) {
429 continue;
430 }
431 // Add support for testing symbolic shapes with dynamic dims
432
433 for (const Use& use : shape_compute_graph_->inputs().at(index)->uses()) {
434 // TODO: either decompose composite ops like slice or add handling here
435 switch (use.user->kind()) {
436 case aten::len: {
437 size_t len = shape_arguments->len();
438 replaceWithIValue(use.user->output(), static_cast<int64_t>(len));
439 } break;
440 case aten::__getitem__: {
441 auto index = constant_as<int64_t>(use.user->inputs().at(1));
442 if (!index) {
443 continue;
444 }
445 auto norm_index = normIndex(*index, shape_arguments->len());
446 if (!norm_index) {
447 continue;
448 }
449 auto shape_arg = shape_arguments->at(*norm_index);
450 if (auto const_int = shape_arg.asConstantInt()) {
451 replaceWithIValue(use.user->output(), const_int);
452 continue;
453 }
454 auto maybe_shape_symbol = shape_arg.asShapeSymbol();
455 if (!maybe_shape_symbol) {
456 continue;
457 }
458 auto shape_symbol = *maybe_shape_symbol;
459 if (symbolic_shape_values) {
460 symbolic_shape_values->emplace(
461 use.user->output(), shape_symbol.value());
462 } else {
463 int64_t symbolic_index = shape_symbol.value();
464 symbolic_shape_map[symbolic_index].push_back(use.user->output());
465 }
466 for (const auto& sym_uses : use.user->output()->uses()) {
467 auto k = sym_uses.user->kind();
468 if (k != aten::ge && k != aten::le && k != aten::ne &&
469 k != aten::eq && k != aten::lt && k != aten::gt) {
470 break;
471 }
472 auto other_index = 1 - sym_uses.offset;
473 auto other_value =
474 constant_as<int64_t>(sym_uses.user->input(other_index));
475 if (!other_value) {
476 continue;
477 }
478
479 // check for dim >= 0, 0 <= dim
480 // dim >= 0
481 if (k == aten::ge && *other_value == 0 && other_index == 1) {
482 replaceWithIValue(sym_uses.user->output(), true);
483 continue;
484 }
485 // 0 <= dim
486 if (k == aten::le && *other_value == 0 && other_index == 0) {
487 replaceWithIValue(sym_uses.user->output(), true);
488 continue;
489 }
490
491 // check for dim comparisons to negative number
492 if (*other_value >= 0) {
493 continue;
494 }
495 if (k == aten::eq || k == aten::ne) {
496 // True if:
497 // -2 != {Positive}
498 replaceWithIValue(sym_uses.user->output(), k == aten::ne);
499 } else {
500 // True if:
501 // -2 <= / < {Positive}
502 // {Positive} >= / > {-2}
503 bool true_val =
504 ((other_index == 0 && (k == aten::le || k == aten::lt)) ||
505 (other_index == 1 && (k == aten::ge || k == aten::gt)));
506 replaceWithIValue(sym_uses.user->output(), true_val);
507 }
508 }
509 }
510 }
511 }
512
513 for (const auto& symbolic_set : symbolic_shape_map) {
514 mergeSymbolicShapeSets(symbolic_set.second);
515 }
516 }
517 }
518
519 void mergeSymbolicShapeSets(const std::vector<Value*>& symbolic_set) {
520 // `symbolic_set` represents a set of Value * which are all equal
521 // to each other. Here, we optimize the graph by replacing values
522 // in the set with other dominating values.
523 // in the following example, where a, b and c are all in the same
524 // symbolic set:
525 // if cond:
526 // a = li[0]
527 // b = li[1]
528 // return [a, b]
529 // else:
530 // c = li[0]
531 // return [c, c]
532 // we can replace `b` with `a` because it is dominated by `a`,
533 // but we cannot replace `c` with another dominating value
534
535 // there are ways to compute this more efficiently but typically number of
536 // Values for each symbolic set is low and this is cheap to run
537 for (const auto i : c10::irange(symbolic_set.size())) {
538 Value* v = symbolic_set[i];
539 Value* dominating_value = v;
540 for (const auto& sym_set : symbolic_set) {
541 if (dominating_value->node()->isDominatedBy(sym_set->node())) {
542 dominating_value = sym_set;
543 }
544 }
545 if (dominating_value != v) {
546 v->replaceAllUsesWith(dominating_value);
547 }
548 }
549 }
550
551 std::vector<c10::SymbolicShape> propagateShapesInGraph() {
552 bool made_change = true;
553 constexpr size_t MAX_ATTEMPTS = 8;
554 for (int attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
555 attempt_num++) {
556 // symbolic shape concrete values are only used in final shape extraction
557 GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
558 substituteSymbolicProperties(/*symbolic_shape_values*/ nullptr);
559 GRAPH_DUMP("Before Opt: ", shape_compute_graph_);
560 made_change = shapeGraphCleanupPasses(shape_compute_graph_);
561 }
562 std::unordered_map<Value*, int64_t> symbolic_shape_values;
563 substituteSymbolicProperties(&symbolic_shape_values);
564 GRAPH_DUMP("Done with partial evaluation", shape_compute_graph_);
565
566 return extractOutputShape(symbolic_shape_values);
567 }
568
569 std::vector<c10::SymbolicShape> extractOutputShape(
570 std::unordered_map<Value*, int64_t>& symbolic_shape_values) {
571 TORCH_INTERNAL_ASSERT(
572 shape_compute_graph_->outputs().size() == schema_->returns().size());
573 // TODO: would be nice if there were easy facility to look at uses and see
574 // if they are all pure instead of instanting db.
575 auto res = std::vector<c10::SymbolicShape>();
576 AliasDb db(shape_compute_graph_);
577 for (size_t i = 0; i < shape_compute_graph_->outputs().size(); ++i) {
578 auto output = shape_compute_graph_->outputs().at(i);
579 auto type = output->type();
580 TORCH_INTERNAL_ASSERT(isListOfInts(type));
581 c10::SymbolicShape ss =
582 extractListShape(output, symbolic_shape_values, db);
583 GRAPH_DEBUG("Extracted Output: ", ss);
584 res.push_back(ss);
585 }
586 return res;
587 }
588
589 public:
590 SymbolicShapeOpAnalyzer(const FunctionSchema* schema) : schema_(schema) {
591 shape_compute_graph_ = nullptr;
592 if (!schema_) {
593 return;
594 }
595 auto maybe_graph = shapeComputeGraphForSchema(*schema_);
596 if (!maybe_graph) {
597 return;
598 }
599 shape_compute_graph_ = (*maybe_graph)->copy();
600 }
601
602 SymbolicShapeOpAnalyzer(
603 const FunctionSchema* schema,
604 std::shared_ptr<Graph> graph)
605 : schema_(schema) {
606 shape_compute_graph_ = graph->copy();
607 }
608
609 c10::optional<std::vector<c10::SymbolicShape>> run(
610 std::vector<SSArgument>& inputs) {
611 if (!shape_compute_graph_) {
612 return c10::nullopt;
613 }
614 inputs_ = inputs;
615 substituteConstantInputs();
616 GRAPH_DEBUG(inputs_)
617 return propagateShapesInGraph();
618 }
619
620 std::shared_ptr<Graph> getShapeComputeGraph() {
621 return shape_compute_graph_;
622 }
623};
624
625SSArgument tensorShapeArg(Value* tensor_v) {
626 auto tt = tensor_v->type()->expect<TensorType>();
627 c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
628
629 // for testing, we don't insert complete tensor shapes and rely on our
630 // partial evaluation pipeline to propagate information.
631 // this is a good proxy for our ability to propagate non-complete shape
632 // information.
633 if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) {
634 return IValue(tt->sizes().concrete_sizes());
635 }
636 if (toIValue(tensor_v)) {
637 auto size = constant_as<at::Tensor>(tensor_v)->sizes();
638 if (!symbolic_shape_analysis_test_mode) {
639 return IValue(size);
640 } else {
641 return c10::SymbolicShape(size);
642 }
643 }
644 return symbolic_shapes;
645}
646
647std::vector<SSArgument> getNodeInputShapes(Node* n, const AliasDb& db) {
648 // TODO: fix the List of integers implementation, and
649 // extract out the shape changes, otherwise this is complete
650 // NB: shape compute graphs may have less inputs than their node
651 // counterparts to allow e.g. sharing one single unary definition
652 // so iterate on # of shape inputs
653 // We make lists of Tensor inputs variadic, which results in
654 // offset between a node index and its corresponding graph index
655 std::vector<SSArgument> input_shapes = std::vector<SSArgument>();
656
657 for (size_t node_index = 0; node_index < n->inputs().size(); ++node_index) {
658 auto type = n->input(node_index)->type();
659
660 if (type->castRaw<TensorType>()) {
661 input_shapes.push_back(tensorShapeArg(n->input(node_index)));
662 continue;
663 }
664 if (isListOfTensors(type)) {
665 // waiting for more use cases to decide on best generalization
666 if (n->input(node_index)->node()->kind() == prim::Constant) {
667 auto ival = toIValue(n->input(node_index));
668 for (const auto& ten : ival->toTensorVector()) {
669 input_shapes.emplace_back(c10::List<int64_t>(ten.sizes()));
670 }
671 } else if (
672 n->input(node_index)->node()->kind() == prim::ListConstruct &&
673 !db.hasWriters(n->input(node_index))) {
674 auto li_construct_node = n->input(node_index)->node();
675 for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
676 input_shapes.push_back(tensorShapeArg(li_construct_node->input(j)));
677 }
678 } else {
679 TORCH_INTERNAL_ASSERT(false, "Unhandled List, we shouldn't get here");
680 }
681 continue;
682 }
683 if (auto ival = toIValue(n->input(node_index))) {
684 input_shapes.emplace_back(*ival);
685 continue;
686 }
687 if (type->cast<ListType>() &&
688 type->cast<ListType>()->getElementType()->cast<IntType>()) {
689 auto input_src_node = n->input(node_index)->node();
690 if (input_src_node->kind() == prim::ListConstruct &&
691 !db.hasWriters(n->input(node_index))) {
692 // it is a very common in graphs to see patterns like:
693 // z = x.view(y.size())
694 // or:
695 // z = x.view(1, 10, y.size(0), y.size(1))
696 // We want to propagate symbolic dimensions and concrete sizes
697 // from y to z. To do this we try to associate symbolic dimensions
698 // or concrete sizes with the integer list inputs that have a
699 // constructor taken from constants or y.size() or y.size(0)
700 auto list_construct = n->input(node_index)->node();
701 std::vector<ShapeArg> shape;
702 for (Value* v : list_construct->inputs()) {
703 if (auto constant = constant_as<int64_t>(v)) {
704 shape.emplace_back(*constant);
705 } else if (v->node()->kind() == aten::size) {
706 auto const_index = constant_as<int64_t>(v->node()->input(1));
707 auto tt = v->node()->input(0)->type()->expect<TensorType>();
708 auto ss = tt->symbolic_sizes();
709 if (!ss.rank() || !const_index) {
710 // if we are getting a size of a tensor, it is an unknown
711 // symbolic dimension instead of an unknown integer (must be
712 // >=0)
713 shape.emplace_back(at::ShapeSymbol::newSymbol());
714 continue;
715 }
716 auto norm_index = normIndex(*const_index, *ss.rank());
717 if (!norm_index) {
718 shape.emplace_back(at::ShapeSymbol::newSymbol());
719 continue;
720 }
721 shape.emplace_back(ss[*norm_index]);
722 } else {
723 shape.emplace_back(ShapeArg::unknownInteger());
724 }
725 }
726 input_shapes.emplace_back(ShapeArguments(shape));
727 continue;
728 }
729 if (input_src_node->kind() == aten::size &&
730 !db.hasWriters(n->input(node_index))) {
731 auto ten_inp = input_src_node->input();
732 auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
733 input_shapes.emplace_back(ss);
734 continue;
735 }
736 }
737 GRAPH_DEBUG(
738 "Unhandled input: ",
739 n->kind().toDisplayString(),
740 " arg num: ",
741 node_index);
742 input_shapes.emplace_back(c10::SymbolicShape());
743 }
744 TORCH_INTERNAL_ASSERT(
745 input_shapes.size() >= n->inputs().size(),
746 "input_shapes size: ",
747 input_shapes.size(),
748 " n inputs size: ",
749 n->inputs().size());
750 return input_shapes;
751}
752
753void applyOutputShapeToGraph(
754 Node* node,
755 const std::vector<c10::SymbolicShape>& output_shapes) {
756 TORCH_INTERNAL_ASSERT(
757 node->outputs().size() == output_shapes.size(),
758 "Output shape size mismatch");
759 for (size_t i = 0; i < output_shapes.size(); ++i) {
760 auto& ss = output_shapes.at(i);
761 node->output(i)->setType(
762 node->output(i)->type()->expect<TensorType>()->withSymbolicShapes(ss));
763 }
764}
765
766std::shared_ptr<Graph> PropagateShapesWithShapeFunction(
767 Node* n,
768 const AliasDb& db) {
769 const FunctionSchema* func_schema = n->maybeSchema();
770 if (!func_schema) {
771 return nullptr;
772 }
773 auto op_analyzer = SymbolicShapeOpAnalyzer(func_schema);
774 if (!op_analyzer.getShapeComputeGraph()) {
775 return nullptr;
776 }
777 auto input_shapes = getNodeInputShapes(n, db);
778 op_analyzer.refineInputUnionTypes(n);
779
780 if (auto output_shapes = op_analyzer.run(input_shapes)) {
781 applyOutputShapeToGraph(n, *output_shapes);
782 }
783
784 return op_analyzer.getShapeComputeGraph();
785}
786
787c10::SymbolicShape combine_bounds(
788 c10::SymbolicShape& lower_bound,
789 c10::SymbolicShape& upper_bound) {
790 // TODO: At some point we might want to add support for dynamic dims
791 TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
792 if (lower_bound.rank() == c10::nullopt) {
793 return c10::SymbolicShape();
794 }
795 std::vector<c10::ShapeSymbol> merged_shapes;
796 for (int i = 0; i < lower_bound.rank(); i++) {
797 // TODO: Merge equivalent expressions (not needed for current use case)
798 if (lower_bound[i] == upper_bound[i]) {
799 merged_shapes.push_back(lower_bound[i]);
800 } else {
801 merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
802 }
803 }
804 return c10::SymbolicShape(std::move(merged_shapes));
805}
806
807struct SymbolicShapeGraphAnalyzer {
808 SymbolicShapeGraphAnalyzer(
809 std::shared_ptr<Graph>& graph,
810 Node* beg,
811 Node* end)
812 : graph_(graph), beg_(beg), end_(end) {
813 TORCH_INTERNAL_ASSERT(
814 beg_->owningBlock() == end_->owningBlock() && end_->isAfter(beg_));
815 }
816
817 c10::optional<ShapeComputeGraphMapping> run() {
818 AliasDb db(graph_);
819 std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs =
820 propagateShapesAndGatherPartialEvalShapeGraphs(db);
821
822 auto stitched_shape_compute_graph = std::make_shared<Graph>();
823 // We want to build up a computational graph which computes all shapes
824 // we dont know statically - that is, all symbolic shapes within
825 // the region [beg, end). it must be executable before beg.
826 // TODO: dont require dimensions of tensors to be set AOT ?
827
828 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
829 auto curr = *it;
830 if (curr->kind() == prim::Constant) {
831 continue;
832 }
833 // TODO: generalize logic to for other tensor input ops when they are
834 // added
835 if (curr->kind() == prim::ListConstruct) {
836 auto uses = curr->output()->uses();
837 if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
838 return use.user->kind() == aten::cat;
839 })) {
840 GRAPH_DEBUG("Non cat list use ", getHeader(curr));
841 return c10::nullopt;
842 }
843 continue;
844 }
845
846 if (!partial_evaluated_graphs.count(curr)) {
847 GRAPH_DEBUG("No graph ", getHeader(curr));
848 return c10::nullopt;
849 }
850
851 auto outputs = curr->outputs();
852 for (Value* v : outputs) {
853 auto tt = v->type()->cast<TensorType>();
854 if (!tt) {
855 GRAPH_DEBUG("Non tensor node", getHeader(curr));
856 return c10::nullopt;
857 }
858 auto symbolic_sizes = tt->symbolic_sizes();
859 // TODO: dont require # of dimensions of tensors set ?
860 if (!symbolic_sizes.rank()) {
861 GRAPH_DEBUG("No rank on output ", getHeader(curr));
862 return c10::nullopt;
863 }
864 }
865 auto partial_eval_graph = partial_evaluated_graphs[curr];
866 joinPartialEvaluatedShapeGraphToLargeShapeGraph(
867 curr, partial_eval_graph, stitched_shape_compute_graph);
868 }
869
870 size_t MAX_ITER = 8;
871 bool made_change = true;
872 size_t i = 0;
873 while (i < MAX_ITER && made_change) {
874 i++;
875 made_change = shapeGraphCleanupPasses(stitched_shape_compute_graph);
876 }
877
878 // for any output that is duplicated, the symbolic shape must be equal
879 // take the symbolic shape that is generated first and get equivalent ones
880 std::unordered_map<int64_t, int64_t> discovered_sym_shape_equalities;
881 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim;
882 std::vector<size_t> erase_indices;
883
884 for (size_t i = 0; i < stitched_shape_compute_graph->outputs().size();
885 ++i) {
886 Value* output = stitched_shape_compute_graph->outputs().at(i);
887 // this Value is already contained, so the symbolic shape for i must be
888 // equal to the symbolic shape at the existing index
889 if (graph_output_to_symbolic_shape_dim.count(output)) {
890 auto curr_sym_shape = output_index_to_symbolic_shape_[i];
891 auto existing_sym_shape = graph_output_to_symbolic_shape_dim[output];
892 discovered_sym_shape_equalities[curr_sym_shape] = existing_sym_shape;
893 erase_indices.push_back(i);
894 } else {
895 graph_output_to_symbolic_shape_dim[output] =
896 output_index_to_symbolic_shape_[i];
897 }
898 }
899 for (int64_t i = erase_indices.size() - 1; i >= 0; i--) {
900 stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
901 }
902 for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
903 if (!stitched_shape_compute_graph->inputs().at(i)->hasUses()) {
904 enclosing_graph_value_to_shape_graph_input_.erase(
905 stitched_shape_compute_graph->inputs().at(i));
906 stitched_shape_compute_graph->eraseInput(i);
907 } else {
908 ++i;
909 }
910 }
911
912 updateGraphWithSymbolicShapeEqualities(discovered_sym_shape_equalities);
913 return ShapeComputeGraphMapping(
914 std::move(stitched_shape_compute_graph),
915 enclosing_graph_value_to_shape_graph_input_,
916 std::move(graph_output_to_symbolic_shape_dim));
917 }
918
919 void updateGraphWithSymbolicShapeEqualities(
920 std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
921 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
922 auto curr = *it;
923 for (size_t i = 0; i < curr->outputs().size(); ++i) {
924 auto output = curr->output(i);
925 auto tt = output->type()->cast<TensorType>();
926 if (!tt || !tt->symbolic_sizes().rank()) {
927 continue;
928 }
929 bool changed = false;
930 std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
931 auto new_sizes =
932 c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
933 auto value = shape.value();
934 if (sym_shape_equalities.count(value)) {
935 changed = true;
936 return sym_shape_equalities[value];
937 }
938 return value;
939 });
940 if (changed) {
941 output->setType(
942 tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
943 }
944 }
945 }
946 }
947
948 void registerStitchedComputeOutput(
949 std::shared_ptr<Graph> stitched_shape_compute_graph,
950 Value* output,
951 int64_t symbolic_shape) {
952 stitched_shape_compute_graph->registerOutput(output);
953 output_index_to_symbolic_shape_
954 [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
955 symbolic_shape_value_to_graph_output_[symbolic_shape] =
956 stitched_shape_compute_graph->outputs().at(
957 stitched_shape_compute_graph->outputs().size() - 1);
958 }
959
960 void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
961 Node* curr,
962 std::shared_ptr<Graph> partial_eval_graph,
963 std::shared_ptr<Graph> stitched_shape_compute_graph) {
964 // we are building up the large shape compute graph by iteratively
965 // combining partially evaluated individual node shape graphs.
966
967 // We need to maintain two mappings, one from non-Tensor inputs in the
968 // enclosing graph to their equivalent mappings within the large shape
969 // compute graph, and one from symbolic shape dimension to new node output
970
971 // When we add a new tensor node, we do two things:
972 // 1: record a mapping from the tensor node output to its shape in the
973 // partial eval graph 2: add each symbolic shape dimension that we have
974 // not already added as a output to the large shape compute graph
975
976 // Once we are done stitching together all partial eval'd graphs, we can
977 // cleanup the graph and remove the unneeded complete shapes as outputs,
978 // leaving us only compute for calculating the runtime value of symbolic
979 // dimensions
980 // leaving us only compute for calculating the runtime value of symbolic
981 // dimensions
982
983 std::vector<Value*> node_inputs;
984 // TODO: generalize logic
985 if (curr->kind() == aten::cat) {
986 TORCH_INTERNAL_ASSERT(
987 curr->input(0)->node()->kind() == prim::ListConstruct);
988 for (Value* v : curr->input(0)->node()->inputs()) {
989 node_inputs.push_back(v);
990 }
991 node_inputs.push_back(curr->namedInput("dim"));
992 } else {
993 for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
994 node_inputs.push_back(curr->input(i));
995 }
996 }
997
998 std::vector<Value*> partial_eval_inputs;
999 for (size_t i = 0; i < node_inputs.size(); ++i) {
1000 auto node_input = node_inputs[i];
1001 auto existing_graph_mapping =
1002 enclosing_graph_value_to_shape_graph_input_.find(node_input);
1003 if (existing_graph_mapping !=
1004 enclosing_graph_value_to_shape_graph_input_.end()) {
1005 partial_eval_inputs.push_back(existing_graph_mapping->second);
1006 } else {
1007 Value* shape_graph_input =
1008 stitched_shape_compute_graph->addInput()->copyMetadata(
1009 partial_eval_graph->inputs().at(i));
1010 enclosing_graph_value_to_shape_graph_input_[node_input] =
1011 shape_graph_input;
1012 partial_eval_inputs.push_back(shape_graph_input);
1013 }
1014 // make sure all symbolic dimensions in the graph we are creating are
1015 // computed in the partial eval graph
1016 if (auto tt = node_input->type()->cast<TensorType>()) {
1017 if (!tt->symbolic_sizes().rank()) {
1018 continue;
1019 }
1020 auto rank = *tt->symbolic_sizes().rank();
1021 for (size_t j = 0; j < rank; ++j) {
1022 auto shape = tt->symbolic_sizes()[j];
1023 if (shape.is_static() ||
1024 symbolic_shape_value_to_graph_output_.count(shape.value())) {
1025 continue;
1026 }
1027 auto input = enclosing_graph_value_to_shape_graph_input_[node_input];
1028 WithInsertPoint guard(stitched_shape_compute_graph->block());
1029 auto index = stitched_shape_compute_graph->insertConstant(
1030 static_cast<int64_t>(j));
1031 auto li_index = stitched_shape_compute_graph->insert(
1032 aten::__getitem__, {input, index});
1033 registerStitchedComputeOutput(
1034 stitched_shape_compute_graph, li_index, shape.value());
1035 }
1036 }
1037 }
1038
1039 WithInsertPoint guard(stitched_shape_compute_graph->block());
1040 std::unordered_map<Value*, Value*> value_map;
1041 insertGraph(
1042 *stitched_shape_compute_graph,
1043 *partial_eval_graph,
1044 partial_eval_inputs,
1045 value_map);
1046
1047 for (size_t i = 0; i < curr->outputs().size(); ++i) {
1048 Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
1049 enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
1050 new_list_output;
1051
1052 TORCH_INTERNAL_ASSERT(
1053 new_list_output->node()->kind() == prim::ListConstruct ||
1054 new_list_output->node()->kind() == prim::Constant);
1055 TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
1056
1057 auto symbolic_sizes =
1058 curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
1059 TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
1060
1061 for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
1062 if (symbolic_sizes[i].is_static()) {
1063 continue;
1064 }
1065 int64_t symbolic_shape = symbolic_sizes[i].value();
1066 if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
1067 continue;
1068 }
1069 registerStitchedComputeOutput(
1070 stitched_shape_compute_graph,
1071 new_list_output->node()->input(i),
1072 symbolic_shape);
1073 }
1074 }
1075 }
1076
1077 std::unordered_map<Node*, std::shared_ptr<Graph>>
1078 propagateShapesAndGatherPartialEvalShapeGraphs(AliasDb& db) {
1079 std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs;
1080 for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
1081 auto curr = *it;
1082 if (auto maybe_graph = PropagateShapesWithShapeFunction(curr, db)) {
1083 partial_evaluated_graphs[curr] = maybe_graph;
1084 }
1085 }
1086 return partial_evaluated_graphs;
1087 }
1088
1089 std::unordered_map<Value*, Value*>
1090 enclosing_graph_value_to_shape_graph_input_;
1091 std::unordered_map<int64_t, Value*> symbolic_shape_value_to_graph_output_;
1092 std::unordered_map<size_t, int64_t> output_index_to_symbolic_shape_;
1093
1094 std::shared_ptr<Graph>& graph_;
1095 Node* beg_;
1096 Node* end_;
1097};
1098
1099void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
1100 for (Node* n : b->nodes()) {
1101 // TODO: handle loop
1102 if (n->kind() == prim::If) {
1103 IfView if_v(n);
1104 PropagateShapesOnBlock(if_v.thenBlock(), db);
1105 PropagateShapesOnBlock(if_v.elseBlock(), db);
1106 mergeTypes(if_v.thenOutputs(), if_v.elseOutputs(), if_v.outputs());
1107 } else if (n->maybeSchema()) {
1108 PropagateShapesWithShapeFunction(n, db);
1109 } else if (n->kind() == prim::TupleConstruct) {
1110 auto orig_type = n->output()->type()->expect<TupleType>();
1111 auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); });
1112 n->output()->setType(
1113 orig_type->createWithContained(std::move(new_types)));
1114 }
1115 }
1116}
1117} // namespace
1118
1119void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
1120 AliasDb db(graph);
1121 PropagateShapesOnBlock(graph->block(), db);
1122}
1123
1124c10::optional<ShapeComputeGraphMapping>
1125PropagateShapesAndBuildLargeShapeComputeGraph(
1126 std::shared_ptr<Graph>& graph,
1127 Node* beg,
1128 Node* end) {
1129 return SymbolicShapeGraphAnalyzer(graph, beg, end).run();
1130}
1131
1132TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
1133calculateSymbolicShapesOnOp(
1134 const FunctionSchema* schema,
1135 const std::vector<SSAInput>& inputs) {
1136 auto bounded_graphs = boundedGraphsForSchema(*schema);
1137 auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt;
1138 if (!has_shape_compute && bounded_graphs == c10::nullopt) {
1139 // Avoid doing all this work for functions that don't have a
1140 // supported schema
1141 return c10::nullopt;
1142 }
1143
1144 if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
1145 return cached_ret_vec;
1146 }
1147
1148 std::vector<SSArgument> ssa_args;
1149 for (auto& arg : inputs) {
1150 if (const IValue* ival = c10::get_if<IValue>(&arg)) {
1151 ssa_args.emplace_back(*ival);
1152 } else {
1153 const c10::SymbolicShape* ss = c10::get_if<c10::SymbolicShape>(&arg);
1154 ssa_args.emplace_back(ShapeArguments(*ss));
1155 }
1156 }
1157 // Handle bounded shape option
1158 if (bounded_graphs) {
1159 auto lower_bound =
1160 SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
1161 auto lower_bound_res = lower_bound.run(ssa_args);
1162 auto upper_bound =
1163 SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
1164 auto upper_bound_res = upper_bound.run(ssa_args);
1165 // Stitch together the values
1166 if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
1167 TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
1168 auto merged_res = std::vector<c10::SymbolicShape>();
1169 for (size_t i = 0; i < lower_bound_res->size(); i++) {
1170 merged_res.push_back(
1171 combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
1172 }
1173 cache_shape_function(schema, inputs, merged_res);
1174 return merged_res;
1175 }
1176 return c10::nullopt;
1177 }
1178
1179 auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
1180 auto res = op_analyzer.run(ssa_args);
1181 if (res.has_value()) {
1182 cache_shape_function(schema, inputs, res.value());
1183 }
1184 return res;
1185}
1186
1187} // namespace jit
1188} // namespace torch
1189