1#pragma once
2
3#include <ATen/core/symbol.h>
4
5#include <functional>
6#include <memory>
7#include <set>
8#include <string>
9#include <unordered_map>
10#include <unordered_set>
11#include <utility>
12#include <vector>
13
14#include <c10/core/ScalarType.h>
15#include <c10/util/Flags.h>
16#include <torch/csrc/lazy/core/dynamic_ir.h>
17#include <torch/csrc/lazy/core/hash.h>
18#include <torch/csrc/lazy/core/ir.h>
19#include <torch/csrc/lazy/core/ir_metadata.h>
20#include <torch/csrc/lazy/ts_backend/ts_node.h>
21
22C10_DECLARE_bool(ltc_enable_dynamic_shapes);
23
24namespace torch {
25namespace lazy {
26
27/**
28 * The goal of "dynamic" Nodes is to patch a hole in our tracing.
29 * Previously, if a user called `sizes` on a Tensor, it would leak out
30 * of our tracing system, as `sizes` returns a torch.Size or an int. To
31 * prevent this from happening, we introduce DimensionNode, a new type
32 * of Node that abstracts the operation of getting the dimensions of a
33 * Tensor.
34 *
35 * Consider the following example:
36 * ```
37 * numel = x.shape()[0] * x.shape()[1]
38 * ```
39 *
40 * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
41 * and the multiplication of the two SizeNodes will be represented by
42 * a SizeMul (also a subclass of DimensionNode). Through this, we can
43 * prevent `numel` from being represented as a Python int and thus
44 * burned into the Graph.
45 */
46
47// Represents the result of calling `size` on a Tensor
48class TORCH_API SizeNode : public TsNode, public DimensionNode {
49 public:
50 SizeNode(Value input, size_t dim);
51 int64_t getStaticValue() const override;
52 bool isSymbolic() const override;
53 std::string ToString() const override;
54 size_t dim_ = 0;
55 torch::lazy::TSOpVector Lower(
56 std::shared_ptr<torch::jit::GraphFunction> function,
57 TSLoweringContext* loctx) const override;
58};
59
60class TORCH_API SizeAdd : public TsNode, public DimensionNode {
61 public:
62 SizeAdd(Value a, Value b);
63 int64_t getStaticValue() const override;
64 bool isSymbolic() const override;
65 std::string ToString() const override;
66};
67
68class TORCH_API SizeMul : public TsNode, public DimensionNode {
69 public:
70 SizeMul(Value a, Value b);
71 int64_t getStaticValue() const override;
72 bool isSymbolic() const override;
73 std::string ToString() const override;
74};
75
76class TORCH_API SizeDiv : public TsNode, public DimensionNode {
77 public:
78 SizeDiv(Value a, Value b);
79 int64_t getStaticValue() const override;
80 bool isSymbolic() const override;
81 std::string ToString() const override;
82};
83
84} // namespace lazy
85} // namespace torch
86