1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <functional> |
17 | #include <unordered_map> |
18 | #include <utility> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/lib/hash/hash.h" |
28 | #include "tensorflow/core/platform/bfloat16.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace { |
32 | |
33 | typedef Eigen::ThreadPoolDevice CPUDevice; |
34 | |
35 | // `UniqueOpHashMap` defines the map type that is used when elements of type |
36 | // `T` are to be uniquified. By default, we use `absl::flat_hash_map<T, TIndex>` |
37 | // as the map type. Subsequent specializations are provided for |
38 | // performance and/or correctness. |
39 | template <typename T, typename TIndex> |
40 | struct UniqueOpHashMap { |
41 | using map_type = absl::flat_hash_map<T, TIndex>; |
42 | }; |
43 | |
44 | // NOTE(mrry): For `tstring` elements, we use an `absl::string_view` key to |
45 | // avoid copying the input strings into the map. |
46 | template <typename TIndex> |
47 | struct UniqueOpHashMap<tstring, TIndex> { |
48 | using map_type = absl::flat_hash_map<absl::string_view, TIndex>; |
49 | }; |
50 | |
51 | // NOTE(mrry): `absl::flat_hash_map<float, ...>` does not allow `NaN` as a key, |
52 | // because `NaN != NaN`, so we fall back to `std::unordered_map<>` for |
53 | // floating-point types. |
54 | template <typename TIndex> |
55 | struct UniqueOpHashMap<float, TIndex> { |
56 | using map_type = std::unordered_map<float, TIndex>; |
57 | }; |
58 | template <typename TIndex> |
59 | struct UniqueOpHashMap<double, TIndex> { |
60 | using map_type = std::unordered_map<double, TIndex>; |
61 | }; |
62 | template <typename TIndex> |
63 | struct UniqueOpHashMap<Eigen::half, TIndex> { |
64 | using map_type = std::unordered_map<Eigen::half, TIndex>; |
65 | }; |
66 | template <typename TIndex> |
67 | struct UniqueOpHashMap<bfloat16, TIndex> { |
68 | using map_type = std::unordered_map<bfloat16, TIndex>; |
69 | }; |
70 | |
71 | // `UniqueOp` computes the unique elements in the input tensor. |
72 | // |
73 | // * `T` is the element type. |
74 | // * `TIndex` is the type used to represent indices in the output, either |
75 | // `int32` or `int64`. |
76 | template <typename T, typename TIndex> |
77 | class UniqueOp : public OpKernel { |
78 | public: |
79 | explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {} |
80 | |
81 | void Compute(OpKernelContext* context) override { |
82 | const Tensor& input = context->input(0); |
83 | // TODO(dga): Make unique polymorphic for returning int32 and int64 |
84 | // vectors to support large tensors. |
85 | OP_REQUIRES(context, |
86 | input.NumElements() <= std::numeric_limits<int32>::max(), |
87 | errors::InvalidArgument( |
88 | "unique does not support input tensors larger than " , |
89 | std::numeric_limits<int32>::max(), " elements" )); |
90 | |
91 | int64_t axis = 0; |
92 | std::vector<int64_t> new_sizes{1, input.NumElements(), 1}; |
93 | if (context->num_inputs() == 1) { |
94 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), |
95 | errors::InvalidArgument("unique expects a 1D vector." )); |
96 | } else { |
97 | // In case of UniqueV2, the axis is a 1D vector. The purpose is |
98 | // to allow specifying either "no axis" or "axis". The `[]` means |
99 | // "no axis", while `[x]` means `axis = x`. |
100 | const Tensor& axis_tensor = context->input(1); |
101 | OP_REQUIRES(context, TensorShapeUtils::IsVector(axis_tensor.shape()), |
102 | errors::InvalidArgument("axis expects a 1D vector." )); |
103 | OP_REQUIRES( |
104 | context, axis_tensor.NumElements() <= 1, |
105 | errors::InvalidArgument( |
106 | "axis does not support input tensors larger than 1 elements" )); |
107 | if (axis_tensor.NumElements() == 0) { |
108 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), |
109 | errors::InvalidArgument("unique expects a 1D vector." )); |
110 | } else { |
111 | OP_REQUIRES(context, |
112 | (axis_tensor.dtype() == DT_INT32 || |
113 | axis_tensor.dtype() == DT_INT64), |
114 | errors::InvalidArgument( |
115 | "axis tensor should be int32 or int64, but got " , |
116 | DataTypeString(axis_tensor.dtype()))); |
117 | if (axis_tensor.dtype() == DT_INT32) { |
118 | axis = internal::SubtleMustCopy(axis_tensor.scalar<int32>()()); |
119 | } else { |
120 | axis = internal::SubtleMustCopy(axis_tensor.scalar<int64_t>()()); |
121 | } |
122 | axis = axis < 0 ? axis + input.dims() : axis; |
123 | OP_REQUIRES(context, 0 <= axis && axis < input.dims(), |
124 | errors::InvalidArgument("axis has to be between [0, " , |
125 | input.dims(), ")" )); |
126 | if (axis > 0) { |
127 | for (int64_t i = 0; i < axis; i++) { |
128 | new_sizes[0] *= input.dim_size(i); |
129 | } |
130 | } |
131 | new_sizes[1] = input.dim_size(axis); |
132 | if (axis + 1 < input.dims()) { |
133 | for (int64_t i = axis + 1; i < input.dims(); i++) { |
134 | new_sizes[2] *= input.dim_size(i); |
135 | } |
136 | } |
137 | } |
138 | } |
139 | |
140 | Tensor* idx = nullptr; |
141 | OP_REQUIRES_OK(context, context->allocate_output( |
142 | 1, TensorShape({new_sizes[1]}), &idx)); |
143 | auto idx_vec = idx->template vec<TIndex>(); |
144 | |
145 | int64_t uniq_size; |
146 | if (new_sizes[0] == 1 && new_sizes[2] == 1) { |
147 | // Specialized and faster implementation when unique is run over single |
148 | // elements. Here we put T directly into the map rather than ints pointing |
149 | // to them as in the general case. |
150 | auto Tin = input.flat<T>(); |
151 | const int64_t N = static_cast<int64_t>(Tin.size()); |
152 | |
153 | typename UniqueOpHashMap<T, TIndex>::map_type uniq; |
154 | uniq.reserve(2 * N); |
155 | for (Eigen::Index i = 0, j = 0; i < N; ++i) { |
156 | auto it = uniq.emplace(Tin(i), j); |
157 | idx_vec(i) = it.first->second; |
158 | if (it.second) { |
159 | ++j; |
160 | } |
161 | } |
162 | |
163 | uniq_size = static_cast<int64_t>(uniq.size()); |
164 | TensorShape output_shape(input.shape()); |
165 | output_shape.set_dim(axis, uniq_size); |
166 | Tensor* output = nullptr; |
167 | OP_REQUIRES_OK(context, |
168 | context->allocate_output(0, output_shape, &output)); |
169 | auto Tout = output->flat<T>(); |
170 | |
171 | for (const auto& it : uniq) { |
172 | Tout(it.second) = it.first; |
173 | } |
174 | } else { |
175 | // General implementation when unique is run over multiple elements. |
176 | auto Tin = input.shaped<T, 3>(new_sizes); |
177 | |
178 | auto hash_fn = [&Tin](const Eigen::Index& key) { |
179 | size_t h = 0; |
180 | for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { |
181 | for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { |
182 | h = Hash64Combine(h, hash<T>{}(Tin(i, key, j))); |
183 | } |
184 | } |
185 | return h; |
186 | }; |
187 | |
188 | auto equal_to_fn = [&Tin](const Eigen::Index& lhs, |
189 | const Eigen::Index& rhs) { |
190 | for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { |
191 | for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { |
192 | if (Tin(i, lhs, j) != Tin(i, rhs, j)) { |
193 | return false; |
194 | } |
195 | } |
196 | } |
197 | return true; |
198 | }; |
199 | |
200 | absl::flat_hash_map<int64_t, int64_t, decltype(hash_fn), |
201 | decltype(equal_to_fn)> |
202 | uniq(0, hash_fn, equal_to_fn); |
203 | |
204 | uniq.reserve(2 * Tin.dimension(1)); |
205 | |
206 | for (int64_t i = 0, j = 0; i < Tin.dimension(1); ++i) { |
207 | auto it = uniq.emplace(i, j); |
208 | idx_vec(i) = it.first->second; |
209 | if (it.second) { |
210 | ++j; |
211 | } |
212 | } |
213 | |
214 | uniq_size = static_cast<int64_t>(uniq.size()); |
215 | new_sizes[1] = uniq_size; |
216 | TensorShape output_shape(input.shape()); |
217 | output_shape.set_dim(axis, uniq_size); |
218 | Tensor* output = nullptr; |
219 | OP_REQUIRES_OK(context, |
220 | context->allocate_output(0, output_shape, &output)); |
221 | auto Tout = output->shaped<T, 3>(new_sizes); |
222 | |
223 | for (auto it : uniq) { |
224 | Tout.chip(it.second, 1) = Tin.chip(it.first, 1); |
225 | } |
226 | } |
227 | |
228 | if (num_outputs() > 2) { |
229 | Tensor* output = nullptr; |
230 | OP_REQUIRES_OK(context, context->allocate_output( |
231 | 2, TensorShape({uniq_size}), &output)); |
232 | auto count_output_vec = output->template vec<TIndex>(); |
233 | count_output_vec.setZero(); |
234 | const int N = idx_vec.size(); |
235 | for (int64_t i = 0; i < N; ++i) { |
236 | count_output_vec(idx_vec(i))++; |
237 | } |
238 | } |
239 | } |
240 | }; |
241 | |
242 | #define REGISTER_UNIQUE(type) \ |
243 | REGISTER_KERNEL_BUILDER(Name("Unique") \ |
244 | .Device(DEVICE_CPU) \ |
245 | .TypeConstraint<type>("T") \ |
246 | .TypeConstraint<int32>("out_idx"), \ |
247 | UniqueOp<type, int32>); \ |
248 | REGISTER_KERNEL_BUILDER(Name("Unique") \ |
249 | .Device(DEVICE_CPU) \ |
250 | .TypeConstraint<type>("T") \ |
251 | .TypeConstraint<int64_t>("out_idx"), \ |
252 | UniqueOp<type, int64>); \ |
253 | REGISTER_KERNEL_BUILDER(Name("UniqueV2") \ |
254 | .Device(DEVICE_CPU) \ |
255 | .TypeConstraint<type>("T") \ |
256 | .TypeConstraint<int32>("out_idx"), \ |
257 | UniqueOp<type, int32>); \ |
258 | REGISTER_KERNEL_BUILDER(Name("UniqueV2") \ |
259 | .Device(DEVICE_CPU) \ |
260 | .TypeConstraint<type>("T") \ |
261 | .TypeConstraint<int64_t>("out_idx"), \ |
262 | UniqueOp<type, int64>); \ |
263 | REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \ |
264 | .Device(DEVICE_CPU) \ |
265 | .TypeConstraint<type>("T") \ |
266 | .TypeConstraint<int32>("out_idx"), \ |
267 | UniqueOp<type, int32>) \ |
268 | REGISTER_KERNEL_BUILDER(Name("UniqueWithCounts") \ |
269 | .Device(DEVICE_CPU) \ |
270 | .TypeConstraint<type>("T") \ |
271 | .TypeConstraint<int64_t>("out_idx"), \ |
272 | UniqueOp<type, int64>); \ |
273 | REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ |
274 | .Device(DEVICE_CPU) \ |
275 | .TypeConstraint<type>("T") \ |
276 | .TypeConstraint<int32>("out_idx"), \ |
277 | UniqueOp<type, int32>) \ |
278 | REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ |
279 | .Device(DEVICE_CPU) \ |
280 | .TypeConstraint<type>("T") \ |
281 | .TypeConstraint<int64_t>("out_idx"), \ |
282 | UniqueOp<type, int64>) |
283 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); |
284 | REGISTER_UNIQUE(tstring) |
285 | REGISTER_UNIQUE(bool) |
286 | #undef REGISTER_UNIQUE |
287 | |
288 | // Fake integer GPU kernels so that the use of Unique in optimizers (to |
289 | // de-duplicate sparse gradient indices) does not conflict with gradients being |
290 | // located on a GPU. These kernels run on the CPU, their inputs and outputs |
291 | // residing in host (not GPU) memory. |
292 | #define REGISTER_UNIQUE_DEVICE(type) \ |
293 | REGISTER_KERNEL_BUILDER(Name("Unique") \ |
294 | .Device(DEVICE_DEFAULT) \ |
295 | .TypeConstraint<type>("T") \ |
296 | .TypeConstraint<int32>("out_idx") \ |
297 | .HostMemory("x") \ |
298 | .HostMemory("y") \ |
299 | .HostMemory("idx"), \ |
300 | UniqueOp<type, int32>); \ |
301 | REGISTER_KERNEL_BUILDER(Name("Unique") \ |
302 | .Device(DEVICE_DEFAULT) \ |
303 | .TypeConstraint<type>("T") \ |
304 | .TypeConstraint<int64_t>("out_idx") \ |
305 | .HostMemory("x") \ |
306 | .HostMemory("y") \ |
307 | .HostMemory("idx"), \ |
308 | UniqueOp<type, int64>); |
309 | |
310 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE_DEVICE); |
311 | REGISTER_UNIQUE_DEVICE(tstring) |
312 | REGISTER_UNIQUE_DEVICE(bool) |
313 | #undef REGISTER_UNIQUE_DEVICE |
314 | |
315 | } // namespace |
316 | } // namespace tensorflow |
317 | |