1 | #include <torch/csrc/lazy/ts_backend/dynamic_ir.h> |
2 | |
3 | static const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output) { |
4 | return dynamic_cast<const torch::lazy::DimensionNode*>(output.node); |
5 | } |
6 | |
7 | namespace torch { |
8 | namespace lazy { |
9 | |
10 | TSOpVector SizeNode::Lower( |
11 | std::shared_ptr<torch::jit::GraphFunction> function, |
12 | TSLoweringContext* loctx) const { |
13 | std::vector<torch::jit::NamedValue> arguments; |
14 | std::vector<torch::jit::NamedValue> kwarguments; |
15 | arguments.reserve(2); |
16 | auto index = loctx->graph()->insertConstant(static_cast<int64_t>(this->dim_)); |
17 | arguments.emplace_back(loctx->GetOutputOp(operand(0))); |
18 | arguments.emplace_back(index); |
19 | torch::lazy::TSOpVector size_out = |
20 | torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); |
21 | TORCH_CHECK_EQ(size_out.size(), 1); |
22 | return size_out; |
23 | } |
24 | |
25 | SizeNode::SizeNode(Value input, size_t dim) |
26 | : TsNode( |
27 | OpKind{c10::Symbol::fromQualString("aten::size" )}, |
28 | {input}, |
29 | std::vector<Shape>{}, |
30 | 1, |
31 | MHash(dim)), |
32 | dim_(dim){}; |
33 | |
34 | int64_t SizeNode::getStaticValue() const { |
35 | return dynamic_cast<const TsNode*>(operand(0).node)->shape(0).size(dim_); |
36 | } |
37 | bool SizeNode::isSymbolic() const { |
38 | auto symbolic_vec = |
39 | dynamic_cast<const TsNode*>(operand(0).node)->shape(0).is_symbolic(); |
40 | if (!symbolic_vec.has_value()) { |
41 | return true; |
42 | } |
43 | return symbolic_vec->at(dim_); |
44 | } |
45 | |
46 | std::string SizeNode::ToString() const { |
47 | return "SizeNode" ; |
48 | } |
49 | |
50 | SizeAdd::SizeAdd(Value a, Value b) |
51 | : TsNode( |
52 | OpKind{c10::Symbol::fromQualString("aten::add" )}, |
53 | {a, b}, |
54 | std::vector<Shape>{}, |
55 | 1){}; |
56 | |
57 | int64_t SizeAdd::getStaticValue() const { |
58 | return DimCast(operand(0))->getStaticValue() + |
59 | DimCast(operand(1))->getStaticValue(); |
60 | } |
61 | |
62 | bool SizeAdd::isSymbolic() const { |
63 | return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic(); |
64 | } |
65 | |
66 | std::string SizeAdd::ToString() const { |
67 | return "SizeAdd" ; |
68 | } |
69 | |
70 | SizeMul::SizeMul(Value a, Value b) |
71 | : TsNode( |
72 | OpKind{c10::Symbol::fromQualString("aten::mul" )}, |
73 | {a, b}, |
74 | std::vector<Shape>{}, |
75 | 1){}; |
76 | |
77 | int64_t SizeMul::getStaticValue() const { |
78 | return DimCast(operand(0))->getStaticValue() * |
79 | DimCast(operand(1))->getStaticValue(); |
80 | } |
81 | |
82 | bool SizeMul::isSymbolic() const { |
83 | return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic(); |
84 | } |
85 | |
86 | std::string SizeMul::ToString() const { |
87 | return "SizeMul" ; |
88 | } |
89 | |
90 | SizeDiv::SizeDiv(Value a, Value b) |
91 | : TsNode( |
92 | OpKind{c10::Symbol::fromQualString("aten::div" )}, |
93 | {a, b}, |
94 | std::vector<Shape>{}, |
95 | 1){}; |
96 | |
97 | int64_t SizeDiv::getStaticValue() const { |
98 | TORCH_CHECK( |
99 | DimCast(operand(1))->getStaticValue() != 0, |
100 | "Can't divide a dimension by zero" ); |
101 | return DimCast(operand(0))->getStaticValue() / |
102 | DimCast(operand(1))->getStaticValue(); |
103 | } |
104 | |
105 | bool SizeDiv::isSymbolic() const { |
106 | return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic(); |
107 | } |
108 | |
109 | std::string SizeDiv::ToString() const { |
110 | return "SizeDiv" ; |
111 | } |
112 | |
113 | } // namespace lazy |
114 | } // namespace torch |
115 | |