1 | #pragma once |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | |
5 | #include <functional> |
6 | #include <memory> |
7 | #include <set> |
8 | #include <string> |
9 | #include <unordered_map> |
10 | #include <unordered_set> |
11 | #include <utility> |
12 | #include <vector> |
13 | |
14 | #include <c10/core/ScalarType.h> |
15 | #include <c10/util/Flags.h> |
16 | #include <torch/csrc/lazy/core/dynamic_ir.h> |
17 | #include <torch/csrc/lazy/core/hash.h> |
18 | #include <torch/csrc/lazy/core/ir.h> |
19 | #include <torch/csrc/lazy/core/ir_metadata.h> |
20 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
21 | |
22 | C10_DECLARE_bool(ltc_enable_dynamic_shapes); |
23 | |
24 | namespace torch { |
25 | namespace lazy { |
26 | |
27 | /** |
28 | * The goal of "dynamic" Nodes is to patch a hole in our tracing. |
29 | * Previously, if a user called `sizes` on a Tensor, it would leak out |
30 | * of our tracing system, as `sizes` returns a torch.Size or an int. To |
31 | * prevent this from happening, we introduce DimensionNode, a new type |
32 | * of Node that abstracts the operation of getting the dimensions of a |
33 | * Tensor. |
34 | * |
35 | * Consider the following example: |
36 | * ``` |
37 | * numel = x.shape()[0] * x.shape()[1] |
38 | * ``` |
39 | * |
40 | * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode), |
41 | * and the multiplication of the two SizeNodes will be represented by |
42 | * a SizeMul (also a subclass of DimensionNode). Through this, we can |
43 | * prevent `numel` from being represented as a Python int and thus |
44 | * burned into the Graph. |
45 | */ |
46 | |
47 | // Represents the result of calling `size` on a Tensor |
48 | class TORCH_API SizeNode : public TsNode, public DimensionNode { |
49 | public: |
50 | SizeNode(Value input, size_t dim); |
51 | int64_t getStaticValue() const override; |
52 | bool isSymbolic() const override; |
53 | std::string ToString() const override; |
54 | size_t dim_ = 0; |
55 | torch::lazy::TSOpVector Lower( |
56 | std::shared_ptr<torch::jit::GraphFunction> function, |
57 | TSLoweringContext* loctx) const override; |
58 | }; |
59 | |
60 | class TORCH_API SizeAdd : public TsNode, public DimensionNode { |
61 | public: |
62 | SizeAdd(Value a, Value b); |
63 | int64_t getStaticValue() const override; |
64 | bool isSymbolic() const override; |
65 | std::string ToString() const override; |
66 | }; |
67 | |
68 | class TORCH_API SizeMul : public TsNode, public DimensionNode { |
69 | public: |
70 | SizeMul(Value a, Value b); |
71 | int64_t getStaticValue() const override; |
72 | bool isSymbolic() const override; |
73 | std::string ToString() const override; |
74 | }; |
75 | |
76 | class TORCH_API SizeDiv : public TsNode, public DimensionNode { |
77 | public: |
78 | SizeDiv(Value a, Value b); |
79 | int64_t getStaticValue() const override; |
80 | bool isSymbolic() const override; |
81 | std::string ToString() const override; |
82 | }; |
83 | |
84 | } // namespace lazy |
85 | } // namespace torch |
86 | |