1#pragma once
2
3#include <c10/core/ScalarType.h>
4#include <c10/util/Optional.h>
5#include <torch/csrc/lazy/backend/backend_interface.h>
6#include <torch/csrc/lazy/core/config.h>
7#include <torch/csrc/lazy/core/ir.h>
8#include <torch/csrc/lazy/core/tensor.h>
9#include <torch/csrc/lazy/core/trie.h>
10#include <vector>
11
12// This file is part of the backend interface. So, ops shouldn't be added or
13// removed without due process The exception to this being the view ops which
14// will be removed soon pending functionalization
15
16namespace torch {
17namespace lazy {
18
19template <typename T, typename... Args>
20NodePtr ReuseNode(Args&&... args) {
21 if (FLAGS_torch_lazy_reuse_ir) {
22 return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
23 }
24 return nullptr;
25}
26
27// Caching an IR node into TrieCache
28static inline void CacheNode(NodePtr node) {
29 if (FLAGS_torch_lazy_reuse_ir) {
30 TrieCache::Get()->Insert(std::move(node));
31 }
32}
33
34template <typename T, typename... Args>
35NodePtr MakeNode(Args&&... args) {
36 return std::make_shared<T>(std::forward<Args>(args)...);
37}
38
39// op is passed in for a more efficient node casting, see the implementation of
40// NodeCast
41template <typename T, typename... Args>
42NodePtr ReuseOrMakeNode(Args&&... args) {
43 NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
44 if (!node) {
45 node = MakeNode<T>(std::forward<Args>(args)...);
46 CacheNode(node);
47 }
48 return node;
49}
50
51struct IrBuilder {
52 virtual NodePtr MakeDeviceData(
53 const std::shared_ptr<BackendData>& data) const = 0;
54 virtual NodePtr MakeScalar(
55 const at::Scalar& value,
56 const at::ScalarType& type) const = 0;
57 virtual NodePtr MakeExpand(
58 const Value& input0,
59 const std::vector<int64_t>& size,
60 const bool& is_scalar_expand) const = 0;
61 virtual NodePtr MakeCast(
62 const Value& input0,
63 const at::ScalarType& dtype,
64 const c10::optional<at::ScalarType>& stype = c10::nullopt) const = 0;
65 virtual NodePtr MakeTensorList(const OpList& inputs) const = 0;
66 virtual NodePtr MakeGeneric(
67 const OpKind& op,
68 const OpList& operands,
69 const Shape& shape,
70 const size_t& num_outputs = 1,
71 const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const = 0;
72
73 // dynamic ir nodes
74 virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0;
75 virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0;
76 virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0;
77 virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0;
78
79 virtual ~IrBuilder() = default;
80};
81
82static inline NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) {
83 return getIrBuilder()->MakeDeviceData(data);
84}
85static inline NodePtr MakeScalar(
86 const at::Scalar& value,
87 const at::ScalarType& type) {
88 return getIrBuilder()->MakeScalar(value, type);
89}
90static inline NodePtr MakeExpand(
91 const Value& input0,
92 const std::vector<int64_t>& size,
93 const bool& is_scalar_expand) {
94 return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand);
95}
96static inline NodePtr MakeCast(
97 const Value& input0,
98 const at::ScalarType& dtype,
99 const c10::optional<at::ScalarType>& stype = c10::nullopt) {
100 return getIrBuilder()->MakeCast(input0, dtype, stype);
101}
102static inline NodePtr MakeTensorList(const OpList& inputs) {
103 return getIrBuilder()->MakeTensorList(inputs);
104}
105static inline NodePtr MakeGeneric(
106 const OpKind& op,
107 const OpList& operands,
108 const Shape& shape,
109 const size_t& num_outputs = 1,
110 const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) {
111 return getIrBuilder()->MakeGeneric(
112 op, operands, shape, num_outputs, hash_seed);
113}
114
115// dynamic ir nodes
116static inline NodePtr MakeSizeNode(const Value& input, size_t dim) {
117 return getIrBuilder()->MakeSizeNode(input, dim);
118}
119static inline NodePtr MakeSizeAdd(const Value& a, const Value& b) {
120 return getIrBuilder()->MakeSizeAdd(a, b);
121}
122static inline NodePtr MakeSizeMul(const Value& a, const Value& b) {
123 return getIrBuilder()->MakeSizeAdd(a, b);
124}
125static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
126 return getIrBuilder()->MakeSizeDiv(a, b);
127}
128
129inline Value GetSymIntValue(c10::SymInt a) {
130 return Value(
131 a.is_symbolic()
132 ? dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImpl().get())
133 ->node_
134 : MakeScalar(a.as_int_unchecked(), at::kLong),
135 0);
136}
137
138// TODO: this should return Value
139inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
140 std::vector<int64_t> r;
141 for (const auto& a : arr) {
142 r.emplace_back(a.expect_int());
143 }
144 return r;
145}
146
147} // namespace lazy
148} // namespace torch
149