1 | #pragma once |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | |
5 | #include <functional> |
6 | #include <memory> |
7 | #include <set> |
8 | #include <string> |
9 | #include <unordered_map> |
10 | #include <unordered_set> |
11 | #include <utility> |
12 | #include <vector> |
13 | |
14 | #include <c10/core/ScalarType.h> |
15 | #include <c10/util/Flags.h> |
16 | #include <torch/csrc/lazy/core/hash.h> |
17 | #include <torch/csrc/lazy/core/ir.h> |
18 | #include <torch/csrc/lazy/core/ir_metadata.h> |
19 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
20 | |
21 | namespace torch { |
22 | namespace lazy { |
23 | |
24 | /** |
25 | * The goal of "dynamic" Nodes is to patch a hole in our tracing. |
26 | * Previously, if a user called `sizes` on a Tensor, it would leak out |
27 | * of our tracing system, as `sizes` returns a torch.Size or an int. To |
28 | * prevent this from happening, we introduce DimensionNode, a new type |
29 | * of Node that abstracts the operation of getting the dimensions of a |
30 | * Tensor. |
31 | * |
32 | * Consider the following example: |
33 | * ``` |
34 | * numel = x.shape()[0] * x.shape()[1] |
35 | * ``` |
36 | * |
37 | * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode), |
38 | * and the multiplication of the two SizeNodes will be represented by |
39 | * a SizeMul (also a subclass of DimensionNode). Through this, we can |
40 | * prevent `numel` from being represented as a Python int and thus |
41 | * burned into the Graph. |
42 | */ |
43 | |
44 | class TORCH_API DimensionNode { |
45 | public: |
46 | virtual bool isSymbolic() const { |
47 | return false; |
48 | }; |
49 | virtual int64_t getDynamicValue() const { |
50 | TORCH_CHECK(false, "NYI" ); |
51 | }; |
52 | virtual int64_t getStaticValue() const { |
53 | TORCH_CHECK(false, "NYI" ); |
54 | }; |
55 | virtual ~DimensionNode() = default; |
56 | }; |
57 | |
58 | } // namespace lazy |
59 | } // namespace torch |
60 | |