1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <limits>
18
19#define EIGEN_USE_THREADS
20
21#include "absl/container/flat_hash_map.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_types.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace tensorflow {
31
32// Don't allocate too large `BatchedMap<T>` objects
33static int kMaxBatches = std::numeric_limits<int>::max();
34
35template <class T>
36using BatchedMap = std::vector<absl::flat_hash_map<int64_t, T>>;
37
38namespace {
39// TODO(momernick): Extend this function to work with outputs of rank > 2.
40template <class T>
41Status OutputSparse(const BatchedMap<T>& per_batch_counts, int64_t num_values,
42 bool is_1d, OpKernelContext* context) {
43 int total_values = 0;
44 int num_batches = per_batch_counts.size();
45 for (const auto& per_batch_count : per_batch_counts) {
46 total_values += per_batch_count.size();
47 }
48
49 Tensor* indices;
50 int inner_dim = is_1d ? 1 : 2;
51 TF_RETURN_IF_ERROR(context->allocate_output(
52 0, TensorShape({total_values, inner_dim}), &indices));
53
54 Tensor* values;
55 TF_RETURN_IF_ERROR(
56 context->allocate_output(1, TensorShape({total_values}), &values));
57
58 auto output_indices = indices->matrix<int64_t>();
59 auto output_values = values->flat<T>();
60 int64_t value_loc = 0;
61 for (int b = 0; b < num_batches; ++b) {
62 const auto& per_batch_count = per_batch_counts[b];
63 std::vector<std::pair<int64_t, T>> pairs(per_batch_count.begin(),
64 per_batch_count.end());
65 std::sort(pairs.begin(), pairs.end());
66 for (const auto& x : pairs) {
67 if (is_1d) {
68 output_indices(value_loc, 0) = x.first;
69 } else {
70 output_indices(value_loc, 0) = b;
71 output_indices(value_loc, 1) = x.first;
72 }
73 output_values(value_loc) = x.second;
74 ++value_loc;
75 }
76 }
77 Tensor* dense_shape;
78 if (is_1d) {
79 TF_RETURN_IF_ERROR(
80 context->allocate_output(2, TensorShape({1}), &dense_shape));
81 dense_shape->flat<int64_t>().data()[0] = num_values;
82 } else {
83 TF_RETURN_IF_ERROR(
84 context->allocate_output(2, TensorShape({2}), &dense_shape));
85 dense_shape->flat<int64_t>().data()[0] = num_batches;
86 dense_shape->flat<int64_t>().data()[1] = num_values;
87 }
88
89 return OkStatus();
90}
91
92int64_t GetOutputSize(int64_t max_seen, int64_t max_length,
93 int64_t min_length) {
94 return max_length >= 0 ? max_length : std::max((max_seen + 1), min_length);
95}
96
97} // namespace
98
99template <class T, class W>
100class DenseCount : public OpKernel {
101 public:
102 explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
103 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
104 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
105 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
106 }
107
108 void Compute(OpKernelContext* context) override {
109 const Tensor& data = context->input(0);
110 const Tensor& weights = context->input(1);
111 bool use_weights = weights.NumElements() > 0;
112
113 OP_REQUIRES(context,
114 TensorShapeUtils::IsVector(data.shape()) ||
115 TensorShapeUtils::IsMatrix(data.shape()),
116 errors::InvalidArgument(
117 "Input must be a 1 or 2-dimensional tensor. Got: ",
118 data.shape().DebugString()));
119
120 // Ensure all values are non-negative.
121 const auto data_values = data.flat<T>();
122 Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
123 nonnegative.device(context->eigen_cpu_device()) =
124 (data_values >= static_cast<T>(0)).all();
125 OP_REQUIRES(
126 context, nonnegative(),
127 errors::InvalidArgument("Input values must all be non-negative"));
128
129 if (use_weights) {
130 OP_REQUIRES(
131 context, weights.shape() == data.shape(),
132 errors::InvalidArgument(
133 "Weights and data must have the same shape. Weight shape: ",
134 weights.shape().DebugString(),
135 "; data shape: ", data.shape().DebugString()));
136 }
137
138 bool is_1d = TensorShapeUtils::IsVector(data.shape());
139 int negative_valued_axis = -1;
140 int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
141
142 int num_batch_elements = 1;
143 for (int i = 0; i < num_batch_dimensions; ++i) {
144 OP_REQUIRES(context, data.shape().dim_size(i) != 0,
145 errors::InvalidArgument(
146 "Invalid input: Shapes dimension cannot be 0."));
147 num_batch_elements *= data.shape().dim_size(i);
148 }
149 int num_value_elements = data.shape().num_elements() / num_batch_elements;
150 auto per_batch_counts = BatchedMap<W>(num_batch_elements);
151
152 T max_value = 0;
153
154 const auto weight_values = weights.flat<W>();
155 int i = 0;
156 for (int b = 0; b < num_batch_elements; ++b) {
157 for (int v = 0; v < num_value_elements; ++v) {
158 const auto& value = data_values(i);
159 if (maxlength_ < 0 || value < maxlength_) {
160 if (binary_output_) {
161 per_batch_counts[b][value] = 1;
162 } else if (use_weights) {
163 per_batch_counts[b][value] += weight_values(i);
164 } else {
165 per_batch_counts[b][value]++;
166 }
167 if (value > max_value) {
168 max_value = value;
169 }
170 }
171 ++i;
172 }
173 }
174
175 int64_t num_output_values =
176 GetOutputSize(max_value, maxlength_, minlength_);
177 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
178 is_1d, context));
179 }
180
181 private:
182 int64_t maxlength_;
183 int64_t minlength_;
184 bool binary_output_;
185};
186
187template <class T, class W>
188class SparseCount : public OpKernel {
189 public:
190 explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
191 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
192 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
193 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
194 }
195
196 void Compute(OpKernelContext* context) override {
197 const Tensor& indices = context->input(0);
198 const Tensor& values = context->input(1);
199 const Tensor& shape = context->input(2);
200 const Tensor& weights = context->input(3);
201 bool use_weights = weights.NumElements() > 0;
202
203 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
204 errors::InvalidArgument(
205 "Input indices must be a 2-dimensional tensor. Got: ",
206 indices.shape().DebugString()));
207 OP_REQUIRES(context, TensorShapeUtils::IsVector(values.shape()),
208 errors::InvalidArgument("Input values must be a vector. Got: ",
209 values.shape().DebugString()));
210 OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()),
211 errors::InvalidArgument("Input shape must be a vector. Got: ",
212 shape.shape().DebugString()));
213 OP_REQUIRES(context,
214 values.shape().dim_size(0) == indices.shape().dim_size(0),
215 errors::InvalidArgument(
216 "Number of values must match first dimension of indices.",
217 "Got ", values.shape().dim_size(0),
218 " values, indices shape: ", indices.shape().DebugString()));
219 OP_REQUIRES(
220 context, shape.shape().dim_size(0) == indices.shape().dim_size(1),
221 errors::InvalidArgument(
222 "Number of dimensions must match second dimension of indices.",
223 "Got ", shape.shape().dim_size(0),
224 " dimensions, indices shape: ", indices.shape().DebugString()));
225 OP_REQUIRES(context, shape.NumElements() > 0,
226 errors::InvalidArgument(
227 "The shape argument requires at least one element."));
228 // Validate indices: each index must be valid for the corresponding
229 // dimension. This could be possibly done better.
230 const auto indices_values = indices.matrix<int64_t>();
231 const auto shape_vector = shape.vec<int64_t>();
232 int num_values = values.NumElements(); // same as first dim of indices
233 int rank = indices.shape().dim_size(1);
234 for (int i = 0; i < num_values; ++i) {
235 for (int j = 0; j < rank; ++j) {
236 OP_REQUIRES(
237 context,
238 indices_values(i, j) >= 0 && indices_values(i, j) < shape_vector(j),
239 errors::InvalidArgument(
240 "Invalid index value at ", i, ": dimension ", j, " has value ",
241 indices_values(i, j), " which is not in [0, ", shape_vector(j),
242 ") (as given by dense shape ", shape.DebugString()));
243 }
244 }
245
246 // Ensure all values are non-negative.
247 const auto values_values = values.flat<T>();
248 Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
249 nonnegative.device(context->eigen_cpu_device()) =
250 (values_values >= static_cast<T>(0)).all();
251 OP_REQUIRES(
252 context, nonnegative(),
253 errors::InvalidArgument("Input values must all be non-negative"));
254
255 if (use_weights) {
256 OP_REQUIRES(
257 context, weights.shape() == values.shape(),
258 errors::InvalidArgument(
259 "Weights and values must have the same shape. Weight shape: ",
260 weights.shape().DebugString(),
261 "; values shape: ", values.shape().DebugString()));
262 }
263
264 bool is_1d = shape.NumElements() == 1;
265 int num_batches = is_1d ? 1 : shape_vector(0);
266 OP_REQUIRES(
267 context, 0 < num_batches && num_batches < kMaxBatches,
268 errors::InvalidArgument("Cannot allocate ", num_batches,
269 " batches, is the dense shape too wide?"));
270
271 const auto weight_values = weights.flat<W>();
272
273 auto per_batch_counts = BatchedMap<W>(num_batches);
274
275 T max_value = 0;
276
277 for (int idx = 0; idx < num_values; ++idx) {
278 int batch = is_1d ? 0 : indices_values(idx, 0);
279 if (batch >= num_batches) {
280 OP_REQUIRES(context, batch < num_batches,
281 errors::InvalidArgument(
282 "Indices value along the first dimension must be ",
283 "lower than the first index of the shape.", "Got ",
284 batch, " as batch and ", num_batches,
285 " as the first dimension of the shape."));
286 }
287 const auto& value = values_values(idx);
288 if (maxlength_ < 0 || value < maxlength_) {
289 if (binary_output_) {
290 per_batch_counts[batch][value] = 1;
291 } else if (use_weights) {
292 per_batch_counts[batch][value] += weight_values(idx);
293 } else {
294 per_batch_counts[batch][value]++;
295 }
296 if (value > max_value) {
297 max_value = value;
298 }
299 }
300 }
301
302 int64_t num_output_values =
303 GetOutputSize(max_value, maxlength_, minlength_);
304 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
305 is_1d, context));
306 }
307
308 private:
309 int64_t maxlength_;
310 int64_t minlength_;
311 bool binary_output_;
312 bool validate_;
313};
314
315template <class T, class W>
316class RaggedCount : public OpKernel {
317 public:
318 explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
319 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
320 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
321 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
322 }
323
324 void Compute(OpKernelContext* context) override {
325 const Tensor& splits = context->input(0);
326 const Tensor& values = context->input(1);
327 const Tensor& weights = context->input(2);
328 bool use_weights = weights.NumElements() > 0;
329 bool is_1d = false;
330
331 if (use_weights) {
332 OP_REQUIRES(
333 context, weights.shape() == values.shape(),
334 errors::InvalidArgument(
335 "Weights and values must have the same shape. Weight shape: ",
336 weights.shape().DebugString(),
337 "; values shape: ", values.shape().DebugString()));
338 }
339
340 const auto splits_values = splits.flat<int64_t>();
341 const auto values_values = values.flat<T>();
342 const auto weight_values = weights.flat<W>();
343 int num_batches = splits.NumElements() - 1;
344 int num_values = values.NumElements();
345
346 OP_REQUIRES(
347 context, num_batches > 0,
348 errors::InvalidArgument(
349 "Must provide at least 2 elements for the splits argument"));
350 OP_REQUIRES(context, splits_values(0) == 0,
351 errors::InvalidArgument("Splits must start with 0, not with ",
352 splits_values(0)));
353 OP_REQUIRES(context, splits_values(num_batches) == num_values,
354 errors::InvalidArgument(
355 "Splits must end with the number of values, got ",
356 splits_values(num_batches), " instead of ", num_values));
357
358 // Ensure all values are non-negative.
359 Eigen::TensorFixedSize<bool, Eigen::Sizes<>, Eigen::RowMajor> nonnegative;
360 nonnegative.device(context->eigen_cpu_device()) =
361 (values_values >= static_cast<T>(0)).all();
362 OP_REQUIRES(
363 context, nonnegative(),
364 errors::InvalidArgument("Input values must all be non-negative"));
365
366 auto per_batch_counts = BatchedMap<W>(num_batches);
367 T max_value = 0;
368 int batch_idx = 0;
369
370 for (int idx = 0; idx < num_values; ++idx) {
371 while (idx >= splits_values(batch_idx)) {
372 batch_idx++;
373 }
374 const auto& value = values_values(idx);
375 if (maxlength_ < 0 || value < maxlength_) {
376 if (binary_output_) {
377 per_batch_counts[batch_idx - 1][value] = 1;
378 } else if (use_weights) {
379 per_batch_counts[batch_idx - 1][value] += weight_values(idx);
380 } else {
381 per_batch_counts[batch_idx - 1][value]++;
382 }
383 if (value > max_value) {
384 max_value = value;
385 }
386 }
387 }
388
389 int64_t num_output_values =
390 GetOutputSize(max_value, maxlength_, minlength_);
391 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
392 is_1d, context));
393 }
394
395 private:
396 int64_t maxlength_;
397 int64_t minlength_;
398 bool binary_output_;
399 bool validate_;
400};
401
402#define REGISTER_W(W_TYPE) \
403 REGISTER(int32, W_TYPE) \
404 REGISTER(int64_t, W_TYPE)
405
406#define REGISTER(I_TYPE, W_TYPE) \
407 \
408 REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \
409 .TypeConstraint<I_TYPE>("T") \
410 .TypeConstraint<W_TYPE>("output_type") \
411 .Device(DEVICE_CPU), \
412 DenseCount<I_TYPE, W_TYPE>) \
413 \
414 REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \
415 .TypeConstraint<I_TYPE>("T") \
416 .TypeConstraint<W_TYPE>("output_type") \
417 .Device(DEVICE_CPU), \
418 SparseCount<I_TYPE, W_TYPE>) \
419 \
420 REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \
421 .TypeConstraint<I_TYPE>("T") \
422 .TypeConstraint<W_TYPE>("output_type") \
423 .Device(DEVICE_CPU), \
424 RaggedCount<I_TYPE, W_TYPE>)
425
426TF_CALL_INTEGRAL_TYPES(REGISTER_W);
427TF_CALL_float(REGISTER_W);
428TF_CALL_double(REGISTER_W);
429
430#undef REGISTER_W
431#undef REGISTER
432
433} // namespace tensorflow
434