1 | #include <ATen/core/symbol.h> |
2 | #include <c10/util/Exception.h> |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/ir/constants.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/ir/ir_views.h> |
8 | #include <torch/csrc/jit/jit_log.h> |
9 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
10 | #include <torch/csrc/jit/passes/constant_pooling.h> |
11 | #include <torch/csrc/jit/passes/constant_propagation.h> |
12 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
13 | #include <torch/csrc/jit/passes/integer_value_refinement.h> |
14 | #include <torch/csrc/jit/passes/loop_unrolling.h> |
15 | #include <torch/csrc/jit/passes/lower_tuples.h> |
16 | #include <torch/csrc/jit/passes/peephole.h> |
17 | #include <torch/csrc/jit/passes/peephole_list_idioms.h> |
18 | #include <torch/csrc/jit/passes/peephole_non_tensor.h> |
19 | #include <torch/csrc/jit/passes/remove_mutation.h> |
20 | #include <torch/csrc/jit/passes/shape_analysis.h> |
21 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
22 | #include <torch/csrc/jit/passes/symbolic_shape_cache.h> |
23 | #include <torch/csrc/jit/passes/tensorexpr_fuser.h> |
24 | #include <torch/csrc/jit/runtime/exception_message.h> |
25 | #include <torch/csrc/jit/runtime/symbolic_shape_registry.h> |
26 | #include <torch/csrc/utils/memory.h> |
27 | #include <algorithm> |
28 | #include <memory> |
29 | #include <numeric> |
30 | #include <unordered_map> |
31 | #include <utility> |
32 | #include <vector> |
33 | |
34 | /* |
35 | XXX: this is still in prototype phase and has much work left to do, including |
36 | but not limited to: |
37 | - Refactor APIs |
38 | - Add decent coverage of common ops |
39 | - Add shape analysis pass on Graph that handles Loops |
40 | - Allow concurrent reads to the operator map |
41 | - Supporting returning partially evaluated shape compute graph |
42 | */ |
43 | |
44 | static bool symbolic_shape_analysis_test_mode = false; |
45 | |
46 | namespace torch { |
47 | namespace jit { |
48 | |
49 | // This is similar to c10::SymbolicShape, but instead of either having |
50 | // a concrete dimension or a symbolic dimension, an argument may be: |
51 | // - A Symbolic Dimension |
52 | // - A Constant Integer |
53 | // - Neither of the above. The third case can occur due to inputs to |
54 | // ops like view that accept negative values. Maintaining the distinction |
55 | // between an unknown symbolic dimension and an unknown integer allows |
56 | // us to optimize out comparisons to values < 0 (symbolic shapes are always >= |
57 | // 0) For example, a call like graph(%y: Tensor(SS(-1), 10, 10), %inp: int): |
58 | // %five: int = prim::Constant[value=5]() |
59 | // %zero: int = prim::Constant[value=0]() |
60 | // %1 : int = aten::size(%y, %zero) |
61 | // %2 : int[] = prim::ListConstruct(%five, %1, %inp) |
62 | // %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2) |
63 | // |
64 | // x.view([5, y.size(0), inp]) |
65 | // will have inputs equal to [5, SS(-1), c10::nullopt] |
66 | |
67 | struct ShapeArg |
68 | : public std:: |
69 | pair<c10::optional<c10::ShapeSymbol>, c10::optional<int64_t>> { |
70 | using pair::pair; |
71 | |
72 | static ShapeArg unknownInteger() { |
73 | return ShapeArg(); |
74 | } |
75 | |
76 | ShapeArg(int64_t int_value) { |
77 | this->first = c10::nullopt; |
78 | this->second = int_value; |
79 | } |
80 | |
81 | ShapeArg(c10::ShapeSymbol ss) { |
82 | if (ss.is_static()) { |
83 | this->first = c10::nullopt; |
84 | this->second = ss.value(); |
85 | } else { |
86 | this->first = ss; |
87 | this->second = c10::nullopt; |
88 | } |
89 | } |
90 | |
91 | c10::optional<int64_t> asConstantInt() const { |
92 | return this->second; |
93 | } |
94 | |
95 | c10::optional<c10::ShapeSymbol> asShapeSymbol() const { |
96 | return this->first; |
97 | } |
98 | |
99 | private: |
100 | ShapeArg() { |
101 | this->first = c10::nullopt; |
102 | this->second = c10::nullopt; |
103 | } |
104 | }; |
105 | |
106 | std::ostream& operator<<(std::ostream& out, const ShapeArg& sa) { |
107 | if (auto val = sa.asConstantInt()) { |
108 | out << *val; |
109 | } else if (auto ss = sa.asShapeSymbol()) { |
110 | out << *ss; |
111 | } else { |
112 | out << "UNK" ; |
113 | } |
114 | return out; |
115 | } |
116 | |
117 | struct ShapeArguments { |
118 | // Superset of SymbolicShape, with additional support for unknown, nonsymbolic |
119 | // vals |
120 | public: |
121 | ShapeArguments(const c10::SymbolicShape& ss) { |
122 | has_dim_ = ss.rank().has_value(); |
123 | if (has_dim_) { |
124 | for (size_t i = 0; i < *ss.rank(); ++i) { |
125 | maybe_shape_symbols_.emplace_back(ss.at(i)); |
126 | } |
127 | } |
128 | } |
129 | |
130 | ShapeArguments(std::vector<ShapeArg> ss) |
131 | : has_dim_(true), maybe_shape_symbols_(std::move(ss)) {} |
132 | |
133 | bool has_dim() const { |
134 | return has_dim_; |
135 | } |
136 | |
137 | int64_t len() const { |
138 | TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim" ) |
139 | return (int64_t)maybe_shape_symbols_.size(); |
140 | } |
141 | |
142 | const ShapeArg at(size_t i) const { |
143 | TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim" ) |
144 | return maybe_shape_symbols_.at(i); |
145 | } |
146 | |
147 | private: |
148 | bool has_dim_; |
149 | std::vector<ShapeArg> maybe_shape_symbols_; |
150 | }; |
151 | |
152 | std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) { |
153 | if (!sa.has_dim()) { |
154 | os << "(UNKNOWN DIM)" ; |
155 | return os; |
156 | } |
157 | |
158 | os << "(" ; |
159 | for (size_t i = 0; i < sa.len(); i++) { |
160 | os << sa.at(i); |
161 | } |
162 | os << ")" ; |
163 | |
164 | return os; |
165 | } |
166 | |
167 | bool setSymbolicShapeAnalysisTestMode(bool value) { |
168 | bool old_value = symbolic_shape_analysis_test_mode; |
169 | symbolic_shape_analysis_test_mode = value; |
170 | return old_value; |
171 | } |
172 | |
173 | bool symbolicShapeAnalysisTestModeEnabled() { |
174 | return symbolic_shape_analysis_test_mode; |
175 | } |
176 | |
177 | using SSArgument = c10::variant<ShapeArguments, IValue>; |
178 | |
179 | std::ostream& operator<<(std::ostream& out, const SSArgument& sa) { |
180 | if (const IValue* iv = c10::get_if<IValue>(&sa)) { |
181 | out << *iv; |
182 | } else { |
183 | out << c10::get<ShapeArguments>(sa); |
184 | } |
185 | return out; |
186 | } |
187 | |
188 | namespace { |
189 | |
190 | bool isListOfInts(const TypePtr& type) { |
191 | return type->cast<ListType>() && |
192 | type->cast<ListType>()->getElementType()->cast<IntType>(); |
193 | } |
194 | |
195 | bool isListOfListOfInts(const TypePtr& type) { |
196 | // Allows List[Optional[List[Int]]] |
197 | if (!type->cast<ListType>()) { |
198 | return false; |
199 | } |
200 | TypePtr element_type = type->cast<ListType>()->getElementType(); |
201 | if (element_type->cast<OptionalType>()) { |
202 | element_type = element_type->cast<OptionalType>()->getElementType(); |
203 | } |
204 | return isListOfInts(element_type); |
205 | } |
206 | |
207 | bool isListOfTensors(const TypePtr& type) { |
208 | return type->cast<ListType>() && |
209 | type->cast<ListType>()->getElementType()->cast<TensorType>(); |
210 | } |
211 | |
212 | c10::optional<size_t> normIndex(int64_t index, size_t len) { |
213 | if (index < 0) { |
214 | index = index + len; |
215 | } |
216 | if (index >= 0 && index < static_cast<int64_t>(len)) { |
217 | return index; |
218 | } else { |
219 | return c10::nullopt; |
220 | } |
221 | } |
222 | |
223 | bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) { |
224 | // TODO: lower simple tuples ? |
225 | bool made_change = RemoveListMutation(graph); |
226 | made_change |= UnrollConstantLoops(graph); |
227 | made_change |= ConstantPropagation(graph); |
228 | made_change |= PeepholeOptimizeNonTensor(graph); |
229 | made_change |= PeepholeOptimizeListIdioms(graph, /*refine_list_len*/ true); |
230 | made_change |= RefineIntegerValues(graph); |
231 | made_change |= ConstantPropagation(graph); |
232 | // todo add return change for constant pooling |
233 | ConstantPooling(graph); |
234 | made_change |= EliminateCommonSubexpression(graph); |
235 | EliminateDeadCode(graph); |
236 | return made_change; |
237 | } |
238 | |
239 | void replaceWithIValue(Value* v, IValue val) { |
240 | WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin()); |
241 | v->replaceAllUsesWith(v->owningGraph()->insertConstant(val)); |
242 | } |
243 | |
244 | c10::SymbolicShape ( |
245 | Value* list, |
246 | std::unordered_map<Value*, int64_t>& symbolic_shape_values, |
247 | const AliasDb& db) { |
248 | if (list->node()->kind() == prim::Constant) { |
249 | auto int_list = toIValue(list)->toIntVector(); |
250 | return c10::SymbolicShape(int_list); |
251 | } |
252 | // We need a list construct or a constant output |
253 | // that is not written to in order to analyze the output shape |
254 | if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) { |
255 | GRAPH_DEBUG("Could not extract shape" ); |
256 | return c10::SymbolicShape(); |
257 | } |
258 | Node* list_construct = list->node(); |
259 | std::vector<c10::optional<int64_t>> output_shape; |
260 | for (Value* input : list_construct->inputs()) { |
261 | if (symbolic_shape_values.count(input)) { |
262 | output_shape.emplace_back(symbolic_shape_values[input]); |
263 | } else { |
264 | output_shape.push_back(constant_as<int64_t>(input)); |
265 | } |
266 | } |
267 | return c10::SymbolicShape(output_shape); |
268 | } |
269 | |
270 | // Symbolic Shape Analysis works through iteratively partially evaluating |
271 | // a TorchScript shape compute graph by inputing properties from input |
272 | // Tensors. We can substitute in properties like `len(x)` and `x[1]` |
273 | // if they are statically on the input Tensors. We can also use |
274 | // assertions like `assert len(x) == 4` in order to refine the input |
275 | // length and unroll loops over its elements. We iteratively optimize and |
276 | // substitute in properties until we are unable to make any further |
277 | // optimizations. Finally, we try to extract Tensor properties from the output. |
278 | // For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the ouptut |
279 | // will be length 4 with first two dimensions equal to 1 and 2. We can also |
280 | // deduce that the 4th dimension has the same symbolic shape as inp[3], which |
281 | // means that we do know its concrete value statically but we can asssign sets |
282 | // of tensor dimensions which must be equal at runtime. |
283 | |
284 | struct SymbolicShapeOpAnalyzer { |
285 | std::shared_ptr<Graph> shape_compute_graph_; |
286 | const FunctionSchema* schema_; |
287 | std::vector<SSArgument> inputs_; |
288 | |
289 | // For the case where we have a JIT graph, |
290 | // subsititute optional types for their component types |
291 | // if the type is known. This doesn't need to be done |
292 | // for known IValues. |
293 | void refineInputUnionTypes(const Node* parent_graph_node) { |
294 | for (size_t op_in_index = 0; |
295 | op_in_index < shape_compute_graph_->inputs().size(); |
296 | op_in_index++) { |
297 | auto type = parent_graph_node->input(op_in_index)->type(); |
298 | if (auto opt_type = shape_compute_graph_->inputs() |
299 | .at(op_in_index) |
300 | ->type() |
301 | ->cast<OptionalType>()) { |
302 | // None will get handled with constant substitution later |
303 | if (!type->cast<OptionalType>() && |
304 | !NoneType::get()->isSubtypeOf(*type)) { |
305 | shape_compute_graph_->inputs() |
306 | .at(op_in_index) |
307 | ->setType(opt_type->getElementType()); |
308 | } |
309 | } else if (shape_compute_graph_->inputs() |
310 | .at(op_in_index) |
311 | ->type() |
312 | ->cast<NumberType>()) { |
313 | shape_compute_graph_->inputs().at(op_in_index)->setType(type); |
314 | } |
315 | } |
316 | } |
317 | |
318 | // We handle non-constant values in the shape propagation step |
319 | void substituteConstantInputs() { |
320 | if (shape_compute_graph_->inputs().empty()) { |
321 | return; |
322 | } |
323 | |
324 | bool seen_tensor_list = false; |
325 | |
326 | size_t op_in_index = 0; |
327 | while (op_in_index < shape_compute_graph_->inputs().size()) { |
328 | Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index); |
329 | if (!isListOfListOfInts(graph_in_var->type())) { |
330 | op_in_index++; |
331 | continue; |
332 | } |
333 | |
334 | // Modifying the graph where _node is part of to not use the tensor |
335 | // construct |
336 | |
337 | // When we have partially evaluate a list of Tensors like cat(tensor[]) |
338 | // We have a few problems: |
339 | // - optimizing out calls to the length of the list: len(tensors) |
340 | // - resolving accesses of the list to the tensor symbolic sizes the |
341 | // corresponding list element We can solve both of these problems by |
342 | // replacing the partial evaluation of cat([x, y]) def cat(tensors: |
343 | // List[List[int]], dim: int) |
344 | // body |
345 | // with |
346 | // def cat(x, y, dim: int) |
347 | // tensors = [x, y] |
348 | // body |
349 | TORCH_INTERNAL_ASSERT( |
350 | !seen_tensor_list, |
351 | "SSA doesn't handle case with multiple tensor lists" ) |
352 | seen_tensor_list = true; |
353 | |
354 | uint64_t li_length = inputs_.size() - (schema_->arguments().size() - 1); |
355 | std::vector<Value*> li_inputs; |
356 | |
357 | TypePtr element_type = |
358 | graph_in_var->type()->cast<ListType>()->getElementType(); |
359 | for (size_t j = op_in_index; j < op_in_index + li_length; ++j) { |
360 | auto new_inp = shape_compute_graph_->insertInput(op_in_index + j); |
361 | new_inp->setType(element_type); |
362 | li_inputs.push_back(new_inp); |
363 | } |
364 | WithInsertPoint guard(*shape_compute_graph_->block()->nodes().begin()); |
365 | auto new_li = shape_compute_graph_->insertNode( |
366 | shape_compute_graph_->createList(element_type, li_inputs)); |
367 | graph_in_var->replaceAllUsesWith(new_li->output()); |
368 | shape_compute_graph_->eraseInput(op_in_index + li_length); |
369 | } |
370 | |
371 | TORCH_INTERNAL_ASSERT( |
372 | shape_compute_graph_->inputs().size() <= inputs_.size(), |
373 | "Shape Compute Graph expected to have less inputs than actual inputs" ); //? |
374 | |
375 | for (size_t op_in_index = 0; |
376 | op_in_index < shape_compute_graph_->inputs().size(); |
377 | op_in_index++) { |
378 | SSArgument& argument = inputs_[op_in_index]; |
379 | Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index); |
380 | |
381 | if (IValue* cur_val = c10::get_if<IValue>(&argument)) { |
382 | GRAPH_DEBUG("Substituting constant input " , *cur_val); |
383 | replaceWithIValue(graph_in_var, *cur_val); |
384 | } else { |
385 | auto cur_arg = c10::get<ShapeArguments>(argument); |
386 | if (cur_arg.has_dim()) { |
387 | graph_in_var->setType(ListType::ofInts()); |
388 | } |
389 | } |
390 | } |
391 | } |
392 | |
393 | void substituteSymbolicProperties( |
394 | std::unordered_map<Value*, int64_t>* symbolic_shape_values) { |
395 | // clang-format off |
396 | // here we iteratively substitute properties of the node's input tensors |
397 | // into the shape compute graph. we can substitute constants into the |
398 | // like len(inp) or inp[0] if the tensor has a fixed length or a fixed |
399 | // first dimension. we also try to resolve symbolic shapes of the same |
400 | // symbolic value to the same Value * in the shape compute graph. |
401 | // for the shape logic: |
402 | // dim1 = inp1[0] |
403 | // dim2 = inp2[0] |
404 | // return dim1 if dim2 == 1 else dim2 |
405 | // if we see that inp1[0] and inp2[0] both have the same symbolic shape |
406 | // value, then it is a valid transformation to replace dim2 with dim1 or |
407 | // vice versa. to do this we collect all Value * for a particular symbolic |
408 | // shape. Then, we replace all Value * within that set with their dominator. |
409 | // In the example above, this allows us to infer that the output will be the |
410 | // symbolic dimension value of dim1. |
411 | |
412 | // if `symbolic_shape_values` is not null, record list accesses |
413 | // which resolve to symbolic dimension values with their concrete symbolic |
414 | // shape value. Because symbolic dimensions are represented as negative numbers and |
415 | // are not real values, inserting them as constants in the graph would invalidate |
416 | // the graph for further use. Instead, we keep track of what their value would be |
417 | // for extracting output shapes. |
418 | // clang-format on |
419 | |
420 | std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map; |
421 | |
422 | TORCH_INTERNAL_ASSERT( |
423 | inputs_.size() >= shape_compute_graph_->inputs().size(), |
424 | "Missing Arg for Shape Graph" ); |
425 | for (int64_t index = 0; index < shape_compute_graph_->inputs().size(); |
426 | index++) { |
427 | auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]); |
428 | if (!shape_arguments || !shape_arguments->has_dim()) { |
429 | continue; |
430 | } |
431 | // Add support for testing symbolic shapes with dynamic dims |
432 | |
433 | for (const Use& use : shape_compute_graph_->inputs().at(index)->uses()) { |
434 | // TODO: either decompose composite ops like slice or add handling here |
435 | switch (use.user->kind()) { |
436 | case aten::len: { |
437 | size_t len = shape_arguments->len(); |
438 | replaceWithIValue(use.user->output(), static_cast<int64_t>(len)); |
439 | } break; |
440 | case aten::__getitem__: { |
441 | auto index = constant_as<int64_t>(use.user->inputs().at(1)); |
442 | if (!index) { |
443 | continue; |
444 | } |
445 | auto norm_index = normIndex(*index, shape_arguments->len()); |
446 | if (!norm_index) { |
447 | continue; |
448 | } |
449 | auto shape_arg = shape_arguments->at(*norm_index); |
450 | if (auto const_int = shape_arg.asConstantInt()) { |
451 | replaceWithIValue(use.user->output(), const_int); |
452 | continue; |
453 | } |
454 | auto maybe_shape_symbol = shape_arg.asShapeSymbol(); |
455 | if (!maybe_shape_symbol) { |
456 | continue; |
457 | } |
458 | auto shape_symbol = *maybe_shape_symbol; |
459 | if (symbolic_shape_values) { |
460 | symbolic_shape_values->emplace( |
461 | use.user->output(), shape_symbol.value()); |
462 | } else { |
463 | int64_t symbolic_index = shape_symbol.value(); |
464 | symbolic_shape_map[symbolic_index].push_back(use.user->output()); |
465 | } |
466 | for (const auto& sym_uses : use.user->output()->uses()) { |
467 | auto k = sym_uses.user->kind(); |
468 | if (k != aten::ge && k != aten::le && k != aten::ne && |
469 | k != aten::eq && k != aten::lt && k != aten::gt) { |
470 | break; |
471 | } |
472 | auto other_index = 1 - sym_uses.offset; |
473 | auto other_value = |
474 | constant_as<int64_t>(sym_uses.user->input(other_index)); |
475 | if (!other_value) { |
476 | continue; |
477 | } |
478 | |
479 | // check for dim >= 0, 0 <= dim |
480 | // dim >= 0 |
481 | if (k == aten::ge && *other_value == 0 && other_index == 1) { |
482 | replaceWithIValue(sym_uses.user->output(), true); |
483 | continue; |
484 | } |
485 | // 0 <= dim |
486 | if (k == aten::le && *other_value == 0 && other_index == 0) { |
487 | replaceWithIValue(sym_uses.user->output(), true); |
488 | continue; |
489 | } |
490 | |
491 | // check for dim comparisons to negative number |
492 | if (*other_value >= 0) { |
493 | continue; |
494 | } |
495 | if (k == aten::eq || k == aten::ne) { |
496 | // True if: |
497 | // -2 != {Positive} |
498 | replaceWithIValue(sym_uses.user->output(), k == aten::ne); |
499 | } else { |
500 | // True if: |
501 | // -2 <= / < {Positive} |
502 | // {Positive} >= / > {-2} |
503 | bool true_val = |
504 | ((other_index == 0 && (k == aten::le || k == aten::lt)) || |
505 | (other_index == 1 && (k == aten::ge || k == aten::gt))); |
506 | replaceWithIValue(sym_uses.user->output(), true_val); |
507 | } |
508 | } |
509 | } |
510 | } |
511 | } |
512 | |
513 | for (const auto& symbolic_set : symbolic_shape_map) { |
514 | mergeSymbolicShapeSets(symbolic_set.second); |
515 | } |
516 | } |
517 | } |
518 | |
519 | void mergeSymbolicShapeSets(const std::vector<Value*>& symbolic_set) { |
520 | // `symbolic_set` represents a set of Value * which are all equal |
521 | // to each other. Here, we optimize the graph by replacing values |
522 | // in the set with other dominating values. |
523 | // in the following example, where a, b and c are all in the same |
524 | // symbolic set: |
525 | // if cond: |
526 | // a = li[0] |
527 | // b = li[1] |
528 | // return [a, b] |
529 | // else: |
530 | // c = li[0] |
531 | // return [c, c] |
532 | // we can replace `b` with `a` because it is dominated by `a`, |
533 | // but we cannot replace `c` with another dominating value |
534 | |
535 | // there are ways to compute this more efficiently but typically number of |
536 | // Values for each symbolic set is low and this is cheap to run |
537 | for (const auto i : c10::irange(symbolic_set.size())) { |
538 | Value* v = symbolic_set[i]; |
539 | Value* dominating_value = v; |
540 | for (const auto& sym_set : symbolic_set) { |
541 | if (dominating_value->node()->isDominatedBy(sym_set->node())) { |
542 | dominating_value = sym_set; |
543 | } |
544 | } |
545 | if (dominating_value != v) { |
546 | v->replaceAllUsesWith(dominating_value); |
547 | } |
548 | } |
549 | } |
550 | |
551 | std::vector<c10::SymbolicShape> propagateShapesInGraph() { |
552 | bool made_change = true; |
553 | constexpr size_t MAX_ATTEMPTS = 8; |
554 | for (int attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS; |
555 | attempt_num++) { |
556 | // symbolic shape concrete values are only used in final shape extraction |
557 | GRAPH_DUMP("Before substitution: " , shape_compute_graph_); |
558 | substituteSymbolicProperties(/*symbolic_shape_values*/ nullptr); |
559 | GRAPH_DUMP("Before Opt: " , shape_compute_graph_); |
560 | made_change = shapeGraphCleanupPasses(shape_compute_graph_); |
561 | } |
562 | std::unordered_map<Value*, int64_t> symbolic_shape_values; |
563 | substituteSymbolicProperties(&symbolic_shape_values); |
564 | GRAPH_DUMP("Done with partial evaluation" , shape_compute_graph_); |
565 | |
566 | return extractOutputShape(symbolic_shape_values); |
567 | } |
568 | |
569 | std::vector<c10::SymbolicShape> ( |
570 | std::unordered_map<Value*, int64_t>& symbolic_shape_values) { |
571 | TORCH_INTERNAL_ASSERT( |
572 | shape_compute_graph_->outputs().size() == schema_->returns().size()); |
573 | // TODO: would be nice if there were easy facility to look at uses and see |
574 | // if they are all pure instead of instanting db. |
575 | auto res = std::vector<c10::SymbolicShape>(); |
576 | AliasDb db(shape_compute_graph_); |
577 | for (size_t i = 0; i < shape_compute_graph_->outputs().size(); ++i) { |
578 | auto output = shape_compute_graph_->outputs().at(i); |
579 | auto type = output->type(); |
580 | TORCH_INTERNAL_ASSERT(isListOfInts(type)); |
581 | c10::SymbolicShape ss = |
582 | extractListShape(output, symbolic_shape_values, db); |
583 | GRAPH_DEBUG("Extracted Output: " , ss); |
584 | res.push_back(ss); |
585 | } |
586 | return res; |
587 | } |
588 | |
589 | public: |
590 | SymbolicShapeOpAnalyzer(const FunctionSchema* schema) : schema_(schema) { |
591 | shape_compute_graph_ = nullptr; |
592 | if (!schema_) { |
593 | return; |
594 | } |
595 | auto maybe_graph = shapeComputeGraphForSchema(*schema_); |
596 | if (!maybe_graph) { |
597 | return; |
598 | } |
599 | shape_compute_graph_ = (*maybe_graph)->copy(); |
600 | } |
601 | |
602 | SymbolicShapeOpAnalyzer( |
603 | const FunctionSchema* schema, |
604 | std::shared_ptr<Graph> graph) |
605 | : schema_(schema) { |
606 | shape_compute_graph_ = graph->copy(); |
607 | } |
608 | |
609 | c10::optional<std::vector<c10::SymbolicShape>> run( |
610 | std::vector<SSArgument>& inputs) { |
611 | if (!shape_compute_graph_) { |
612 | return c10::nullopt; |
613 | } |
614 | inputs_ = inputs; |
615 | substituteConstantInputs(); |
616 | GRAPH_DEBUG(inputs_) |
617 | return propagateShapesInGraph(); |
618 | } |
619 | |
620 | std::shared_ptr<Graph> getShapeComputeGraph() { |
621 | return shape_compute_graph_; |
622 | } |
623 | }; |
624 | |
625 | SSArgument tensorShapeArg(Value* tensor_v) { |
626 | auto tt = tensor_v->type()->expect<TensorType>(); |
627 | c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes(); |
628 | |
629 | // for testing, we don't insert complete tensor shapes and rely on our |
630 | // partial evaluation pipeline to propagate information. |
631 | // this is a good proxy for our ability to propagate non-complete shape |
632 | // information. |
633 | if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) { |
634 | return IValue(tt->sizes().concrete_sizes()); |
635 | } |
636 | if (toIValue(tensor_v)) { |
637 | auto size = constant_as<at::Tensor>(tensor_v)->sizes(); |
638 | if (!symbolic_shape_analysis_test_mode) { |
639 | return IValue(size); |
640 | } else { |
641 | return c10::SymbolicShape(size); |
642 | } |
643 | } |
644 | return symbolic_shapes; |
645 | } |
646 | |
647 | std::vector<SSArgument> getNodeInputShapes(Node* n, const AliasDb& db) { |
648 | // TODO: fix the List of integers implementation, and |
649 | // extract out the shape changes, otherwise this is complete |
650 | // NB: shape compute graphs may have less inputs than their node |
651 | // counterparts to allow e.g. sharing one single unary definition |
652 | // so iterate on # of shape inputs |
653 | // We make lists of Tensor inputs variadic, which results in |
654 | // offset between a node index and its corresponding graph index |
655 | std::vector<SSArgument> input_shapes = std::vector<SSArgument>(); |
656 | |
657 | for (size_t node_index = 0; node_index < n->inputs().size(); ++node_index) { |
658 | auto type = n->input(node_index)->type(); |
659 | |
660 | if (type->castRaw<TensorType>()) { |
661 | input_shapes.push_back(tensorShapeArg(n->input(node_index))); |
662 | continue; |
663 | } |
664 | if (isListOfTensors(type)) { |
665 | // waiting for more use cases to decide on best generalization |
666 | if (n->input(node_index)->node()->kind() == prim::Constant) { |
667 | auto ival = toIValue(n->input(node_index)); |
668 | for (const auto& ten : ival->toTensorVector()) { |
669 | input_shapes.emplace_back(c10::List<int64_t>(ten.sizes())); |
670 | } |
671 | } else if ( |
672 | n->input(node_index)->node()->kind() == prim::ListConstruct && |
673 | !db.hasWriters(n->input(node_index))) { |
674 | auto li_construct_node = n->input(node_index)->node(); |
675 | for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) { |
676 | input_shapes.push_back(tensorShapeArg(li_construct_node->input(j))); |
677 | } |
678 | } else { |
679 | TORCH_INTERNAL_ASSERT(false, "Unhandled List, we shouldn't get here" ); |
680 | } |
681 | continue; |
682 | } |
683 | if (auto ival = toIValue(n->input(node_index))) { |
684 | input_shapes.emplace_back(*ival); |
685 | continue; |
686 | } |
687 | if (type->cast<ListType>() && |
688 | type->cast<ListType>()->getElementType()->cast<IntType>()) { |
689 | auto input_src_node = n->input(node_index)->node(); |
690 | if (input_src_node->kind() == prim::ListConstruct && |
691 | !db.hasWriters(n->input(node_index))) { |
692 | // it is a very common in graphs to see patterns like: |
693 | // z = x.view(y.size()) |
694 | // or: |
695 | // z = x.view(1, 10, y.size(0), y.size(1)) |
696 | // We want to propagate symbolic dimensions and concrete sizes |
697 | // from y to z. To do this we try to associate symbolic dimensions |
698 | // or concrete sizes with the integer list inputs that have a |
699 | // constructor taken from constants or y.size() or y.size(0) |
700 | auto list_construct = n->input(node_index)->node(); |
701 | std::vector<ShapeArg> shape; |
702 | for (Value* v : list_construct->inputs()) { |
703 | if (auto constant = constant_as<int64_t>(v)) { |
704 | shape.emplace_back(*constant); |
705 | } else if (v->node()->kind() == aten::size) { |
706 | auto const_index = constant_as<int64_t>(v->node()->input(1)); |
707 | auto tt = v->node()->input(0)->type()->expect<TensorType>(); |
708 | auto ss = tt->symbolic_sizes(); |
709 | if (!ss.rank() || !const_index) { |
710 | // if we are getting a size of a tensor, it is an unknown |
711 | // symbolic dimension instead of an unknown integer (must be |
712 | // >=0) |
713 | shape.emplace_back(at::ShapeSymbol::newSymbol()); |
714 | continue; |
715 | } |
716 | auto norm_index = normIndex(*const_index, *ss.rank()); |
717 | if (!norm_index) { |
718 | shape.emplace_back(at::ShapeSymbol::newSymbol()); |
719 | continue; |
720 | } |
721 | shape.emplace_back(ss[*norm_index]); |
722 | } else { |
723 | shape.emplace_back(ShapeArg::unknownInteger()); |
724 | } |
725 | } |
726 | input_shapes.emplace_back(ShapeArguments(shape)); |
727 | continue; |
728 | } |
729 | if (input_src_node->kind() == aten::size && |
730 | !db.hasWriters(n->input(node_index))) { |
731 | auto ten_inp = input_src_node->input(); |
732 | auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes(); |
733 | input_shapes.emplace_back(ss); |
734 | continue; |
735 | } |
736 | } |
737 | GRAPH_DEBUG( |
738 | "Unhandled input: " , |
739 | n->kind().toDisplayString(), |
740 | " arg num: " , |
741 | node_index); |
742 | input_shapes.emplace_back(c10::SymbolicShape()); |
743 | } |
744 | TORCH_INTERNAL_ASSERT( |
745 | input_shapes.size() >= n->inputs().size(), |
746 | "input_shapes size: " , |
747 | input_shapes.size(), |
748 | " n inputs size: " , |
749 | n->inputs().size()); |
750 | return input_shapes; |
751 | } |
752 | |
753 | void applyOutputShapeToGraph( |
754 | Node* node, |
755 | const std::vector<c10::SymbolicShape>& output_shapes) { |
756 | TORCH_INTERNAL_ASSERT( |
757 | node->outputs().size() == output_shapes.size(), |
758 | "Output shape size mismatch" ); |
759 | for (size_t i = 0; i < output_shapes.size(); ++i) { |
760 | auto& ss = output_shapes.at(i); |
761 | node->output(i)->setType( |
762 | node->output(i)->type()->expect<TensorType>()->withSymbolicShapes(ss)); |
763 | } |
764 | } |
765 | |
766 | std::shared_ptr<Graph> PropagateShapesWithShapeFunction( |
767 | Node* n, |
768 | const AliasDb& db) { |
769 | const FunctionSchema* func_schema = n->maybeSchema(); |
770 | if (!func_schema) { |
771 | return nullptr; |
772 | } |
773 | auto op_analyzer = SymbolicShapeOpAnalyzer(func_schema); |
774 | if (!op_analyzer.getShapeComputeGraph()) { |
775 | return nullptr; |
776 | } |
777 | auto input_shapes = getNodeInputShapes(n, db); |
778 | op_analyzer.refineInputUnionTypes(n); |
779 | |
780 | if (auto output_shapes = op_analyzer.run(input_shapes)) { |
781 | applyOutputShapeToGraph(n, *output_shapes); |
782 | } |
783 | |
784 | return op_analyzer.getShapeComputeGraph(); |
785 | } |
786 | |
787 | c10::SymbolicShape combine_bounds( |
788 | c10::SymbolicShape& lower_bound, |
789 | c10::SymbolicShape& upper_bound) { |
790 | // TODO: At some point we might want to add support for dynamic dims |
791 | TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank()); |
792 | if (lower_bound.rank() == c10::nullopt) { |
793 | return c10::SymbolicShape(); |
794 | } |
795 | std::vector<c10::ShapeSymbol> merged_shapes; |
796 | for (int i = 0; i < lower_bound.rank(); i++) { |
797 | // TODO: Merge equivalent expressions (not needed for current use case) |
798 | if (lower_bound[i] == upper_bound[i]) { |
799 | merged_shapes.push_back(lower_bound[i]); |
800 | } else { |
801 | merged_shapes.push_back(c10::ShapeSymbol::newSymbol()); |
802 | } |
803 | } |
804 | return c10::SymbolicShape(std::move(merged_shapes)); |
805 | } |
806 | |
807 | struct SymbolicShapeGraphAnalyzer { |
808 | SymbolicShapeGraphAnalyzer( |
809 | std::shared_ptr<Graph>& graph, |
810 | Node* beg, |
811 | Node* end) |
812 | : graph_(graph), beg_(beg), end_(end) { |
813 | TORCH_INTERNAL_ASSERT( |
814 | beg_->owningBlock() == end_->owningBlock() && end_->isAfter(beg_)); |
815 | } |
816 | |
817 | c10::optional<ShapeComputeGraphMapping> run() { |
818 | AliasDb db(graph_); |
819 | std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs = |
820 | propagateShapesAndGatherPartialEvalShapeGraphs(db); |
821 | |
822 | auto stitched_shape_compute_graph = std::make_shared<Graph>(); |
823 | // We want to build up a computational graph which computes all shapes |
824 | // we dont know statically - that is, all symbolic shapes within |
825 | // the region [beg, end). it must be executable before beg. |
826 | // TODO: dont require dimensions of tensors to be set AOT ? |
827 | |
828 | for (auto it = beg_->iterator(); it != end_->iterator(); it++) { |
829 | auto curr = *it; |
830 | if (curr->kind() == prim::Constant) { |
831 | continue; |
832 | } |
833 | // TODO: generalize logic to for other tensor input ops when they are |
834 | // added |
835 | if (curr->kind() == prim::ListConstruct) { |
836 | auto uses = curr->output()->uses(); |
837 | if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) { |
838 | return use.user->kind() == aten::cat; |
839 | })) { |
840 | GRAPH_DEBUG("Non cat list use " , getHeader(curr)); |
841 | return c10::nullopt; |
842 | } |
843 | continue; |
844 | } |
845 | |
846 | if (!partial_evaluated_graphs.count(curr)) { |
847 | GRAPH_DEBUG("No graph " , getHeader(curr)); |
848 | return c10::nullopt; |
849 | } |
850 | |
851 | auto outputs = curr->outputs(); |
852 | for (Value* v : outputs) { |
853 | auto tt = v->type()->cast<TensorType>(); |
854 | if (!tt) { |
855 | GRAPH_DEBUG("Non tensor node" , getHeader(curr)); |
856 | return c10::nullopt; |
857 | } |
858 | auto symbolic_sizes = tt->symbolic_sizes(); |
859 | // TODO: dont require # of dimensions of tensors set ? |
860 | if (!symbolic_sizes.rank()) { |
861 | GRAPH_DEBUG("No rank on output " , getHeader(curr)); |
862 | return c10::nullopt; |
863 | } |
864 | } |
865 | auto partial_eval_graph = partial_evaluated_graphs[curr]; |
866 | joinPartialEvaluatedShapeGraphToLargeShapeGraph( |
867 | curr, partial_eval_graph, stitched_shape_compute_graph); |
868 | } |
869 | |
870 | size_t MAX_ITER = 8; |
871 | bool made_change = true; |
872 | size_t i = 0; |
873 | while (i < MAX_ITER && made_change) { |
874 | i++; |
875 | made_change = shapeGraphCleanupPasses(stitched_shape_compute_graph); |
876 | } |
877 | |
878 | // for any output that is duplicated, the symbolic shape must be equal |
879 | // take the symbolic shape that is generated first and get equivalent ones |
880 | std::unordered_map<int64_t, int64_t> discovered_sym_shape_equalities; |
881 | std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim; |
882 | std::vector<size_t> erase_indices; |
883 | |
884 | for (size_t i = 0; i < stitched_shape_compute_graph->outputs().size(); |
885 | ++i) { |
886 | Value* output = stitched_shape_compute_graph->outputs().at(i); |
887 | // this Value is already contained, so the symbolic shape for i must be |
888 | // equal to the symbolic shape at the existing index |
889 | if (graph_output_to_symbolic_shape_dim.count(output)) { |
890 | auto curr_sym_shape = output_index_to_symbolic_shape_[i]; |
891 | auto existing_sym_shape = graph_output_to_symbolic_shape_dim[output]; |
892 | discovered_sym_shape_equalities[curr_sym_shape] = existing_sym_shape; |
893 | erase_indices.push_back(i); |
894 | } else { |
895 | graph_output_to_symbolic_shape_dim[output] = |
896 | output_index_to_symbolic_shape_[i]; |
897 | } |
898 | } |
899 | for (int64_t i = erase_indices.size() - 1; i >= 0; i--) { |
900 | stitched_shape_compute_graph->eraseOutput(erase_indices[i]); |
901 | } |
902 | for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) { |
903 | if (!stitched_shape_compute_graph->inputs().at(i)->hasUses()) { |
904 | enclosing_graph_value_to_shape_graph_input_.erase( |
905 | stitched_shape_compute_graph->inputs().at(i)); |
906 | stitched_shape_compute_graph->eraseInput(i); |
907 | } else { |
908 | ++i; |
909 | } |
910 | } |
911 | |
912 | updateGraphWithSymbolicShapeEqualities(discovered_sym_shape_equalities); |
913 | return ShapeComputeGraphMapping( |
914 | std::move(stitched_shape_compute_graph), |
915 | enclosing_graph_value_to_shape_graph_input_, |
916 | std::move(graph_output_to_symbolic_shape_dim)); |
917 | } |
918 | |
919 | void updateGraphWithSymbolicShapeEqualities( |
920 | std::unordered_map<int64_t, int64_t>& sym_shape_equalities) { |
921 | for (auto it = beg_->iterator(); it != end_->iterator(); it++) { |
922 | auto curr = *it; |
923 | for (size_t i = 0; i < curr->outputs().size(); ++i) { |
924 | auto output = curr->output(i); |
925 | auto tt = output->type()->cast<TensorType>(); |
926 | if (!tt || !tt->symbolic_sizes().rank()) { |
927 | continue; |
928 | } |
929 | bool changed = false; |
930 | std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes(); |
931 | auto new_sizes = |
932 | c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) { |
933 | auto value = shape.value(); |
934 | if (sym_shape_equalities.count(value)) { |
935 | changed = true; |
936 | return sym_shape_equalities[value]; |
937 | } |
938 | return value; |
939 | }); |
940 | if (changed) { |
941 | output->setType( |
942 | tt->withSymbolicShapes(c10::SymbolicShape(new_sizes))); |
943 | } |
944 | } |
945 | } |
946 | } |
947 | |
948 | void registerStitchedComputeOutput( |
949 | std::shared_ptr<Graph> stitched_shape_compute_graph, |
950 | Value* output, |
951 | int64_t symbolic_shape) { |
952 | stitched_shape_compute_graph->registerOutput(output); |
953 | output_index_to_symbolic_shape_ |
954 | [stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape; |
955 | symbolic_shape_value_to_graph_output_[symbolic_shape] = |
956 | stitched_shape_compute_graph->outputs().at( |
957 | stitched_shape_compute_graph->outputs().size() - 1); |
958 | } |
959 | |
960 | void joinPartialEvaluatedShapeGraphToLargeShapeGraph( |
961 | Node* curr, |
962 | std::shared_ptr<Graph> partial_eval_graph, |
963 | std::shared_ptr<Graph> stitched_shape_compute_graph) { |
964 | // we are building up the large shape compute graph by iteratively |
965 | // combining partially evaluated individual node shape graphs. |
966 | |
967 | // We need to maintain two mappings, one from non-Tensor inputs in the |
968 | // enclosing graph to their equivalent mappings within the large shape |
969 | // compute graph, and one from symbolic shape dimension to new node output |
970 | |
971 | // When we add a new tensor node, we do two things: |
972 | // 1: record a mapping from the tensor node output to its shape in the |
973 | // partial eval graph 2: add each symbolic shape dimension that we have |
974 | // not already added as a output to the large shape compute graph |
975 | |
976 | // Once we are done stitching together all partial eval'd graphs, we can |
977 | // cleanup the graph and remove the unneeded complete shapes as outputs, |
978 | // leaving us only compute for calculating the runtime value of symbolic |
979 | // dimensions |
980 | // leaving us only compute for calculating the runtime value of symbolic |
981 | // dimensions |
982 | |
983 | std::vector<Value*> node_inputs; |
984 | // TODO: generalize logic |
985 | if (curr->kind() == aten::cat) { |
986 | TORCH_INTERNAL_ASSERT( |
987 | curr->input(0)->node()->kind() == prim::ListConstruct); |
988 | for (Value* v : curr->input(0)->node()->inputs()) { |
989 | node_inputs.push_back(v); |
990 | } |
991 | node_inputs.push_back(curr->namedInput("dim" )); |
992 | } else { |
993 | for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) { |
994 | node_inputs.push_back(curr->input(i)); |
995 | } |
996 | } |
997 | |
998 | std::vector<Value*> partial_eval_inputs; |
999 | for (size_t i = 0; i < node_inputs.size(); ++i) { |
1000 | auto node_input = node_inputs[i]; |
1001 | auto existing_graph_mapping = |
1002 | enclosing_graph_value_to_shape_graph_input_.find(node_input); |
1003 | if (existing_graph_mapping != |
1004 | enclosing_graph_value_to_shape_graph_input_.end()) { |
1005 | partial_eval_inputs.push_back(existing_graph_mapping->second); |
1006 | } else { |
1007 | Value* shape_graph_input = |
1008 | stitched_shape_compute_graph->addInput()->copyMetadata( |
1009 | partial_eval_graph->inputs().at(i)); |
1010 | enclosing_graph_value_to_shape_graph_input_[node_input] = |
1011 | shape_graph_input; |
1012 | partial_eval_inputs.push_back(shape_graph_input); |
1013 | } |
1014 | // make sure all symbolic dimensions in the graph we are creating are |
1015 | // computed in the partial eval graph |
1016 | if (auto tt = node_input->type()->cast<TensorType>()) { |
1017 | if (!tt->symbolic_sizes().rank()) { |
1018 | continue; |
1019 | } |
1020 | auto rank = *tt->symbolic_sizes().rank(); |
1021 | for (size_t j = 0; j < rank; ++j) { |
1022 | auto shape = tt->symbolic_sizes()[j]; |
1023 | if (shape.is_static() || |
1024 | symbolic_shape_value_to_graph_output_.count(shape.value())) { |
1025 | continue; |
1026 | } |
1027 | auto input = enclosing_graph_value_to_shape_graph_input_[node_input]; |
1028 | WithInsertPoint guard(stitched_shape_compute_graph->block()); |
1029 | auto index = stitched_shape_compute_graph->insertConstant( |
1030 | static_cast<int64_t>(j)); |
1031 | auto li_index = stitched_shape_compute_graph->insert( |
1032 | aten::__getitem__, {input, index}); |
1033 | registerStitchedComputeOutput( |
1034 | stitched_shape_compute_graph, li_index, shape.value()); |
1035 | } |
1036 | } |
1037 | } |
1038 | |
1039 | WithInsertPoint guard(stitched_shape_compute_graph->block()); |
1040 | std::unordered_map<Value*, Value*> value_map; |
1041 | insertGraph( |
1042 | *stitched_shape_compute_graph, |
1043 | *partial_eval_graph, |
1044 | partial_eval_inputs, |
1045 | value_map); |
1046 | |
1047 | for (size_t i = 0; i < curr->outputs().size(); ++i) { |
1048 | Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)]; |
1049 | enclosing_graph_value_to_shape_graph_input_[curr->output(i)] = |
1050 | new_list_output; |
1051 | |
1052 | TORCH_INTERNAL_ASSERT( |
1053 | new_list_output->node()->kind() == prim::ListConstruct || |
1054 | new_list_output->node()->kind() == prim::Constant); |
1055 | TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses()); |
1056 | |
1057 | auto symbolic_sizes = |
1058 | curr->output(i)->type()->expect<TensorType>()->symbolic_sizes(); |
1059 | TORCH_INTERNAL_ASSERT(symbolic_sizes.rank()); |
1060 | |
1061 | for (size_t i = 0; i < *symbolic_sizes.rank(); i++) { |
1062 | if (symbolic_sizes[i].is_static()) { |
1063 | continue; |
1064 | } |
1065 | int64_t symbolic_shape = symbolic_sizes[i].value(); |
1066 | if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) { |
1067 | continue; |
1068 | } |
1069 | registerStitchedComputeOutput( |
1070 | stitched_shape_compute_graph, |
1071 | new_list_output->node()->input(i), |
1072 | symbolic_shape); |
1073 | } |
1074 | } |
1075 | } |
1076 | |
1077 | std::unordered_map<Node*, std::shared_ptr<Graph>> |
1078 | propagateShapesAndGatherPartialEvalShapeGraphs(AliasDb& db) { |
1079 | std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs; |
1080 | for (auto it = beg_->iterator(); it != end_->iterator(); it++) { |
1081 | auto curr = *it; |
1082 | if (auto maybe_graph = PropagateShapesWithShapeFunction(curr, db)) { |
1083 | partial_evaluated_graphs[curr] = maybe_graph; |
1084 | } |
1085 | } |
1086 | return partial_evaluated_graphs; |
1087 | } |
1088 | |
1089 | std::unordered_map<Value*, Value*> |
1090 | enclosing_graph_value_to_shape_graph_input_; |
1091 | std::unordered_map<int64_t, Value*> symbolic_shape_value_to_graph_output_; |
1092 | std::unordered_map<size_t, int64_t> output_index_to_symbolic_shape_; |
1093 | |
1094 | std::shared_ptr<Graph>& graph_; |
1095 | Node* beg_; |
1096 | Node* end_; |
1097 | }; |
1098 | |
1099 | void PropagateShapesOnBlock(Block* b, const AliasDb& db) { |
1100 | for (Node* n : b->nodes()) { |
1101 | // TODO: handle loop |
1102 | if (n->kind() == prim::If) { |
1103 | IfView if_v(n); |
1104 | PropagateShapesOnBlock(if_v.thenBlock(), db); |
1105 | PropagateShapesOnBlock(if_v.elseBlock(), db); |
1106 | mergeTypes(if_v.thenOutputs(), if_v.elseOutputs(), if_v.outputs()); |
1107 | } else if (n->maybeSchema()) { |
1108 | PropagateShapesWithShapeFunction(n, db); |
1109 | } else if (n->kind() == prim::TupleConstruct) { |
1110 | auto orig_type = n->output()->type()->expect<TupleType>(); |
1111 | auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); }); |
1112 | n->output()->setType( |
1113 | orig_type->createWithContained(std::move(new_types))); |
1114 | } |
1115 | } |
1116 | } |
1117 | } // namespace |
1118 | |
1119 | void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) { |
1120 | AliasDb db(graph); |
1121 | PropagateShapesOnBlock(graph->block(), db); |
1122 | } |
1123 | |
1124 | c10::optional<ShapeComputeGraphMapping> |
1125 | PropagateShapesAndBuildLargeShapeComputeGraph( |
1126 | std::shared_ptr<Graph>& graph, |
1127 | Node* beg, |
1128 | Node* end) { |
1129 | return SymbolicShapeGraphAnalyzer(graph, beg, end).run(); |
1130 | } |
1131 | |
1132 | TORCH_API c10::optional<std::vector<c10::SymbolicShape>> |
1133 | calculateSymbolicShapesOnOp( |
1134 | const FunctionSchema* schema, |
1135 | const std::vector<SSAInput>& inputs) { |
1136 | auto bounded_graphs = boundedGraphsForSchema(*schema); |
1137 | auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt; |
1138 | if (!has_shape_compute && bounded_graphs == c10::nullopt) { |
1139 | // Avoid doing all this work for functions that don't have a |
1140 | // supported schema |
1141 | return c10::nullopt; |
1142 | } |
1143 | |
1144 | if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) { |
1145 | return cached_ret_vec; |
1146 | } |
1147 | |
1148 | std::vector<SSArgument> ssa_args; |
1149 | for (auto& arg : inputs) { |
1150 | if (const IValue* ival = c10::get_if<IValue>(&arg)) { |
1151 | ssa_args.emplace_back(*ival); |
1152 | } else { |
1153 | const c10::SymbolicShape* ss = c10::get_if<c10::SymbolicShape>(&arg); |
1154 | ssa_args.emplace_back(ShapeArguments(*ss)); |
1155 | } |
1156 | } |
1157 | // Handle bounded shape option |
1158 | if (bounded_graphs) { |
1159 | auto lower_bound = |
1160 | SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound); |
1161 | auto lower_bound_res = lower_bound.run(ssa_args); |
1162 | auto upper_bound = |
1163 | SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound); |
1164 | auto upper_bound_res = upper_bound.run(ssa_args); |
1165 | // Stitch together the values |
1166 | if (lower_bound_res.has_value() && upper_bound_res.has_value()) { |
1167 | TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size()); |
1168 | auto merged_res = std::vector<c10::SymbolicShape>(); |
1169 | for (size_t i = 0; i < lower_bound_res->size(); i++) { |
1170 | merged_res.push_back( |
1171 | combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i))); |
1172 | } |
1173 | cache_shape_function(schema, inputs, merged_res); |
1174 | return merged_res; |
1175 | } |
1176 | return c10::nullopt; |
1177 | } |
1178 | |
1179 | auto op_analyzer = SymbolicShapeOpAnalyzer(schema); |
1180 | auto res = op_analyzer.run(ssa_args); |
1181 | if (res.has_value()) { |
1182 | cache_shape_function(schema, inputs, res.value()); |
1183 | } |
1184 | return res; |
1185 | } |
1186 | |
1187 | } // namespace jit |
1188 | } // namespace torch |
1189 | |