1/**
2 * Most of the utils in this file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h
4 */
5
6#pragma once
7
8#include <exception>
9#include <functional>
10#include <vector>
11
12#include <c10/util/Optional.h>
13#include <c10/util/OptionalArrayRef.h>
14
15namespace torch {
16namespace lazy {
17
18// Similar to c10::scope_exit but with a status.
19// TODO(alanwaketan): Consolidate it with c10::scope_exit.
20template <typename T>
21class Cleanup {
22 public:
23 using StatusType = T;
24
25 explicit Cleanup(std::function<void(StatusType&&)>&& func)
26 : func_(std::move(func)) {}
27 Cleanup(Cleanup&& ref) noexcept
28 : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {}
29 Cleanup(const Cleanup&) = delete;
30
31 ~Cleanup() {
32 if (func_ != nullptr) {
33 func_(std::move(status_));
34 }
35 }
36
37 Cleanup& operator=(const Cleanup&) = delete;
38
39 Cleanup& operator=(Cleanup&& ref) noexcept {
40 if (this != &ref) {
41 func_ = std::move(ref.func_);
42 status_ = std::move(ref.status_);
43 }
44 return *this;
45 }
46
47 void Release() {
48 func_ = nullptr;
49 }
50
51 void SetStatus(StatusType&& status) {
52 status_ = std::move(status);
53 }
54
55 const StatusType& GetStatus() const {
56 return status_;
57 }
58
59 private:
60 std::function<void(StatusType&&)> func_;
61 StatusType status_;
62};
63
64using ExceptionCleanup = Cleanup<std::exception_ptr>;
65
66// Allows APIs which might return const references and values, to not be forced
67// to return values in the signature.
68// TODO(alanwaketan): This is clever, but is there really no std or c10
69// supports? Needs more investigations.
70template <typename T>
71class MaybeRef {
72 public:
73 /* implicit */ MaybeRef(const T& ref) : ref_(ref) {}
74 /* implicit */ MaybeRef(T&& value)
75 : storage_(std::move(value)), ref_(*storage_) {}
76
77 const T& Get() const {
78 return ref_;
79 }
80 const T& operator*() const {
81 return Get();
82 }
83 operator const T&() const {
84 return Get();
85 }
86
87 bool IsStored() const {
88 return storage_.has_value();
89 }
90
91 private:
92 c10::optional<T> storage_;
93 const T& ref_;
94};
95
96template <typename T>
97std::vector<T> Iota(size_t size, T init = 0, T incr = 1) {
98 std::vector<T> result(size);
99 T value = init;
100 for (size_t i = 0; i < size; ++i, value += incr) {
101 result[i] = value;
102 }
103 return result;
104}
105
106template <typename T, typename S>
107std::vector<T> ToVector(const S& input) {
108 return std::vector<T>(input.begin(), input.end());
109}
110
111template <typename T>
112c10::optional<std::vector<T>> ToOptionalVector(
113 c10::OptionalArrayRef<T> arrayRef) {
114 if (arrayRef) {
115 return arrayRef->vec();
116 }
117 return c10::nullopt;
118}
119
120template <typename T>
121typename std::underlying_type<T>::type GetEnumValue(T value) {
122 return static_cast<typename std::underlying_type<T>::type>(value);
123}
124
125} // namespace lazy
126} // namespace torch
127