1 | #pragma once |
2 | |
3 | #include <torch/csrc/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/passes/symbolic_shape_analysis.h> |
6 | |
7 | #include <unordered_map> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | // Takes in a TensorExprGraph of static shapes and generalizes the input shapes |
13 | // to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise |
14 | // dimensions with the same value will be bucketed to the same symbolic shape. |
15 | // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1) |
16 | // From there, runs symbolic shape inference on the graph, and creates a |
17 | // versionining if in the graph with prim::TensorExprDynamicGuard checking if |
18 | // the inputs at runtime match the Generalized Symbolic Shapes that are inputs |
19 | // to the TE Kernel. The computate to calculate all symbolic dimensions is |
20 | // inlined in to the if block with the TE Kernel. All Sym Dim Value* are |
21 | // appended to the end of the TE Kernel Graph/Node inputs, and the Node is |
22 | // augmented with a integer list attr `symbolic_shape_inputs` that gives the |
23 | // mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR |
24 | // examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in |
25 | // `test_shape_analysis` Returns True on Success, False on Failure, can fail if |
26 | // shape propagation fails to propagate # of dims or if complete shapes on |
27 | // inputs not set |
28 | |
29 | TORCH_API bool GenerateGuard( |
30 | Node* tensorexpr_graph_node, |
31 | bool add_composed_op = false); |
32 | |
33 | TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack); |
34 | |
35 | enum class StrideInput { |
36 | // Tensors natively store whether they are contiguous or not as a property |
37 | // this makes it faster to query `is_contiguous` or |
38 | // `is_contiguous(memory_format=channels_last)` |
39 | // than looping through the sizes/strides yourself |
40 | // For tensors with these properties, we only store one value: |
41 | TENSOR_CONT, |
42 | TENSOR_CONT_CHANNELS_LAST, |
43 | // now, we describe other cases, where there is one stride enum |
44 | // per dimension |
45 | S_ONE, // STRIDE_ONE: packed |
46 | S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1] |
47 | S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1] |
48 | S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value |
49 | }; |
50 | |
51 | TORCH_API std::string toString(StrideInput si); |
52 | TORCH_API StrideInput strideInputFromString(const std::string& si); |
53 | |
54 | } // namespace jit |
55 | } // namespace torch |
56 | |