1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <limits>
16#include <memory>
17#include <string>
18#include <vector>
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor_shape.h"
24#include "tensorflow/core/platform/errors.h"
25#include "tensorflow/core/platform/fingerprint.h"
26#include "tensorflow/core/util/util.h"
27#include "tensorflow/core/util/work_sharder.h"
28
29namespace tensorflow {
30
31namespace {
32
33//==============================================================================
34// Feature Readers
35//==============================================================================
36
37// A `FeatureReader` is used to read the feature values from a single input
38// tensor. Subclasses are used for reading different tensor types:
39// * RaggedFeatureReader<value_type, splits_type>
40// * SparseFeatureReader<value_type>
41// * DenseFeatureReader<value_type>
42//
43// Where value_type is one of: {tstring, int64}; and SplitsType is one of:
44// {int32, int64}.
45class FeatureReader {
46 public:
47 // Returns the number of feature values in the specified batch.
48 virtual int64_t FeatureCount(int64_t batch) const = 0;
49
50 // Copies the value for the specified feature to `out`.
51 virtual void ReadValue(int64_t batch, int64_t n, uint64* out) const = 0;
52 virtual void ReadValue(int64_t batch, int64_t n, tstring* out) const = 0;
53
54 virtual ~FeatureReader() {}
55};
56
57using FeatureReaders = std::vector<std::unique_ptr<FeatureReader>>;
58
59// Copies a feature value `src` to a tstring `dst`, using a view if appropriate.
60void CopyToString(const tstring& src, tstring* dst) {
61 if (src.type() == tstring::SMALL) {
62 *dst = src; // string buffer fits in the tstring object (under ~24 bytes)
63 } else {
64 dst->assign_as_view(src);
65 }
66}
67void CopyToString(int64_t src, tstring* dst) { *dst = std::to_string(src); }
68
69// Copies a feature value `src` to an int64 fingerprint `dst`.
70void CopyToFingerprint(const tstring& feature, uint64* dst) {
71 *dst = Fingerprint64(feature);
72}
73void CopyToFingerprint(int64_t feature, uint64* dst) { *dst = feature; }
74
75// A FeatureReader that is backed by a ragged tensor.
76template <typename ValuesType, typename SplitsType>
77class RaggedFeatureReader : public FeatureReader {
78 public:
79 RaggedFeatureReader(const Tensor& values, const Tensor& row_splits)
80 : values_(values.flat<ValuesType>()),
81 row_splits_(row_splits.flat<SplitsType>()) {}
82
83 int64_t FeatureCount(int64_t batch) const override {
84 return row_splits_(batch + 1) - row_splits_(batch);
85 }
86
87 void ReadValue(int64_t batch, int64_t n, uint64* out) const override {
88 CopyToFingerprint(values_(row_splits_(batch) + n), out);
89 }
90
91 void ReadValue(int64_t batch, int64_t n, tstring* out) const override {
92 CopyToString(values_(row_splits_(batch) + n), out);
93 }
94
95 private:
96 const typename TTypes<ValuesType>::ConstFlat values_;
97 const typename TTypes<SplitsType>::ConstFlat row_splits_;
98};
99
100// A FeatureReader that is backed by a dense tensor.
101template <typename ValuesType>
102class DenseFeatureReader : public FeatureReader {
103 public:
104 explicit DenseFeatureReader(const Tensor& tensor)
105 : values_(tensor.matrix<ValuesType>()),
106 feature_count_(tensor.dim_size(1)) {}
107
108 int64_t FeatureCount(int64_t batch) const override { return feature_count_; }
109
110 void ReadValue(int64_t batch, int64_t n, uint64* out) const override {
111 CopyToFingerprint(values_(batch, n), out);
112 }
113
114 void ReadValue(int64_t batch, int64_t n, tstring* out) const override {
115 CopyToString(values_(batch, n), out);
116 }
117
118 private:
119 const typename TTypes<ValuesType>::ConstMatrix values_;
120 const int64_t feature_count_;
121};
122
123// A FeatureReader that is backed by a sparse tensor.
124template <typename ValuesType>
125class SparseFeatureReader : public FeatureReader {
126 public:
127 SparseFeatureReader(const Tensor& indices_t, const Tensor& values_t,
128 int64_t batch_size)
129 : values_(values_t.flat<ValuesType>()) {
130 row_splits_.reserve(batch_size + 1);
131 row_splits_.push_back(0);
132 auto indices = indices_t.matrix<int64_t>();
133 int64_t num_values = values_.size();
134 int64_t i = 0; // value index
135 for (int row = 0; row < batch_size; row++) {
136 while (i < num_values && indices(i, 0) <= row) ++i;
137 row_splits_.push_back(i);
138 }
139 }
140
141 int64_t FeatureCount(int64_t batch) const override {
142 return row_splits_[batch + 1] - row_splits_[batch];
143 }
144
145 void ReadValue(int64_t batch, int64_t n, uint64* out) const override {
146 CopyToFingerprint(values_(row_splits_[batch] + n), out);
147 }
148
149 void ReadValue(int64_t batch, int64_t n, tstring* out) const override {
150 CopyToString(values_(row_splits_[batch] + n), out);
151 }
152
153 private:
154 const typename TTypes<ValuesType>::ConstFlat values_;
155 std::vector<int64_t> row_splits_;
156};
157
158//==============================================================================
159// Output Writers
160//==============================================================================
161
162// An `OutputWriter` is used to write the feature crosses to the output values
163// tensor. Different subclasses are used for writing different output dtypes:
164// * OutputWriterImpl<tstring, SplitsType> (for tf.ragged.cross)
165// * OutputWriterImpl<int64, SplitsType> (for tf.ragged.cross_hashed)
166class OutputWriter {
167 public:
168 virtual void WriteOutputSlice(int64_t begin, int64_t end) = 0;
169 virtual ~OutputWriter() {}
170};
171
172template <typename ValuesType, typename SplitsType>
173class OutputWriterImpl : public OutputWriter {
174 public:
175 using FlatValues = typename TTypes<ValuesType>::Flat;
176 using FlatSplits = typename TTypes<SplitsType>::ConstFlat;
177
178 OutputWriterImpl(const FeatureReaders& features, int64_t num_buckets,
179 uint64 hash_key, const Tensor* splits_out,
180 Tensor* values_out)
181 : features_(features),
182 num_buckets_(num_buckets),
183 hash_key_(hash_key),
184 splits_out_(splits_out->flat<SplitsType>()),
185 values_out_(values_out->flat<ValuesType>()) {}
186
187 // Reads features from the specified slice of batch indices, computes
188 // feature crosses for each one, and writes them to values_out_.
189 void WriteOutputSlice(int64_t begin, int64_t end) override {
190 std::vector<int> combination(features_.size(), 0);
191 for (int64_t b = begin; b < end; ++b) {
192 auto row_start = splits_out_(b);
193 auto row_limit = splits_out_(b + 1);
194 for (auto i = row_start; i < row_limit; ++i) {
195 WriteCombination(b, combination, &values_out_(i));
196 NextCombination(b, &combination);
197 }
198 combination.assign(features_.size(), 0); // reset for next batch.
199 }
200 }
201
202 private:
203 // Joins the specified combination of input features into a single string,
204 // and writes it to *out.
205 void WriteCombination(int64_t batch_index,
206 const std::vector<int>& combination, tstring* out) {
207 static const auto k_feature_separator = "_X_";
208 gtl::InlinedVector<tstring, 6> cross_vec(features_.size());
209 for (int i = 0; i < combination.size(); ++i) {
210 features_[i]->ReadValue(batch_index, combination[i], &cross_vec[i]);
211 }
212 *out = absl::StrJoin(cross_vec, k_feature_separator);
213 }
214
215 // Joins the specified combination of input features into a single
216 // fingerprint, and writes it to *out.
217 void WriteCombination(int64_t batch_index,
218 const std::vector<int>& combination, int64_t* out) {
219 // Do the fingerprint concatenation on uint64.
220 uint64 hashed_output = hash_key_;
221 for (size_t i = 0; i < combination.size(); ++i) {
222 uint64 hash_i;
223 features_[i]->ReadValue(batch_index, combination[i], &hash_i);
224 hashed_output = FingerprintCat64(hashed_output, hash_i);
225 }
226 // The return value is int64 based on the number of buckets.
227 if (num_buckets_ > 0) {
228 *out = hashed_output % num_buckets_;
229 } else {
230 // To prevent negative output we take modulo to max int64.
231 *out = hashed_output % std::numeric_limits<int64_t>::max();
232 }
233 }
234
235 // Updates `combination` to the next combination of input features.
236 void NextCombination(int64_t batch_index,
237 std::vector<int>* combination) const {
238 bool carry = true;
239 for (int i = combination->size() - 1; i >= 0; i--) {
240 if (carry) {
241 (*combination)[i] = (*combination)[i] + 1;
242 }
243 if ((*combination)[i] == features_[i]->FeatureCount(batch_index)) {
244 (*combination)[i] = 0;
245 } else {
246 carry = false;
247 break;
248 }
249 }
250 }
251
252 const FeatureReaders& features_;
253 const int64_t num_buckets_;
254 const uint64 hash_key_;
255 FlatSplits splits_out_;
256 FlatValues values_out_;
257};
258
259// Returns an appropriate OutputWriter, based on the dtypes of the
260// given tensors.
261std::unique_ptr<OutputWriter> MakeOutputWriter(const FeatureReaders& features,
262 int64_t num_buckets,
263 uint64 hash_key,
264 const Tensor* splits_out,
265 Tensor* values_out) {
266 if (values_out->dtype() == DT_INT64) {
267 if (splits_out->dtype() == DT_INT64) {
268 return std::make_unique<OutputWriterImpl<int64_t, int64_t>>(
269 features, num_buckets, hash_key, splits_out, values_out);
270 } else {
271 return std::make_unique<OutputWriterImpl<int64_t, int32>>(
272 features, num_buckets, hash_key, splits_out, values_out);
273 }
274 } else {
275 if (splits_out->dtype() == DT_INT64) {
276 return std::make_unique<OutputWriterImpl<tstring, int64_t>>(
277 features, num_buckets, hash_key, splits_out, values_out);
278 } else {
279 return std::make_unique<OutputWriterImpl<tstring, int32>>(
280 features, num_buckets, hash_key, splits_out, values_out);
281 }
282 }
283}
284
285//==============================================================================
286// RaggedCross Kernel
287//==============================================================================
288
289template <typename SplitsType>
290class RaggedCrossOp : public OpKernel {
291 public:
292 explicit RaggedCrossOp(OpKernelConstruction* context) : OpKernel(context) {
293 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
294 // Read signed_hash_key_ as int64 since uint64 attributes are not
295 // supported by REGISTER_OP.
296 int64_t signed_hash_key_;
297 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
298 hash_key_ = static_cast<uint64>(signed_hash_key_);
299
300 int num_sparse;
301 OP_REQUIRES_OK(context, context->GetAttr("Nsparse", &num_sparse));
302
303 OP_REQUIRES_OK(context, context->GetAttr("ragged_values_types",
304 &ragged_values_types_));
305 OP_REQUIRES_OK(context, context->GetAttr("ragged_splits_types",
306 &ragged_splits_types_));
307 OP_REQUIRES_OK(context, context->GetAttr("sparse_values_types",
308 &sparse_values_types_));
309 OP_REQUIRES_OK(context, context->GetAttr("dense_types", &dense_types_));
310 OP_REQUIRES_OK(context, context->GetAttr("input_order", &input_order_));
311 OP_REQUIRES(context,
312 ragged_values_types_.size() == ragged_splits_types_.size(),
313 errors::InvalidArgument(
314 "ragged values and splits must have the same length"));
315 OP_REQUIRES(context, num_sparse == sparse_values_types_.size(),
316 errors::InvalidArgument(
317 "sparse indices and values must have the same length"));
318 OP_REQUIRES(context,
319 ragged_values_types_.size() + sparse_values_types_.size() +
320 dense_types_.size() ==
321 input_order_.size(),
322 errors::InvalidArgument("Invalid length for input_order"));
323 }
324
325 void Compute(OpKernelContext* context) override {
326 OpInputList ragged_values_list;
327 OpInputList ragged_splits_list;
328 OpInputList sparse_indices_list;
329 OpInputList sparse_values_list;
330 OpInputList sparse_shape_list;
331 OpInputList dense_list;
332 OP_REQUIRES_OK(context,
333 context->input_list("ragged_values", &ragged_values_list));
334 OP_REQUIRES_OK(
335 context, context->input_list("ragged_row_splits", &ragged_splits_list));
336 OP_REQUIRES_OK(context,
337 context->input_list("sparse_indices", &sparse_indices_list));
338 OP_REQUIRES_OK(context,
339 context->input_list("sparse_values", &sparse_values_list));
340 OP_REQUIRES_OK(context,
341 context->input_list("sparse_shape", &sparse_shape_list));
342 OP_REQUIRES_OK(context, context->input_list("dense_inputs", &dense_list));
343 OP_REQUIRES_OK(context,
344 ValidateInput(ragged_values_list, ragged_splits_list,
345 sparse_indices_list, sparse_values_list,
346 sparse_shape_list, dense_list));
347
348 int64_t batch_size =
349 CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
350
351 FeatureReaders features;
352 OP_REQUIRES_OK(context,
353 BuildFeatureReaders(ragged_values_list, ragged_splits_list,
354 sparse_indices_list, sparse_values_list,
355 dense_list, batch_size, &features));
356
357 Tensor* values_out;
358 Tensor* row_splits_out;
359 OP_REQUIRES_OK(context, BuildOutputTensors(features, batch_size, context,
360 &values_out, &row_splits_out));
361
362 std::unique_ptr<OutputWriter> output_writer = MakeOutputWriter(
363 features, num_buckets_, hash_key_, row_splits_out, values_out);
364
365 auto do_work = [&output_writer](int64_t begin, int64_t end) {
366 output_writer->WriteOutputSlice(begin, end);
367 };
368
369 // TODO(edloper): optimize cost_per_batch
370 const int cost_per_batch = 5000 * ragged_values_list.size();
371 auto thread_pool =
372 context->device()->tensorflow_cpu_worker_threads()->workers;
373 thread_pool->ParallelFor(batch_size, cost_per_batch, do_work);
374 }
375
376 private:
377 // Validates input tensors.
378 Status ValidateInput(const OpInputList& ragged_values_list,
379 const OpInputList& ragged_splits_list,
380 const OpInputList& sparse_indices_list,
381 const OpInputList& sparse_values_list,
382 const OpInputList& sparse_shape_list,
383 const OpInputList& dense_list) {
384 const auto num_ragged = ragged_values_list.size();
385 const auto num_sparse = sparse_indices_list.size();
386
387 // Validate tensor shapes.
388 for (int i = 0; i < num_ragged; ++i) {
389 if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) {
390 return errors::InvalidArgument(
391 "tf.ragged.cross only supports inputs with rank=2.");
392 }
393 if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) ||
394 (ragged_splits_list[i].NumElements() == 0)) {
395 return errors::InvalidArgument("Invalid RaggedTensor");
396 }
397 }
398 for (int i = 0; i < num_sparse; ++i) {
399 if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) ||
400 !TensorShapeUtils::IsVector(sparse_values_list[i].shape()) ||
401 !TensorShapeUtils::IsVector(sparse_shape_list[i].shape())) {
402 return errors::InvalidArgument("Invalid SparseTensor ", i);
403 }
404 if (sparse_shape_list[i].NumElements() != 2) {
405 return errors::InvalidArgument(
406 "tf.ragged.cross only supports inputs with rank=2.");
407 }
408 }
409 for (int i = 0; i < dense_list.size(); ++i) {
410 if (!TensorShapeUtils::IsMatrix(dense_list[i].shape())) {
411 return errors::InvalidArgument(
412 "tf.ragged.cross only supports inputs with rank=2.");
413 }
414 }
415
416 // Check that batch sizes are consistent.
417 int64_t batch_size =
418 CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list);
419 for (int i = 0; i < num_ragged; ++i) {
420 if (ragged_splits_list[i].NumElements() - 1 != batch_size) {
421 return errors::InvalidArgument(
422 "inputs must all have the same batch dimension size.");
423 }
424 }
425 for (int i = 0; i < num_sparse; ++i) {
426 if (sparse_shape_list[i].flat<int64_t>()(0) != batch_size) {
427 return errors::InvalidArgument(
428 "inputs must all have the same batch dimension size.");
429 }
430 }
431 for (int i = 0; i < dense_list.size(); ++i) {
432 if (dense_list[i].dim_size(0) != batch_size) {
433 return errors::InvalidArgument(
434 "inputs must all have the same batch dimension size.");
435 }
436 }
437
438 return OkStatus();
439 }
440
441 // Calculate the batch size from any input tensor. (We check that all input
442 // tensors have the same batch size in `ValidateInput`).
443 int64_t CalculateBatchSize(const OpInputList& ragged_splits_list,
444 const OpInputList& sparse_shape_list,
445 const OpInputList& dense_list) {
446 if (ragged_splits_list.size() > 0) {
447 return ragged_splits_list[0].NumElements() - 1;
448 } else if (dense_list.size() > 0) {
449 return dense_list[0].dim_size(0);
450 } else if (sparse_shape_list.size() > 0) {
451 return sparse_shape_list[0].flat<int64_t>()(0);
452 } else {
453 return 0;
454 }
455 }
456
457 // Build a feature reader for each input tensor, and store them in `features`.
458 Status BuildFeatureReaders(const OpInputList& ragged_values_list,
459 const OpInputList& ragged_splits_list,
460 const OpInputList& sparse_indices_list,
461 const OpInputList& sparse_values_list,
462 const OpInputList& dense_list, int64_t batch_size,
463 FeatureReaders* features) {
464 features->reserve(input_order_.size());
465
466 int next_ragged = 0;
467 int next_sparse = 0;
468 int next_dense = 0;
469 for (char c : input_order_) {
470 if (c == 'R') {
471 if (next_ragged >= ragged_values_list.size())
472 return errors::InvalidArgument(
473 "input_order \"", input_order_,
474 "\" specifies reading a ragged tensor value at index ",
475 next_ragged, " from a list of ", ragged_values_list.size(),
476 " values.");
477 if (next_ragged >= ragged_splits_list.size())
478 return errors::InvalidArgument(
479 "input_order \"", input_order_,
480 "\" specifies reading a ragged tensor split at index ",
481 next_ragged, " from a list of ", ragged_splits_list.size(),
482 " splits.");
483 TF_RETURN_IF_ERROR(BuildRaggedFeatureReader(
484 ragged_values_list[next_ragged], ragged_splits_list[next_ragged],
485 features));
486 next_ragged++;
487 } else if (c == 'S') {
488 if (next_sparse >= sparse_values_list.size())
489 return errors::InvalidArgument(
490 "input_order \"", input_order_,
491 "\" specifies reading a sparse tensor value at index ",
492 next_sparse, " from a list of ", sparse_values_list.size(),
493 " values.");
494 if (next_sparse >= sparse_indices_list.size())
495 return errors::InvalidArgument(
496 "input_order \"", input_order_,
497 "\" specifies reading a sparse tensor index at index ",
498 next_sparse, " from a list of ", sparse_indices_list.size(),
499 " indices.");
500 TF_RETURN_IF_ERROR(BuildSparseFeatureReader(
501 sparse_indices_list[next_sparse], sparse_values_list[next_sparse],
502 batch_size, features));
503 next_sparse++;
504 } else if (c == 'D') {
505 if (next_dense >= dense_list.size())
506 return errors::InvalidArgument(
507 "input_order \"", input_order_,
508 "\" specifies reading a dense tensor at index ", next_dense,
509 " from a list of ", dense_list.size(), " tensors.");
510 TF_RETURN_IF_ERROR(
511 BuildDenseFeatureReader(dense_list[next_dense++], features));
512 } else {
513 return errors::InvalidArgument("Unexpected input_order value.");
514 }
515 }
516
517 return OkStatus();
518 }
519
520 // Builds a RaggedReatureReader
521 static Status BuildRaggedFeatureReader(const Tensor& values,
522 const Tensor& splits,
523 FeatureReaders* features) {
524 if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) {
525 return errors::InvalidArgument("Unexpected dtype for input ",
526 (features->size() + 1), ": ",
527 values.dtype());
528 }
529 if (splits.dtype() != DT_INT64 && splits.dtype() != DT_INT32) {
530 return errors::InvalidArgument("Unexpected row_splits.dtype for input ",
531 (features->size() + 1), ": ",
532 values.dtype());
533 }
534 if (values.dtype() == DT_INT64) {
535 if (splits.dtype() == DT_INT64) {
536 features->emplace_back(
537 new RaggedFeatureReader<int64_t, int64_t>(values, splits));
538 } else {
539 features->emplace_back(
540 new RaggedFeatureReader<int64_t, int32>(values, splits));
541 }
542 } else {
543 if (splits.dtype() == DT_INT64) {
544 features->emplace_back(
545 new RaggedFeatureReader<tstring, int64_t>(values, splits));
546 } else {
547 features->emplace_back(
548 new RaggedFeatureReader<tstring, int32>(values, splits));
549 }
550 }
551 return OkStatus();
552 }
553
554 // Builds a DenseFaggedReatureReader.
555 static Status BuildDenseFeatureReader(const Tensor& values,
556 FeatureReaders* features) {
557 if (values.dtype() == DT_INT64) {
558 features->emplace_back(new DenseFeatureReader<int64_t>(values));
559 } else if (values.dtype() == DT_STRING) {
560 features->emplace_back(new DenseFeatureReader<tstring>(values));
561 } else {
562 return errors::InvalidArgument("Unexpected dtype for input ",
563 (features->size() + 1), ": ",
564 values.dtype());
565 }
566 return OkStatus();
567 }
568
569 // Builds a SparseFaggedReatureReader.
570 static Status BuildSparseFeatureReader(const Tensor& indices,
571 const Tensor& values,
572 int64_t batch_size,
573 FeatureReaders* features) {
574 if (values.dtype() == DT_INT64) {
575 features->emplace_back(
576 new SparseFeatureReader<int64_t>(indices, values, batch_size));
577 } else if (values.dtype() == DT_STRING) {
578 features->emplace_back(
579 new SparseFeatureReader<tstring>(indices, values, batch_size));
580 } else {
581 return errors::InvalidArgument("Unexpected dtype for input ",
582 (features->size() + 1), ": ",
583 values.dtype());
584 }
585 return OkStatus();
586 }
587
588 // Allocates output tensors with proper size, and populates row_splits_out.
589 Status BuildOutputTensors(const FeatureReaders& features, int64_t batch_size,
590 OpKernelContext* context, Tensor** values_out,
591 Tensor** row_splits_out) {
592 // Allocate and populate the row_splits output tensor.
593 TF_RETURN_IF_ERROR(context->allocate_output(
594 1, TensorShape({batch_size + 1}), row_splits_out));
595 auto flat_row_splits = (*row_splits_out)->flat<SplitsType>();
596 int64_t cross_count_total = 0;
597 flat_row_splits(0) = 0;
598 for (int64_t b = 0; b < batch_size; b++) {
599 cross_count_total += CrossCountByBatchIndex(features, b);
600 flat_row_splits(b + 1) = cross_count_total;
601 }
602
603 // Allocate the values output tensor.
604 TF_RETURN_IF_ERROR(context->allocate_output(
605 0, TensorShape({cross_count_total}), values_out));
606
607 return OkStatus();
608 }
609
610 // Returns number of crosses for a given batch_index
611 int64_t CrossCountByBatchIndex(const FeatureReaders& features,
612 int batch_index) {
613 int64_t cross_count = 1;
614 for (int i = 0; i < features.size(); ++i) {
615 const auto feature_count = features[i]->FeatureCount(batch_index);
616 if (feature_count == 0) return 0;
617 cross_count *= feature_count;
618 }
619 return cross_count;
620 }
621
622 int64_t num_buckets_;
623 uint64 hash_key_;
624 std::vector<DataType> ragged_values_types_;
625 std::vector<DataType> ragged_splits_types_;
626 std::vector<DataType> sparse_values_types_;
627 std::vector<DataType> dense_types_;
628 tstring input_order_;
629};
630
631REGISTER_KERNEL_BUILDER(Name("RaggedCross")
632 .Device(DEVICE_CPU)
633 .TypeConstraint<int32>("out_row_splits_type"),
634 RaggedCrossOp<int32>);
635REGISTER_KERNEL_BUILDER(Name("RaggedCross")
636 .Device(DEVICE_CPU)
637 .TypeConstraint<int64_t>("out_row_splits_type"),
638 RaggedCrossOp<int64_t>);
639
640} // namespace
641} // namespace tensorflow
642