1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <limits> |
18 | |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_types.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | // Don't allocate too large `BatchedMap<T>` objects |
33 | static int kMaxBatches = std::numeric_limits<int>::max(); |
34 | |
35 | template <class T> |
36 | using BatchedMap = std::vector<absl::flat_hash_map<int64_t, T>>; |
37 | |
38 | namespace { |
39 | // TODO(momernick): Extend this function to work with outputs of rank > 2. |
40 | template <class T> |
41 | Status OutputSparse(const BatchedMap<T>& per_batch_counts, int64_t num_values, |
42 | bool is_1d, OpKernelContext* context) { |
43 | int total_values = 0; |
44 | int num_batches = per_batch_counts.size(); |
45 | for (const auto& per_batch_count : per_batch_counts) { |
46 | total_values += per_batch_count.size(); |
47 | } |
48 | |
49 | Tensor* indices; |
50 | int inner_dim = is_1d ? 1 : 2; |
51 | TF_RETURN_IF_ERROR(context->allocate_output( |
52 | 0, TensorShape({total_values, inner_dim}), &indices)); |
53 | |
54 | Tensor* values; |
55 | TF_RETURN_IF_ERROR( |
56 | context->allocate_output(1, TensorShape({total_values}), &values)); |
57 | |
58 | auto output_indices = indices->matrix<int64_t>(); |
59 | auto output_values = values->flat<T>(); |
60 | int64_t value_loc = 0; |
61 | for (int b = 0; b < num_batches; ++b) { |
62 | const auto& per_batch_count = per_batch_counts[b]; |
63 | std::vector<std::pair<int64_t, T>> pairs(per_batch_count.begin(), |
64 | per_batch_count.end()); |
65 | std::sort(pairs.begin(), pairs.end()); |
66 | for (const auto& x : pairs) { |
67 | if (is_1d) { |
68 | output_indices(value_loc, 0) = x.first; |
69 | } else { |
70 | output_indices(value_loc, 0) = b; |
71 | output_indices(value_loc, 1) = x.first; |
72 | } |
73 | output_values(value_loc) = x.second; |
74 | ++value_loc; |
75 | } |
76 | } |
77 | Tensor* dense_shape; |
78 | if (is_1d) { |
79 | TF_RETURN_IF_ERROR( |
80 | context->allocate_output(2, TensorShape({1}), &dense_shape)); |
81 | dense_shape->flat<int64_t>().data()[0] = num_values; |
82 | } else { |
83 | TF_RETURN_IF_ERROR( |
84 | context->allocate_output(2, TensorShape({2}), &dense_shape)); |
85 | dense_shape->flat<int64_t>().data()[0] = num_batches; |
86 | dense_shape->flat<int64_t>().data()[1] = num_values; |
87 | } |
88 | |
89 | return OkStatus(); |
90 | } |
91 | |
92 | int64_t GetOutputSize(int64_t max_seen, int64_t max_length, |
93 | int64_t min_length) { |
94 | return max_length >= 0 ? max_length : std::max((max_seen + 1), min_length); |
95 | } |
96 | |
97 | } // namespace |
98 | |
99 | template <class T, class W> |
100 | class DenseCount : public OpKernel { |
101 | public: |
102 | explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) { |
103 | OP_REQUIRES_OK(context, context->GetAttr("minlength" , &minlength_)); |
104 | OP_REQUIRES_OK(context, context->GetAttr("maxlength" , &maxlength_)); |
105 | OP_REQUIRES_OK(context, context->GetAttr("binary_output" , &binary_output_)); |
106 | } |
107 | |
108 | void Compute(OpKernelContext* context) override { |
109 | const Tensor& data = context->input(0); |
110 | const Tensor& weights = context->input(1); |
111 | bool use_weights = weights.NumElements() > 0; |
112 | |
113 | OP_REQUIRES(context, |
114 | TensorShapeUtils::IsVector(data.shape()) || |
115 | TensorShapeUtils::IsMatrix(data.shape()), |
116 | errors::InvalidArgument( |
117 | "Input must be a 1 or 2-dimensional tensor. Got: " , |
118 | data.shape().DebugString())); |
119 | |
120 | // Ensure all values are non-negative. |
121 | const auto data_values = data.flat<T>(); |
122 | Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative; |
123 | nonnegative.device(context->eigen_cpu_device()) = |
124 | (data_values >= static_cast<T>(0)).all(); |
125 | OP_REQUIRES( |
126 | context, nonnegative(), |
127 | errors::InvalidArgument("Input values must all be non-negative" )); |
128 | |
129 | if (use_weights) { |
130 | OP_REQUIRES( |
131 | context, weights.shape() == data.shape(), |
132 | errors::InvalidArgument( |
133 | "Weights and data must have the same shape. Weight shape: " , |
134 | weights.shape().DebugString(), |
135 | "; data shape: " , data.shape().DebugString())); |
136 | } |
137 | |
138 | bool is_1d = TensorShapeUtils::IsVector(data.shape()); |
139 | int negative_valued_axis = -1; |
140 | int num_batch_dimensions = (data.shape().dims() + negative_valued_axis); |
141 | |
142 | int num_batch_elements = 1; |
143 | for (int i = 0; i < num_batch_dimensions; ++i) { |
144 | OP_REQUIRES(context, data.shape().dim_size(i) != 0, |
145 | errors::InvalidArgument( |
146 | "Invalid input: Shapes dimension cannot be 0." )); |
147 | num_batch_elements *= data.shape().dim_size(i); |
148 | } |
149 | int num_value_elements = data.shape().num_elements() / num_batch_elements; |
150 | auto per_batch_counts = BatchedMap<W>(num_batch_elements); |
151 | |
152 | T max_value = 0; |
153 | |
154 | const auto weight_values = weights.flat<W>(); |
155 | int i = 0; |
156 | for (int b = 0; b < num_batch_elements; ++b) { |
157 | for (int v = 0; v < num_value_elements; ++v) { |
158 | const auto& value = data_values(i); |
159 | if (maxlength_ < 0 || value < maxlength_) { |
160 | if (binary_output_) { |
161 | per_batch_counts[b][value] = 1; |
162 | } else if (use_weights) { |
163 | per_batch_counts[b][value] += weight_values(i); |
164 | } else { |
165 | per_batch_counts[b][value]++; |
166 | } |
167 | if (value > max_value) { |
168 | max_value = value; |
169 | } |
170 | } |
171 | ++i; |
172 | } |
173 | } |
174 | |
175 | int64_t num_output_values = |
176 | GetOutputSize(max_value, maxlength_, minlength_); |
177 | OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values, |
178 | is_1d, context)); |
179 | } |
180 | |
181 | private: |
182 | int64_t maxlength_; |
183 | int64_t minlength_; |
184 | bool binary_output_; |
185 | }; |
186 | |
187 | template <class T, class W> |
188 | class SparseCount : public OpKernel { |
189 | public: |
190 | explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) { |
191 | OP_REQUIRES_OK(context, context->GetAttr("minlength" , &minlength_)); |
192 | OP_REQUIRES_OK(context, context->GetAttr("maxlength" , &maxlength_)); |
193 | OP_REQUIRES_OK(context, context->GetAttr("binary_output" , &binary_output_)); |
194 | } |
195 | |
196 | void Compute(OpKernelContext* context) override { |
197 | const Tensor& indices = context->input(0); |
198 | const Tensor& values = context->input(1); |
199 | const Tensor& shape = context->input(2); |
200 | const Tensor& weights = context->input(3); |
201 | bool use_weights = weights.NumElements() > 0; |
202 | |
203 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()), |
204 | errors::InvalidArgument( |
205 | "Input indices must be a 2-dimensional tensor. Got: " , |
206 | indices.shape().DebugString())); |
207 | OP_REQUIRES(context, TensorShapeUtils::IsVector(values.shape()), |
208 | errors::InvalidArgument("Input values must be a vector. Got: " , |
209 | values.shape().DebugString())); |
210 | OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()), |
211 | errors::InvalidArgument("Input shape must be a vector. Got: " , |
212 | shape.shape().DebugString())); |
213 | OP_REQUIRES(context, |
214 | values.shape().dim_size(0) == indices.shape().dim_size(0), |
215 | errors::InvalidArgument( |
216 | "Number of values must match first dimension of indices." , |
217 | "Got " , values.shape().dim_size(0), |
218 | " values, indices shape: " , indices.shape().DebugString())); |
219 | OP_REQUIRES( |
220 | context, shape.shape().dim_size(0) == indices.shape().dim_size(1), |
221 | errors::InvalidArgument( |
222 | "Number of dimensions must match second dimension of indices." , |
223 | "Got " , shape.shape().dim_size(0), |
224 | " dimensions, indices shape: " , indices.shape().DebugString())); |
225 | OP_REQUIRES(context, shape.NumElements() > 0, |
226 | errors::InvalidArgument( |
227 | "The shape argument requires at least one element." )); |
228 | // Validate indices: each index must be valid for the corresponding |
229 | // dimension. This could be possibly done better. |
230 | const auto indices_values = indices.matrix<int64_t>(); |
231 | const auto shape_vector = shape.vec<int64_t>(); |
232 | int num_values = values.NumElements(); // same as first dim of indices |
233 | int rank = indices.shape().dim_size(1); |
234 | for (int i = 0; i < num_values; ++i) { |
235 | for (int j = 0; j < rank; ++j) { |
236 | OP_REQUIRES( |
237 | context, |
238 | indices_values(i, j) >= 0 && indices_values(i, j) < shape_vector(j), |
239 | errors::InvalidArgument( |
240 | "Invalid index value at " , i, ": dimension " , j, " has value " , |
241 | indices_values(i, j), " which is not in [0, " , shape_vector(j), |
242 | ") (as given by dense shape " , shape.DebugString())); |
243 | } |
244 | } |
245 | |
246 | // Ensure all values are non-negative. |
247 | const auto values_values = values.flat<T>(); |
248 | Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative; |
249 | nonnegative.device(context->eigen_cpu_device()) = |
250 | (values_values >= static_cast<T>(0)).all(); |
251 | OP_REQUIRES( |
252 | context, nonnegative(), |
253 | errors::InvalidArgument("Input values must all be non-negative" )); |
254 | |
255 | if (use_weights) { |
256 | OP_REQUIRES( |
257 | context, weights.shape() == values.shape(), |
258 | errors::InvalidArgument( |
259 | "Weights and values must have the same shape. Weight shape: " , |
260 | weights.shape().DebugString(), |
261 | "; values shape: " , values.shape().DebugString())); |
262 | } |
263 | |
264 | bool is_1d = shape.NumElements() == 1; |
265 | int num_batches = is_1d ? 1 : shape_vector(0); |
266 | OP_REQUIRES( |
267 | context, 0 < num_batches && num_batches < kMaxBatches, |
268 | errors::InvalidArgument("Cannot allocate " , num_batches, |
269 | " batches, is the dense shape too wide?" )); |
270 | |
271 | const auto weight_values = weights.flat<W>(); |
272 | |
273 | auto per_batch_counts = BatchedMap<W>(num_batches); |
274 | |
275 | T max_value = 0; |
276 | |
277 | for (int idx = 0; idx < num_values; ++idx) { |
278 | int batch = is_1d ? 0 : indices_values(idx, 0); |
279 | if (batch >= num_batches) { |
280 | OP_REQUIRES(context, batch < num_batches, |
281 | errors::InvalidArgument( |
282 | "Indices value along the first dimension must be " , |
283 | "lower than the first index of the shape." , "Got " , |
284 | batch, " as batch and " , num_batches, |
285 | " as the first dimension of the shape." )); |
286 | } |
287 | const auto& value = values_values(idx); |
288 | if (maxlength_ < 0 || value < maxlength_) { |
289 | if (binary_output_) { |
290 | per_batch_counts[batch][value] = 1; |
291 | } else if (use_weights) { |
292 | per_batch_counts[batch][value] += weight_values(idx); |
293 | } else { |
294 | per_batch_counts[batch][value]++; |
295 | } |
296 | if (value > max_value) { |
297 | max_value = value; |
298 | } |
299 | } |
300 | } |
301 | |
302 | int64_t num_output_values = |
303 | GetOutputSize(max_value, maxlength_, minlength_); |
304 | OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values, |
305 | is_1d, context)); |
306 | } |
307 | |
308 | private: |
309 | int64_t maxlength_; |
310 | int64_t minlength_; |
311 | bool binary_output_; |
312 | bool validate_; |
313 | }; |
314 | |
315 | template <class T, class W> |
316 | class RaggedCount : public OpKernel { |
317 | public: |
318 | explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) { |
319 | OP_REQUIRES_OK(context, context->GetAttr("minlength" , &minlength_)); |
320 | OP_REQUIRES_OK(context, context->GetAttr("maxlength" , &maxlength_)); |
321 | OP_REQUIRES_OK(context, context->GetAttr("binary_output" , &binary_output_)); |
322 | } |
323 | |
324 | void Compute(OpKernelContext* context) override { |
325 | const Tensor& splits = context->input(0); |
326 | const Tensor& values = context->input(1); |
327 | const Tensor& weights = context->input(2); |
328 | bool use_weights = weights.NumElements() > 0; |
329 | bool is_1d = false; |
330 | |
331 | if (use_weights) { |
332 | OP_REQUIRES( |
333 | context, weights.shape() == values.shape(), |
334 | errors::InvalidArgument( |
335 | "Weights and values must have the same shape. Weight shape: " , |
336 | weights.shape().DebugString(), |
337 | "; values shape: " , values.shape().DebugString())); |
338 | } |
339 | |
340 | const auto splits_values = splits.flat<int64_t>(); |
341 | const auto values_values = values.flat<T>(); |
342 | const auto weight_values = weights.flat<W>(); |
343 | int num_batches = splits.NumElements() - 1; |
344 | int num_values = values.NumElements(); |
345 | |
346 | OP_REQUIRES( |
347 | context, num_batches > 0, |
348 | errors::InvalidArgument( |
349 | "Must provide at least 2 elements for the splits argument" )); |
350 | OP_REQUIRES(context, splits_values(0) == 0, |
351 | errors::InvalidArgument("Splits must start with 0, not with " , |
352 | splits_values(0))); |
353 | OP_REQUIRES(context, splits_values(num_batches) == num_values, |
354 | errors::InvalidArgument( |
355 | "Splits must end with the number of values, got " , |
356 | splits_values(num_batches), " instead of " , num_values)); |
357 | |
358 | // Ensure all values are non-negative. |
359 | Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative; |
360 | nonnegative.device(context->eigen_cpu_device()) = |
361 | (values_values >= static_cast<T>(0)).all(); |
362 | OP_REQUIRES( |
363 | context, nonnegative(), |
364 | errors::InvalidArgument("Input values must all be non-negative" )); |
365 | |
366 | auto per_batch_counts = BatchedMap<W>(num_batches); |
367 | T max_value = 0; |
368 | int batch_idx = 0; |
369 | |
370 | for (int idx = 0; idx < num_values; ++idx) { |
371 | while (idx >= splits_values(batch_idx)) { |
372 | batch_idx++; |
373 | } |
374 | const auto& value = values_values(idx); |
375 | if (maxlength_ < 0 || value < maxlength_) { |
376 | if (binary_output_) { |
377 | per_batch_counts[batch_idx - 1][value] = 1; |
378 | } else if (use_weights) { |
379 | per_batch_counts[batch_idx - 1][value] += weight_values(idx); |
380 | } else { |
381 | per_batch_counts[batch_idx - 1][value]++; |
382 | } |
383 | if (value > max_value) { |
384 | max_value = value; |
385 | } |
386 | } |
387 | } |
388 | |
389 | int64_t num_output_values = |
390 | GetOutputSize(max_value, maxlength_, minlength_); |
391 | OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values, |
392 | is_1d, context)); |
393 | } |
394 | |
395 | private: |
396 | int64_t maxlength_; |
397 | int64_t minlength_; |
398 | bool binary_output_; |
399 | bool validate_; |
400 | }; |
401 | |
402 | #define REGISTER_W(W_TYPE) \ |
403 | REGISTER(int32, W_TYPE) \ |
404 | REGISTER(int64_t, W_TYPE) |
405 | |
406 | #define REGISTER(I_TYPE, W_TYPE) \ |
407 | \ |
408 | REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \ |
409 | .TypeConstraint<I_TYPE>("T") \ |
410 | .TypeConstraint<W_TYPE>("output_type") \ |
411 | .Device(DEVICE_CPU), \ |
412 | DenseCount<I_TYPE, W_TYPE>) \ |
413 | \ |
414 | REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \ |
415 | .TypeConstraint<I_TYPE>("T") \ |
416 | .TypeConstraint<W_TYPE>("output_type") \ |
417 | .Device(DEVICE_CPU), \ |
418 | SparseCount<I_TYPE, W_TYPE>) \ |
419 | \ |
420 | REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \ |
421 | .TypeConstraint<I_TYPE>("T") \ |
422 | .TypeConstraint<W_TYPE>("output_type") \ |
423 | .Device(DEVICE_CPU), \ |
424 | RaggedCount<I_TYPE, W_TYPE>) |
425 | |
426 | TF_CALL_INTEGRAL_TYPES(REGISTER_W); |
427 | TF_CALL_float(REGISTER_W); |
428 | TF_CALL_double(REGISTER_W); |
429 | |
430 | #undef REGISTER_W |
431 | #undef REGISTER |
432 | |
433 | } // namespace tensorflow |
434 | |