1 | /** |
2 | * Unique in this file is adapted from PyTorch/XLA |
3 | * https://github.com/pytorch/xla/blob/master/third_party/xla_client/unique.h |
4 | */ |
5 | |
6 | #pragma once |
7 | |
8 | #include <c10/util/Optional.h> |
9 | |
10 | #include <functional> |
11 | #include <set> |
12 | |
13 | namespace torch { |
14 | namespace lazy { |
15 | |
16 | // Helper class to allow tracking zero or more things, which should be forcibly |
17 | // be one only thing. |
18 | template <typename T, typename C = std::equal_to<T>> |
19 | class Unique { |
20 | public: |
21 | std::pair<bool, const T&> set(const T& value) { |
22 | if (value_) { |
23 | TORCH_CHECK(C()(*value_, value), "'" , *value_, "' vs '" , value); |
24 | return std::pair<bool, const T&>(false, *value_); |
25 | } |
26 | value_ = value; |
27 | return std::pair<bool, const T&>(true, *value_); |
28 | } |
29 | |
30 | operator bool() const { |
31 | return value_.has_value(); |
32 | } |
33 | operator const T&() const { |
34 | return *value_; |
35 | } |
36 | const T& operator*() const { |
37 | return *value_; |
38 | } |
39 | const T* operator->() const { |
40 | return value_.operator->(); |
41 | } |
42 | |
43 | std::set<T> AsSet() const { |
44 | std::set<T> vset; |
45 | if (value_.has_value()) { |
46 | vset.insert(*value_); |
47 | } |
48 | return vset; |
49 | } |
50 | |
51 | private: |
52 | c10::optional<T> value_; |
53 | }; |
54 | |
55 | } // namespace lazy |
56 | } // namespace torch |
57 | |