1 | #pragma once |
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <torch/csrc/lazy/backend/backend_interface.h> |
6 | #include <torch/csrc/lazy/core/config.h> |
7 | #include <torch/csrc/lazy/core/ir.h> |
8 | #include <torch/csrc/lazy/core/tensor.h> |
9 | #include <torch/csrc/lazy/core/trie.h> |
10 | #include <vector> |
11 | |
12 | // This file is part of the backend interface. So, ops shouldn't be added or |
13 | // removed without due process The exception to this being the view ops which |
14 | // will be removed soon pending functionalization |
15 | |
16 | namespace torch { |
17 | namespace lazy { |
18 | |
19 | template <typename T, typename... Args> |
20 | NodePtr ReuseNode(Args&&... args) { |
21 | if (FLAGS_torch_lazy_reuse_ir) { |
22 | return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...); |
23 | } |
24 | return nullptr; |
25 | } |
26 | |
27 | // Caching an IR node into TrieCache |
28 | static inline void CacheNode(NodePtr node) { |
29 | if (FLAGS_torch_lazy_reuse_ir) { |
30 | TrieCache::Get()->Insert(std::move(node)); |
31 | } |
32 | } |
33 | |
34 | template <typename T, typename... Args> |
35 | NodePtr MakeNode(Args&&... args) { |
36 | return std::make_shared<T>(std::forward<Args>(args)...); |
37 | } |
38 | |
39 | // op is passed in for a more efficient node casting, see the implementation of |
40 | // NodeCast |
41 | template <typename T, typename... Args> |
42 | NodePtr ReuseOrMakeNode(Args&&... args) { |
43 | NodePtr node = ReuseNode<T>(std::forward<Args>(args)...); |
44 | if (!node) { |
45 | node = MakeNode<T>(std::forward<Args>(args)...); |
46 | CacheNode(node); |
47 | } |
48 | return node; |
49 | } |
50 | |
51 | struct IrBuilder { |
52 | virtual NodePtr MakeDeviceData( |
53 | const std::shared_ptr<BackendData>& data) const = 0; |
54 | virtual NodePtr MakeScalar( |
55 | const at::Scalar& value, |
56 | const at::ScalarType& type) const = 0; |
57 | virtual NodePtr MakeExpand( |
58 | const Value& input0, |
59 | const std::vector<int64_t>& size, |
60 | const bool& is_scalar_expand) const = 0; |
61 | virtual NodePtr MakeCast( |
62 | const Value& input0, |
63 | const at::ScalarType& dtype, |
64 | const c10::optional<at::ScalarType>& stype = c10::nullopt) const = 0; |
65 | virtual NodePtr MakeTensorList(const OpList& inputs) const = 0; |
66 | virtual NodePtr MakeGeneric( |
67 | const OpKind& op, |
68 | const OpList& operands, |
69 | const Shape& shape, |
70 | const size_t& num_outputs = 1, |
71 | const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const = 0; |
72 | |
73 | // dynamic ir nodes |
74 | virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0; |
75 | virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0; |
76 | virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0; |
77 | virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0; |
78 | |
79 | virtual ~IrBuilder() = default; |
80 | }; |
81 | |
82 | static inline NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) { |
83 | return getIrBuilder()->MakeDeviceData(data); |
84 | } |
85 | static inline NodePtr MakeScalar( |
86 | const at::Scalar& value, |
87 | const at::ScalarType& type) { |
88 | return getIrBuilder()->MakeScalar(value, type); |
89 | } |
90 | static inline NodePtr MakeExpand( |
91 | const Value& input0, |
92 | const std::vector<int64_t>& size, |
93 | const bool& is_scalar_expand) { |
94 | return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand); |
95 | } |
96 | static inline NodePtr MakeCast( |
97 | const Value& input0, |
98 | const at::ScalarType& dtype, |
99 | const c10::optional<at::ScalarType>& stype = c10::nullopt) { |
100 | return getIrBuilder()->MakeCast(input0, dtype, stype); |
101 | } |
102 | static inline NodePtr MakeTensorList(const OpList& inputs) { |
103 | return getIrBuilder()->MakeTensorList(inputs); |
104 | } |
105 | static inline NodePtr MakeGeneric( |
106 | const OpKind& op, |
107 | const OpList& operands, |
108 | const Shape& shape, |
109 | const size_t& num_outputs = 1, |
110 | const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) { |
111 | return getIrBuilder()->MakeGeneric( |
112 | op, operands, shape, num_outputs, hash_seed); |
113 | } |
114 | |
115 | // dynamic ir nodes |
116 | static inline NodePtr MakeSizeNode(const Value& input, size_t dim) { |
117 | return getIrBuilder()->MakeSizeNode(input, dim); |
118 | } |
119 | static inline NodePtr MakeSizeAdd(const Value& a, const Value& b) { |
120 | return getIrBuilder()->MakeSizeAdd(a, b); |
121 | } |
122 | static inline NodePtr MakeSizeMul(const Value& a, const Value& b) { |
123 | return getIrBuilder()->MakeSizeAdd(a, b); |
124 | } |
125 | static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) { |
126 | return getIrBuilder()->MakeSizeDiv(a, b); |
127 | } |
128 | |
129 | inline Value GetSymIntValue(c10::SymInt a) { |
130 | return Value( |
131 | a.is_symbolic() |
132 | ? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get()) |
133 | ->node_ |
134 | : MakeScalar(a.as_int_unchecked(), at::kLong), |
135 | 0); |
136 | } |
137 | |
138 | // TODO: this should return Value |
139 | inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) { |
140 | std::vector<int64_t> r; |
141 | for (const auto& a : arr) { |
142 | r.emplace_back(a.expect_int()); |
143 | } |
144 | return r; |
145 | } |
146 | |
147 | } // namespace lazy |
148 | } // namespace torch |
149 | |