1#pragma once
2
3#include <torch/csrc/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
6
7#include <unordered_map>
8
9namespace torch {
10namespace jit {
11
12// Takes in a TensorExprGraph of static shapes and generalizes the input shapes
13// to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
14// dimensions with the same value will be bucketed to the same symbolic shape.
15// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
16// From there, runs symbolic shape inference on the graph, and creates a
17// versionining if in the graph with prim::TensorExprDynamicGuard checking if
18// the inputs at runtime match the Generalized Symbolic Shapes that are inputs
19// to the TE Kernel. The computate to calculate all symbolic dimensions is
20// inlined in to the if block with the TE Kernel. All Sym Dim Value* are
21// appended to the end of the TE Kernel Graph/Node inputs, and the Node is
22// augmented with a integer list attr `symbolic_shape_inputs` that gives the
23// mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
24// examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
25// `test_shape_analysis` Returns True on Success, False on Failure, can fail if
26// shape propagation fails to propagate # of dims or if complete shapes on
27// inputs not set
28
29TORCH_API bool GenerateGuard(
30 Node* tensorexpr_graph_node,
31 bool add_composed_op = false);
32
33TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack);
34
35enum class StrideInput {
36 // Tensors natively store whether they are contiguous or not as a property
37 // this makes it faster to query `is_contiguous` or
38 // `is_contiguous(memory_format=channels_last)`
39 // than looping through the sizes/strides yourself
40 // For tensors with these properties, we only store one value:
41 TENSOR_CONT,
42 TENSOR_CONT_CHANNELS_LAST,
43 // now, we describe other cases, where there is one stride enum
44 // per dimension
45 S_ONE, // STRIDE_ONE: packed
46 S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
47 S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
48 S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
49};
50
51TORCH_API std::string toString(StrideInput si);
52TORCH_API StrideInput strideInputFromString(const std::string& si);
53
54} // namespace jit
55} // namespace torch
56