1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |
17 | #define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |
18 | |
19 | #include <optional> |
20 | #include <ostream> |
21 | #include <string> |
22 | #include <type_traits> |
23 | |
24 | #include "tensorflow/compiler/xla/primitive_util.h" |
25 | #include "tensorflow/compiler/xla/statusor.h" |
26 | #include "tensorflow/compiler/xla/types.h" |
27 | #include "tensorflow/compiler/xla/util.h" |
28 | #include "tensorflow/compiler/xla/xla_data.pb.h" |
29 | |
30 | namespace xla { |
31 | |
32 | // A utility class for primitive comparisons. A comparison includes three |
33 | // components: the type of the elements being compared (F32, S16, etc), whether |
34 | // it is a partial or total order comparison, and the actual comparison operator |
35 | // (==, <=, >, etc). |
36 | // |
37 | // Note that integer comparisons are always total order. Float comparisons can |
38 | // be either total or partial order. |
39 | // |
40 | // Some examples: |
41 | // |
42 | // Comparison a( |
43 | // Comparison::Direction::kLt, |
44 | // xla::PrimitiveType::BF16, |
45 | // Comparison::Order::kTotal |
46 | // ); |
47 | // a.ToString(); /* ".LT.BF16.TOTALORDER" */ |
48 | // |
49 | // Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32); |
50 | // b.IsTotalOrder(); /* true */ |
51 | class Comparison { |
52 | public: |
53 | // Represents the ordering of the comparison. |
54 | enum class Order : uint8_t { |
55 | // https://en.wikipedia.org/wiki/Total_order |
56 | kTotal, |
57 | // https://en.wikipedia.org/wiki/Partially_ordered_set |
58 | kPartial, |
59 | }; |
60 | |
61 | // Represents different comparison operations. |
62 | enum class Direction : uint8_t { |
63 | kEq, |
64 | kNe, |
65 | kGe, |
66 | kGt, |
67 | kLe, |
68 | kLt, |
69 | }; |
70 | |
71 | // (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType |
72 | // and Comparison::Order, since there are multiple floating point |
73 | // representations that support total ordering. |
74 | enum class [[deprecated("Use PrimitiveType and Order" )]] Type : uint8_t{ |
75 | kFloat, |
76 | kFloatTotalOrder, |
77 | kSigned, |
78 | kUnsigned, |
79 | }; |
80 | |
81 | Comparison() = delete; |
82 | |
83 | // This will default to the expected behavior for Comparison::Order: integers |
84 | // will use total ordering, and floats will use partial ordering. |
85 | explicit Comparison(Direction dir, PrimitiveType type); |
86 | |
87 | // Pass in a Comparison::Order to specify a non-default ordering, e.g., some |
88 | // targets may support total order floating point type comparisons. |
89 | explicit Comparison(Direction dir, PrimitiveType type, Order order); |
90 | |
91 | // Returns a comparison with a primitive type matching the Comparison::Type |
92 | // and using a default bit width of 32. For example, |
93 | // Comparison(Direction::kLt, Type::kFloat).PrimitiveType() /* F32 */ |
94 | [[deprecated( |
95 | "Use Comparison(Comparison::Direction, " |
96 | "PrimitiveType)" )]] explicit Comparison(Direction dir, Type type); |
97 | |
98 | inline Direction GetDirection() const { return dir_; } |
99 | inline PrimitiveType GetPrimitiveType() const { return primitive_type_; } |
100 | inline Order GetOrder() const { return order_; } |
101 | |
102 | [[deprecated("Use GetPrimitiveType() and GetOrder()" )]] inline Type GetType() |
103 | const { |
104 | return type_; |
105 | } |
106 | |
107 | inline bool IsEq() const { return dir_ == Direction::kEq; } |
108 | inline bool IsNe() const { return dir_ == Direction::kNe; } |
109 | inline bool IsGe() const { return dir_ == Direction::kGe; } |
110 | inline bool IsGt() const { return dir_ == Direction::kGt; } |
111 | inline bool IsLt() const { return dir_ == Direction::kLt; } |
112 | inline bool IsTotalOrder() const { return order_ == Order::kTotal; } |
113 | inline bool IsPartialOrder() const { return order_ == Order::kPartial; } |
114 | |
115 | // Returns whether this is a floating point total order comparison. |
116 | inline bool IsF32TotalOrder() const { |
117 | return primitive_type_ == PrimitiveType::F32 && IsTotalOrder(); |
118 | } |
119 | inline bool IsBf16TotalOrder() const { |
120 | return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder(); |
121 | } |
122 | |
123 | // Returns whether this is a standard comparison, i.e., what you would expect |
124 | // as the industry standard on most architectures. |
125 | inline bool IsStandardF32() const { |
126 | return primitive_type_ == PrimitiveType::F32 && IsPartialOrder(); |
127 | } |
128 | inline bool IsStandardBf16() const { |
129 | return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder(); |
130 | } |
131 | inline bool IsStandardS32() const { |
132 | return primitive_type_ == PrimitiveType::S32 && IsTotalOrder(); |
133 | } |
134 | inline bool IsStandardU32() const { |
135 | return primitive_type_ == PrimitiveType::U32 && IsTotalOrder(); |
136 | } |
137 | |
138 | inline bool IsIntegralPrimitiveType() const { |
139 | return primitive_util::IsIntegralType(primitive_type_); |
140 | } |
141 | inline bool IsFloatingPointPrimitiveType() const { |
142 | return primitive_util::IsFloatingPointType(primitive_type_); |
143 | } |
144 | |
145 | // Returns whether (a dir a) is always true for this comparison. |
146 | bool IsReflexive() const; |
147 | |
148 | // Returns whether (a dir a) is always false for this comparison. |
149 | bool IsAntireflexive() const; |
150 | |
151 | // Gets the converse of the given comparison direction (e.g. >= turns to <=). |
152 | // Useful when commuting operands to get constants into immediate-accepting |
153 | // positions in the ISA. |
154 | Comparison Converse() const; |
155 | |
156 | // Gets the inverse of the given comparison if it exists (e.g. >= turns to <). |
157 | // Returns optional value because not all inversions may be supported. |
158 | std::optional<Comparison> Inverse() const; |
159 | |
160 | // Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER" |
161 | std::string ToString(std::string prefix1 = "." , std::string prefix2 = "." , |
162 | std::string prefix3 = "." ) const; |
163 | |
164 | // Returns a comparison operator: (T, T) -> bool for this Comparison's |
165 | // Direction. |
166 | template <typename T> |
167 | std::function<bool(T, T)> GetComparator() const { |
168 | switch (GetDirection()) { |
169 | case Direction::kEq: |
170 | return std::equal_to<T>(); |
171 | case Direction::kNe: |
172 | return std::not_equal_to<T>(); |
173 | case Direction::kGe: |
174 | return std::greater_equal<T>(); |
175 | case Direction::kGt: |
176 | return std::greater<T>(); |
177 | case Direction::kLe: |
178 | return std::less_equal<T>(); |
179 | case Direction::kLt: |
180 | return std::less<T>(); |
181 | } |
182 | } |
183 | |
184 | // Applies the comparison from this Comparison's direction and ordering for |
185 | // integral types. |
186 | template <typename T, absl::enable_if_t<std::is_integral<T>::value, int> = 0> |
187 | bool Compare(const T a, const T b) const { |
188 | CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_)); |
189 | return GetComparator<T>()(a, b); |
190 | } |
191 | |
192 | // Applies the comparison from this Comparison's direction and ordering |
193 | // for floating point types. |
194 | template <typename T, |
195 | absl::enable_if_t<std::is_floating_point<T>::value || |
196 | std::is_same<T, xla::bfloat16>::value, |
197 | int> = 0> |
198 | bool Compare(const T a, const T b) const { |
199 | CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_)); |
200 | if (IsTotalOrder()) { |
201 | // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN |
202 | // Reference: |
203 | // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations |
204 | using R = typename SignedIntegerTypeForSize<sizeof(T)>::type; |
205 | return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b)); |
206 | } |
207 | return GetComparator<T>()(a, b); |
208 | } |
209 | |
210 | // Returns the Comparison::Type for the given primitive type. This assumes |
211 | // that each numerical representation follows the standard behavior, e.g., |
212 | // integers are total order and floats are partial order. |
213 | [[deprecated("Use PrimitiveType and Order" )]] static Comparison::Type |
214 | DefaultComparisonType(PrimitiveType type); |
215 | |
216 | private: |
217 | // The direction of the Comparison, e.g., GT. |
218 | const Direction dir_; |
219 | // The primitive type of the Comparison operands, e.g., F32. |
220 | const PrimitiveType primitive_type_; |
221 | // The ordering of the Comparison, e.g., kPartial. |
222 | const Order order_; |
223 | // The Type of the Comparison. This tries to mesh together the ordering and |
224 | // the numerical data classification. |
225 | [[deprecated]] const Type type_; |
226 | }; |
227 | |
228 | using ComparisonDirection = Comparison::Direction; |
229 | using ComparisonOrder = Comparison::Order; |
230 | |
231 | inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { |
232 | return os << cmp.ToString(); |
233 | } |
234 | |
235 | std::string ComparisonDirectionToString(Comparison::Direction direction); |
236 | std::string ComparisonTypeToString(Comparison::Type type); |
237 | std::string ComparisonPrimitiveTypeToString(PrimitiveType type); |
238 | std::string ComparisonOrderToString(Comparison::Order order); |
239 | |
240 | StatusOr<Comparison::Direction> StringToComparisonDirection( |
241 | absl::string_view direction); |
242 | StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison); |
243 | StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order); |
244 | |
245 | // Returns a comparison function using the provided key function on each value, |
246 | // i.e. `key_fn(a) < key_fn(b)`. |
247 | template <typename KeyFn> |
248 | auto LessThanByKey(KeyFn&& key_fn) { |
249 | return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); }; |
250 | } |
251 | |
252 | } // namespace xla |
253 | |
254 | #endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |
255 | |