1#pragma once
2
3#include <ATen/core/symbol.h>
4
5#include <functional>
6#include <memory>
7#include <set>
8#include <string>
9#include <unordered_map>
10#include <unordered_set>
11#include <utility>
12#include <vector>
13
14#include <c10/core/ScalarType.h>
15#include <c10/util/Flags.h>
16#include <torch/csrc/lazy/core/hash.h>
17#include <torch/csrc/lazy/core/ir.h>
18#include <torch/csrc/lazy/core/ir_metadata.h>
19#include <torch/csrc/lazy/ts_backend/ts_node.h>
20
21namespace torch {
22namespace lazy {
23
24/**
25 * The goal of "dynamic" Nodes is to patch a hole in our tracing.
26 * Previously, if a user called `sizes` on a Tensor, it would leak out
27 * of our tracing system, as `sizes` returns a torch.Size or an int. To
28 * prevent this from happening, we introduce DimensionNode, a new type
29 * of Node that abstracts the operation of getting the dimensions of a
30 * Tensor.
31 *
32 * Consider the following example:
33 * ```
34 * numel = x.shape()[0] * x.shape()[1]
35 * ```
36 *
37 * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
38 * and the multiplication of the two SizeNodes will be represented by
39 * a SizeMul (also a subclass of DimensionNode). Through this, we can
40 * prevent `numel` from being represented as a Python int and thus
41 * burned into the Graph.
42 */
43
44class TORCH_API DimensionNode {
45 public:
46 virtual bool isSymbolic() const {
47 return false;
48 };
49 virtual int64_t getDynamicValue() const {
50 TORCH_CHECK(false, "NYI");
51 };
52 virtual int64_t getStaticValue() const {
53 TORCH_CHECK(false, "NYI");
54 };
55 virtual ~DimensionNode() = default;
56};
57
58} // namespace lazy
59} // namespace torch
60