1#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
2
3static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) {
4 return dynamic_cast<const torch::lazy::DimensionNode*>(output.node);
5}
6
7namespace torch {
8namespace lazy {
9
10TSOpVector SizeNode::Lower(
11 std::shared_ptr<torch::jit::GraphFunction> function,
12 TSLoweringContext* loctx) const {
13 std::vector<torch::jit::NamedValue> arguments;
14 std::vector<torch::jit::NamedValue> kwarguments;
15 arguments.reserve(2);
16 auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_));
17 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
18 arguments.emplace_back(index);
19 torch::lazy::TSOpVector size_out =
20 torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
21 TORCH_CHECK_EQ(size_out.size(), 1);
22 return size_out;
23}
24
25SizeNode::SizeNode(Value input, size_t dim)
26 : TsNode(
27 OpKind{c10::Symbol::fromQualString("aten::size")},
28 {input},
29 std::vector<Shape>{},
30 1,
31 MHash(dim)),
32 dim_(dim){};
33
34int64_t SizeNode::getStaticValue() const {
35 return dynamic_cast<const TsNode*>(operand(0).node)->shape(0).size(dim_);
36}
37bool SizeNode::isSymbolic() const {
38 auto symbolic_vec =
39 dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic();
40 if (!symbolic_vec.has_value()) {
41 return true;
42 }
43 return symbolic_vec->at(dim_);
44}
45
46std::string SizeNode::ToString() const {
47 return "SizeNode";
48}
49
50SizeAdd::SizeAdd(Value a, Value b)
51 : TsNode(
52 OpKind{c10::Symbol::fromQualString("aten::add")},
53 {a, b},
54 std::vector<Shape>{},
55 1){};
56
57int64_t SizeAdd::getStaticValue() const {
58 return DimCast(operand(0))->getStaticValue() +
59 DimCast(operand(1))->getStaticValue();
60}
61
62bool SizeAdd::isSymbolic() const {
63 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
64}
65
66std::string SizeAdd::ToString() const {
67 return "SizeAdd";
68}
69
70SizeMul::SizeMul(Value a, Value b)
71 : TsNode(
72 OpKind{c10::Symbol::fromQualString("aten::mul")},
73 {a, b},
74 std::vector<Shape>{},
75 1){};
76
77int64_t SizeMul::getStaticValue() const {
78 return DimCast(operand(0))->getStaticValue() *
79 DimCast(operand(1))->getStaticValue();
80}
81
82bool SizeMul::isSymbolic() const {
83 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
84}
85
86std::string SizeMul::ToString() const {
87 return "SizeMul";
88}
89
90SizeDiv::SizeDiv(Value a, Value b)
91 : TsNode(
92 OpKind{c10::Symbol::fromQualString("aten::div")},
93 {a, b},
94 std::vector<Shape>{},
95 1){};
96
97int64_t SizeDiv::getStaticValue() const {
98 TORCH_CHECK(
99 DimCast(operand(1))->getStaticValue() != 0,
100 "Can't divide a dimension by zero");
101 return DimCast(operand(0))->getStaticValue() /
102 DimCast(operand(1))->getStaticValue();
103}
104
105bool SizeDiv::isSymbolic() const {
106 return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
107}
108
109std::string SizeDiv::ToString() const {
110 return "SizeDiv";
111}
112
113} // namespace lazy
114} // namespace torch
115