1 | /** |
2 | * Most of the utils in this file is adapted from PyTorch/XLA |
3 | * https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h |
4 | */ |
5 | |
6 | #pragma once |
7 | |
8 | #include <exception> |
9 | #include <functional> |
10 | #include <vector> |
11 | |
12 | #include <c10/util/Optional.h> |
13 | #include <c10/util/OptionalArrayRef.h> |
14 | |
15 | namespace torch { |
16 | namespace lazy { |
17 | |
18 | // Similar to c10::scope_exit but with a status. |
19 | // TODO(alanwaketan): Consolidate it with c10::scope_exit. |
20 | template <typename T> |
21 | class Cleanup { |
22 | public: |
23 | using StatusType = T; |
24 | |
25 | explicit Cleanup(std::function<void(StatusType&&)>&& func) |
26 | : func_(std::move(func)) {} |
27 | Cleanup(Cleanup&& ref) noexcept |
28 | : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {} |
29 | Cleanup(const Cleanup&) = delete; |
30 | |
31 | ~Cleanup() { |
32 | if (func_ != nullptr) { |
33 | func_(std::move(status_)); |
34 | } |
35 | } |
36 | |
37 | Cleanup& operator=(const Cleanup&) = delete; |
38 | |
39 | Cleanup& operator=(Cleanup&& ref) noexcept { |
40 | if (this != &ref) { |
41 | func_ = std::move(ref.func_); |
42 | status_ = std::move(ref.status_); |
43 | } |
44 | return *this; |
45 | } |
46 | |
47 | void Release() { |
48 | func_ = nullptr; |
49 | } |
50 | |
51 | void SetStatus(StatusType&& status) { |
52 | status_ = std::move(status); |
53 | } |
54 | |
55 | const StatusType& GetStatus() const { |
56 | return status_; |
57 | } |
58 | |
59 | private: |
60 | std::function<void(StatusType&&)> func_; |
61 | StatusType status_; |
62 | }; |
63 | |
64 | using ExceptionCleanup = Cleanup<std::exception_ptr>; |
65 | |
66 | // Allows APIs which might return const references and values, to not be forced |
67 | // to return values in the signature. |
68 | // TODO(alanwaketan): This is clever, but is there really no std or c10 |
69 | // supports? Needs more investigations. |
70 | template <typename T> |
71 | class MaybeRef { |
72 | public: |
73 | /* implicit */ MaybeRef(const T& ref) : ref_(ref) {} |
74 | /* implicit */ MaybeRef(T&& value) |
75 | : storage_(std::move(value)), ref_(*storage_) {} |
76 | |
77 | const T& Get() const { |
78 | return ref_; |
79 | } |
80 | const T& operator*() const { |
81 | return Get(); |
82 | } |
83 | operator const T&() const { |
84 | return Get(); |
85 | } |
86 | |
87 | bool IsStored() const { |
88 | return storage_.has_value(); |
89 | } |
90 | |
91 | private: |
92 | c10::optional<T> storage_; |
93 | const T& ref_; |
94 | }; |
95 | |
96 | template <typename T> |
97 | std::vector<T> Iota(size_t size, T init = 0, T incr = 1) { |
98 | std::vector<T> result(size); |
99 | T value = init; |
100 | for (size_t i = 0; i < size; ++i, value += incr) { |
101 | result[i] = value; |
102 | } |
103 | return result; |
104 | } |
105 | |
106 | template <typename T, typename S> |
107 | std::vector<T> ToVector(const S& input) { |
108 | return std::vector<T>(input.begin(), input.end()); |
109 | } |
110 | |
111 | template <typename T> |
112 | c10::optional<std::vector<T>> ToOptionalVector( |
113 | c10::OptionalArrayRef<T> arrayRef) { |
114 | if (arrayRef) { |
115 | return arrayRef->vec(); |
116 | } |
117 | return c10::nullopt; |
118 | } |
119 | |
120 | template <typename T> |
121 | typename std::underlying_type<T>::type GetEnumValue(T value) { |
122 | return static_cast<typename std::underlying_type<T>::type>(value); |
123 | } |
124 | |
125 | } // namespace lazy |
126 | } // namespace torch |
127 | |