1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
17#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
18
19#include <optional>
20#include <ostream>
21#include <string>
22#include <type_traits>
23
24#include "tensorflow/compiler/xla/primitive_util.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/types.h"
27#include "tensorflow/compiler/xla/util.h"
28#include "tensorflow/compiler/xla/xla_data.pb.h"
29
30namespace xla {
31
32// A utility class for primitive comparisons. A comparison includes three
33// components: the type of the elements being compared (F32, S16, etc), whether
34// it is a partial or total order comparison, and the actual comparison operator
35// (==, <=, >, etc).
36//
37// Note that integer comparisons are always total order. Float comparisons can
38// be either total or partial order.
39//
40// Some examples:
41//
42// Comparison a(
43// Comparison::Direction::kLt,
44// xla::PrimitiveType::BF16,
45// Comparison::Order::kTotal
46// );
47// a.ToString(); /* ".LT.BF16.TOTALORDER" */
48//
49// Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32);
50// b.IsTotalOrder(); /* true */
51class Comparison {
52 public:
53 // Represents the ordering of the comparison.
54 enum class Order : uint8_t {
55 // https://en.wikipedia.org/wiki/Total_order
56 kTotal,
57 // https://en.wikipedia.org/wiki/Partially_ordered_set
58 kPartial,
59 };
60
61 // Represents different comparison operations.
62 enum class Direction : uint8_t {
63 kEq,
64 kNe,
65 kGe,
66 kGt,
67 kLe,
68 kLt,
69 };
70
71 // (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType
72 // and Comparison::Order, since there are multiple floating point
73 // representations that support total ordering.
74 enum class [[deprecated("Use PrimitiveType and Order")]] Type : uint8_t{
75 kFloat,
76 kFloatTotalOrder,
77 kSigned,
78 kUnsigned,
79 };
80
81 Comparison() = delete;
82
83 // This will default to the expected behavior for Comparison::Order: integers
84 // will use total ordering, and floats will use partial ordering.
85 explicit Comparison(Direction dir, PrimitiveType type);
86
87 // Pass in a Comparison::Order to specify a non-default ordering, e.g., some
88 // targets may support total order floating point type comparisons.
89 explicit Comparison(Direction dir, PrimitiveType type, Order order);
90
91 // Returns a comparison with a primitive type matching the Comparison::Type
92 // and using a default bit width of 32. For example,
93 // Comparison(Direction::kLt, Type::kFloat).PrimitiveType() /* F32 */
94 [[deprecated(
95 "Use Comparison(Comparison::Direction, "
96 "PrimitiveType)")]] explicit Comparison(Direction dir, Type type);
97
98 inline Direction GetDirection() const { return dir_; }
99 inline PrimitiveType GetPrimitiveType() const { return primitive_type_; }
100 inline Order GetOrder() const { return order_; }
101
102 [[deprecated("Use GetPrimitiveType() and GetOrder()")]] inline Type GetType()
103 const {
104 return type_;
105 }
106
107 inline bool IsEq() const { return dir_ == Direction::kEq; }
108 inline bool IsNe() const { return dir_ == Direction::kNe; }
109 inline bool IsGe() const { return dir_ == Direction::kGe; }
110 inline bool IsGt() const { return dir_ == Direction::kGt; }
111 inline bool IsLt() const { return dir_ == Direction::kLt; }
112 inline bool IsTotalOrder() const { return order_ == Order::kTotal; }
113 inline bool IsPartialOrder() const { return order_ == Order::kPartial; }
114
115 // Returns whether this is a floating point total order comparison.
116 inline bool IsF32TotalOrder() const {
117 return primitive_type_ == PrimitiveType::F32 && IsTotalOrder();
118 }
119 inline bool IsBf16TotalOrder() const {
120 return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder();
121 }
122
123 // Returns whether this is a standard comparison, i.e., what you would expect
124 // as the industry standard on most architectures.
125 inline bool IsStandardF32() const {
126 return primitive_type_ == PrimitiveType::F32 && IsPartialOrder();
127 }
128 inline bool IsStandardBf16() const {
129 return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder();
130 }
131 inline bool IsStandardS32() const {
132 return primitive_type_ == PrimitiveType::S32 && IsTotalOrder();
133 }
134 inline bool IsStandardU32() const {
135 return primitive_type_ == PrimitiveType::U32 && IsTotalOrder();
136 }
137
138 inline bool IsIntegralPrimitiveType() const {
139 return primitive_util::IsIntegralType(primitive_type_);
140 }
141 inline bool IsFloatingPointPrimitiveType() const {
142 return primitive_util::IsFloatingPointType(primitive_type_);
143 }
144
145 // Returns whether (a dir a) is always true for this comparison.
146 bool IsReflexive() const;
147
148 // Returns whether (a dir a) is always false for this comparison.
149 bool IsAntireflexive() const;
150
151 // Gets the converse of the given comparison direction (e.g. >= turns to <=).
152 // Useful when commuting operands to get constants into immediate-accepting
153 // positions in the ISA.
154 Comparison Converse() const;
155
156 // Gets the inverse of the given comparison if it exists (e.g. >= turns to <).
157 // Returns optional value because not all inversions may be supported.
158 std::optional<Comparison> Inverse() const;
159
160 // Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER"
161 std::string ToString(std::string prefix1 = ".", std::string prefix2 = ".",
162 std::string prefix3 = ".") const;
163
164 // Returns a comparison operator: (T, T) -> bool for this Comparison's
165 // Direction.
166 template <typename T>
167 std::function<bool(T, T)> GetComparator() const {
168 switch (GetDirection()) {
169 case Direction::kEq:
170 return std::equal_to<T>();
171 case Direction::kNe:
172 return std::not_equal_to<T>();
173 case Direction::kGe:
174 return std::greater_equal<T>();
175 case Direction::kGt:
176 return std::greater<T>();
177 case Direction::kLe:
178 return std::less_equal<T>();
179 case Direction::kLt:
180 return std::less<T>();
181 }
182 }
183
184 // Applies the comparison from this Comparison's direction and ordering for
185 // integral types.
186 template <typename T, absl::enable_if_t<std::is_integral<T>::value, int> = 0>
187 bool Compare(const T a, const T b) const {
188 CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
189 return GetComparator<T>()(a, b);
190 }
191
192 // Applies the comparison from this Comparison's direction and ordering
193 // for floating point types.
194 template <typename T,
195 absl::enable_if_t<std::is_floating_point<T>::value ||
196 std::is_same<T, xla::bfloat16>::value,
197 int> = 0>
198 bool Compare(const T a, const T b) const {
199 CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
200 if (IsTotalOrder()) {
201 // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
202 // Reference:
203 // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
204 using R = typename SignedIntegerTypeForSize<sizeof(T)>::type;
205 return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
206 }
207 return GetComparator<T>()(a, b);
208 }
209
210 // Returns the Comparison::Type for the given primitive type. This assumes
211 // that each numerical representation follows the standard behavior, e.g.,
212 // integers are total order and floats are partial order.
213 [[deprecated("Use PrimitiveType and Order")]] static Comparison::Type
214 DefaultComparisonType(PrimitiveType type);
215
216 private:
217 // The direction of the Comparison, e.g., GT.
218 const Direction dir_;
219 // The primitive type of the Comparison operands, e.g., F32.
220 const PrimitiveType primitive_type_;
221 // The ordering of the Comparison, e.g., kPartial.
222 const Order order_;
223 // The Type of the Comparison. This tries to mesh together the ordering and
224 // the numerical data classification.
225 [[deprecated]] const Type type_;
226};
227
228using ComparisonDirection = Comparison::Direction;
229using ComparisonOrder = Comparison::Order;
230
231inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
232 return os << cmp.ToString();
233}
234
235std::string ComparisonDirectionToString(Comparison::Direction direction);
236std::string ComparisonTypeToString(Comparison::Type type);
237std::string ComparisonPrimitiveTypeToString(PrimitiveType type);
238std::string ComparisonOrderToString(Comparison::Order order);
239
240StatusOr<Comparison::Direction> StringToComparisonDirection(
241 absl::string_view direction);
242StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison);
243StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order);
244
245// Returns a comparison function using the provided key function on each value,
246// i.e. `key_fn(a) < key_fn(b)`.
247template <typename KeyFn>
248auto LessThanByKey(KeyFn&& key_fn) {
249 return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); };
250}
251
252} // namespace xla
253
254#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
255