1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <limits> |
16 | #include <memory> |
17 | #include <string> |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor_shape.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | #include "tensorflow/core/platform/fingerprint.h" |
26 | #include "tensorflow/core/util/util.h" |
27 | #include "tensorflow/core/util/work_sharder.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | namespace { |
32 | |
33 | //============================================================================== |
34 | // Feature Readers |
35 | //============================================================================== |
36 | |
37 | // A `FeatureReader` is used to read the feature values from a single input |
38 | // tensor. Subclasses are used for reading different tensor types: |
39 | // * RaggedFeatureReader<value_type, splits_type> |
40 | // * SparseFeatureReader<value_type> |
41 | // * DenseFeatureReader<value_type> |
42 | // |
43 | // Where value_type is one of: {tstring, int64}; and SplitsType is one of: |
44 | // {int32, int64}. |
45 | class FeatureReader { |
46 | public: |
47 | // Returns the number of feature values in the specified batch. |
48 | virtual int64_t FeatureCount(int64_t batch) const = 0; |
49 | |
50 | // Copies the value for the specified feature to `out`. |
51 | virtual void ReadValue(int64_t batch, int64_t n, uint64* out) const = 0; |
52 | virtual void ReadValue(int64_t batch, int64_t n, tstring* out) const = 0; |
53 | |
54 | virtual ~FeatureReader() {} |
55 | }; |
56 | |
57 | using FeatureReaders = std::vector<std::unique_ptr<FeatureReader>>; |
58 | |
59 | // Copies a feature value `src` to a tstring `dst`, using a view if appropriate. |
60 | void CopyToString(const tstring& src, tstring* dst) { |
61 | if (src.type() == tstring::SMALL) { |
62 | *dst = src; // string buffer fits in the tstring object (under ~24 bytes) |
63 | } else { |
64 | dst->assign_as_view(src); |
65 | } |
66 | } |
67 | void CopyToString(int64_t src, tstring* dst) { *dst = std::to_string(src); } |
68 | |
69 | // Copies a feature value `src` to an int64 fingerprint `dst`. |
70 | void CopyToFingerprint(const tstring& feature, uint64* dst) { |
71 | *dst = Fingerprint64(feature); |
72 | } |
73 | void CopyToFingerprint(int64_t feature, uint64* dst) { *dst = feature; } |
74 | |
75 | // A FeatureReader that is backed by a ragged tensor. |
76 | template <typename ValuesType, typename SplitsType> |
77 | class RaggedFeatureReader : public FeatureReader { |
78 | public: |
79 | RaggedFeatureReader(const Tensor& values, const Tensor& row_splits) |
80 | : values_(values.flat<ValuesType>()), |
81 | row_splits_(row_splits.flat<SplitsType>()) {} |
82 | |
83 | int64_t FeatureCount(int64_t batch) const override { |
84 | return row_splits_(batch + 1) - row_splits_(batch); |
85 | } |
86 | |
87 | void ReadValue(int64_t batch, int64_t n, uint64* out) const override { |
88 | CopyToFingerprint(values_(row_splits_(batch) + n), out); |
89 | } |
90 | |
91 | void ReadValue(int64_t batch, int64_t n, tstring* out) const override { |
92 | CopyToString(values_(row_splits_(batch) + n), out); |
93 | } |
94 | |
95 | private: |
96 | const typename TTypes<ValuesType>::ConstFlat values_; |
97 | const typename TTypes<SplitsType>::ConstFlat row_splits_; |
98 | }; |
99 | |
100 | // A FeatureReader that is backed by a dense tensor. |
101 | template <typename ValuesType> |
102 | class DenseFeatureReader : public FeatureReader { |
103 | public: |
104 | explicit DenseFeatureReader(const Tensor& tensor) |
105 | : values_(tensor.matrix<ValuesType>()), |
106 | feature_count_(tensor.dim_size(1)) {} |
107 | |
108 | int64_t FeatureCount(int64_t batch) const override { return feature_count_; } |
109 | |
110 | void ReadValue(int64_t batch, int64_t n, uint64* out) const override { |
111 | CopyToFingerprint(values_(batch, n), out); |
112 | } |
113 | |
114 | void ReadValue(int64_t batch, int64_t n, tstring* out) const override { |
115 | CopyToString(values_(batch, n), out); |
116 | } |
117 | |
118 | private: |
119 | const typename TTypes<ValuesType>::ConstMatrix values_; |
120 | const int64_t feature_count_; |
121 | }; |
122 | |
123 | // A FeatureReader that is backed by a sparse tensor. |
124 | template <typename ValuesType> |
125 | class SparseFeatureReader : public FeatureReader { |
126 | public: |
127 | SparseFeatureReader(const Tensor& indices_t, const Tensor& values_t, |
128 | int64_t batch_size) |
129 | : values_(values_t.flat<ValuesType>()) { |
130 | row_splits_.reserve(batch_size + 1); |
131 | row_splits_.push_back(0); |
132 | auto indices = indices_t.matrix<int64_t>(); |
133 | int64_t num_values = values_.size(); |
134 | int64_t i = 0; // value index |
135 | for (int row = 0; row < batch_size; row++) { |
136 | while (i < num_values && indices(i, 0) <= row) ++i; |
137 | row_splits_.push_back(i); |
138 | } |
139 | } |
140 | |
141 | int64_t FeatureCount(int64_t batch) const override { |
142 | return row_splits_[batch + 1] - row_splits_[batch]; |
143 | } |
144 | |
145 | void ReadValue(int64_t batch, int64_t n, uint64* out) const override { |
146 | CopyToFingerprint(values_(row_splits_[batch] + n), out); |
147 | } |
148 | |
149 | void ReadValue(int64_t batch, int64_t n, tstring* out) const override { |
150 | CopyToString(values_(row_splits_[batch] + n), out); |
151 | } |
152 | |
153 | private: |
154 | const typename TTypes<ValuesType>::ConstFlat values_; |
155 | std::vector<int64_t> row_splits_; |
156 | }; |
157 | |
158 | //============================================================================== |
159 | // Output Writers |
160 | //============================================================================== |
161 | |
162 | // An `OutputWriter` is used to write the feature crosses to the output values |
163 | // tensor. Different subclasses are used for writing different output dtypes: |
164 | // * OutputWriterImpl<tstring, SplitsType> (for tf.ragged.cross) |
165 | // * OutputWriterImpl<int64, SplitsType> (for tf.ragged.cross_hashed) |
166 | class OutputWriter { |
167 | public: |
168 | virtual void WriteOutputSlice(int64_t begin, int64_t end) = 0; |
169 | virtual ~OutputWriter() {} |
170 | }; |
171 | |
172 | template <typename ValuesType, typename SplitsType> |
173 | class OutputWriterImpl : public OutputWriter { |
174 | public: |
175 | using FlatValues = typename TTypes<ValuesType>::Flat; |
176 | using FlatSplits = typename TTypes<SplitsType>::ConstFlat; |
177 | |
178 | OutputWriterImpl(const FeatureReaders& features, int64_t num_buckets, |
179 | uint64 hash_key, const Tensor* splits_out, |
180 | Tensor* values_out) |
181 | : features_(features), |
182 | num_buckets_(num_buckets), |
183 | hash_key_(hash_key), |
184 | splits_out_(splits_out->flat<SplitsType>()), |
185 | values_out_(values_out->flat<ValuesType>()) {} |
186 | |
187 | // Reads features from the specified slice of batch indices, computes |
188 | // feature crosses for each one, and writes them to values_out_. |
189 | void WriteOutputSlice(int64_t begin, int64_t end) override { |
190 | std::vector<int> combination(features_.size(), 0); |
191 | for (int64_t b = begin; b < end; ++b) { |
192 | auto row_start = splits_out_(b); |
193 | auto row_limit = splits_out_(b + 1); |
194 | for (auto i = row_start; i < row_limit; ++i) { |
195 | WriteCombination(b, combination, &values_out_(i)); |
196 | NextCombination(b, &combination); |
197 | } |
198 | combination.assign(features_.size(), 0); // reset for next batch. |
199 | } |
200 | } |
201 | |
202 | private: |
203 | // Joins the specified combination of input features into a single string, |
204 | // and writes it to *out. |
205 | void WriteCombination(int64_t batch_index, |
206 | const std::vector<int>& combination, tstring* out) { |
207 | static const auto k_feature_separator = "_X_" ; |
208 | gtl::InlinedVector<tstring, 6> cross_vec(features_.size()); |
209 | for (int i = 0; i < combination.size(); ++i) { |
210 | features_[i]->ReadValue(batch_index, combination[i], &cross_vec[i]); |
211 | } |
212 | *out = absl::StrJoin(cross_vec, k_feature_separator); |
213 | } |
214 | |
215 | // Joins the specified combination of input features into a single |
216 | // fingerprint, and writes it to *out. |
217 | void WriteCombination(int64_t batch_index, |
218 | const std::vector<int>& combination, int64_t* out) { |
219 | // Do the fingerprint concatenation on uint64. |
220 | uint64 hashed_output = hash_key_; |
221 | for (size_t i = 0; i < combination.size(); ++i) { |
222 | uint64 hash_i; |
223 | features_[i]->ReadValue(batch_index, combination[i], &hash_i); |
224 | hashed_output = FingerprintCat64(hashed_output, hash_i); |
225 | } |
226 | // The return value is int64 based on the number of buckets. |
227 | if (num_buckets_ > 0) { |
228 | *out = hashed_output % num_buckets_; |
229 | } else { |
230 | // To prevent negative output we take modulo to max int64. |
231 | *out = hashed_output % std::numeric_limits<int64_t>::max(); |
232 | } |
233 | } |
234 | |
235 | // Updates `combination` to the next combination of input features. |
236 | void NextCombination(int64_t batch_index, |
237 | std::vector<int>* combination) const { |
238 | bool carry = true; |
239 | for (int i = combination->size() - 1; i >= 0; i--) { |
240 | if (carry) { |
241 | (*combination)[i] = (*combination)[i] + 1; |
242 | } |
243 | if ((*combination)[i] == features_[i]->FeatureCount(batch_index)) { |
244 | (*combination)[i] = 0; |
245 | } else { |
246 | carry = false; |
247 | break; |
248 | } |
249 | } |
250 | } |
251 | |
252 | const FeatureReaders& features_; |
253 | const int64_t num_buckets_; |
254 | const uint64 hash_key_; |
255 | FlatSplits splits_out_; |
256 | FlatValues values_out_; |
257 | }; |
258 | |
259 | // Returns an appropriate OutputWriter, based on the dtypes of the |
260 | // given tensors. |
261 | std::unique_ptr<OutputWriter> MakeOutputWriter(const FeatureReaders& features, |
262 | int64_t num_buckets, |
263 | uint64 hash_key, |
264 | const Tensor* splits_out, |
265 | Tensor* values_out) { |
266 | if (values_out->dtype() == DT_INT64) { |
267 | if (splits_out->dtype() == DT_INT64) { |
268 | return std::make_unique<OutputWriterImpl<int64_t, int64_t>>( |
269 | features, num_buckets, hash_key, splits_out, values_out); |
270 | } else { |
271 | return std::make_unique<OutputWriterImpl<int64_t, int32>>( |
272 | features, num_buckets, hash_key, splits_out, values_out); |
273 | } |
274 | } else { |
275 | if (splits_out->dtype() == DT_INT64) { |
276 | return std::make_unique<OutputWriterImpl<tstring, int64_t>>( |
277 | features, num_buckets, hash_key, splits_out, values_out); |
278 | } else { |
279 | return std::make_unique<OutputWriterImpl<tstring, int32>>( |
280 | features, num_buckets, hash_key, splits_out, values_out); |
281 | } |
282 | } |
283 | } |
284 | |
285 | //============================================================================== |
286 | // RaggedCross Kernel |
287 | //============================================================================== |
288 | |
289 | template <typename SplitsType> |
290 | class RaggedCrossOp : public OpKernel { |
291 | public: |
292 | explicit RaggedCrossOp(OpKernelConstruction* context) : OpKernel(context) { |
293 | OP_REQUIRES_OK(context, context->GetAttr("num_buckets" , &num_buckets_)); |
294 | // Read signed_hash_key_ as int64 since uint64 attributes are not |
295 | // supported by REGISTER_OP. |
296 | int64_t signed_hash_key_; |
297 | OP_REQUIRES_OK(context, context->GetAttr("hash_key" , &signed_hash_key_)); |
298 | hash_key_ = static_cast<uint64>(signed_hash_key_); |
299 | |
300 | int num_sparse; |
301 | OP_REQUIRES_OK(context, context->GetAttr("Nsparse" , &num_sparse)); |
302 | |
303 | OP_REQUIRES_OK(context, context->GetAttr("ragged_values_types" , |
304 | &ragged_values_types_)); |
305 | OP_REQUIRES_OK(context, context->GetAttr("ragged_splits_types" , |
306 | &ragged_splits_types_)); |
307 | OP_REQUIRES_OK(context, context->GetAttr("sparse_values_types" , |
308 | &sparse_values_types_)); |
309 | OP_REQUIRES_OK(context, context->GetAttr("dense_types" , &dense_types_)); |
310 | OP_REQUIRES_OK(context, context->GetAttr("input_order" , &input_order_)); |
311 | OP_REQUIRES(context, |
312 | ragged_values_types_.size() == ragged_splits_types_.size(), |
313 | errors::InvalidArgument( |
314 | "ragged values and splits must have the same length" )); |
315 | OP_REQUIRES(context, num_sparse == sparse_values_types_.size(), |
316 | errors::InvalidArgument( |
317 | "sparse indices and values must have the same length" )); |
318 | OP_REQUIRES(context, |
319 | ragged_values_types_.size() + sparse_values_types_.size() + |
320 | dense_types_.size() == |
321 | input_order_.size(), |
322 | errors::InvalidArgument("Invalid length for input_order" )); |
323 | } |
324 | |
325 | void Compute(OpKernelContext* context) override { |
326 | OpInputList ragged_values_list; |
327 | OpInputList ragged_splits_list; |
328 | OpInputList sparse_indices_list; |
329 | OpInputList sparse_values_list; |
330 | OpInputList sparse_shape_list; |
331 | OpInputList dense_list; |
332 | OP_REQUIRES_OK(context, |
333 | context->input_list("ragged_values" , &ragged_values_list)); |
334 | OP_REQUIRES_OK( |
335 | context, context->input_list("ragged_row_splits" , &ragged_splits_list)); |
336 | OP_REQUIRES_OK(context, |
337 | context->input_list("sparse_indices" , &sparse_indices_list)); |
338 | OP_REQUIRES_OK(context, |
339 | context->input_list("sparse_values" , &sparse_values_list)); |
340 | OP_REQUIRES_OK(context, |
341 | context->input_list("sparse_shape" , &sparse_shape_list)); |
342 | OP_REQUIRES_OK(context, context->input_list("dense_inputs" , &dense_list)); |
343 | OP_REQUIRES_OK(context, |
344 | ValidateInput(ragged_values_list, ragged_splits_list, |
345 | sparse_indices_list, sparse_values_list, |
346 | sparse_shape_list, dense_list)); |
347 | |
348 | int64_t batch_size = |
349 | CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list); |
350 | |
351 | FeatureReaders features; |
352 | OP_REQUIRES_OK(context, |
353 | BuildFeatureReaders(ragged_values_list, ragged_splits_list, |
354 | sparse_indices_list, sparse_values_list, |
355 | dense_list, batch_size, &features)); |
356 | |
357 | Tensor* values_out; |
358 | Tensor* row_splits_out; |
359 | OP_REQUIRES_OK(context, BuildOutputTensors(features, batch_size, context, |
360 | &values_out, &row_splits_out)); |
361 | |
362 | std::unique_ptr<OutputWriter> output_writer = MakeOutputWriter( |
363 | features, num_buckets_, hash_key_, row_splits_out, values_out); |
364 | |
365 | auto do_work = [&output_writer](int64_t begin, int64_t end) { |
366 | output_writer->WriteOutputSlice(begin, end); |
367 | }; |
368 | |
369 | // TODO(edloper): optimize cost_per_batch |
370 | const int cost_per_batch = 5000 * ragged_values_list.size(); |
371 | auto thread_pool = |
372 | context->device()->tensorflow_cpu_worker_threads()->workers; |
373 | thread_pool->ParallelFor(batch_size, cost_per_batch, do_work); |
374 | } |
375 | |
376 | private: |
377 | // Validates input tensors. |
378 | Status ValidateInput(const OpInputList& ragged_values_list, |
379 | const OpInputList& ragged_splits_list, |
380 | const OpInputList& sparse_indices_list, |
381 | const OpInputList& sparse_values_list, |
382 | const OpInputList& sparse_shape_list, |
383 | const OpInputList& dense_list) { |
384 | const auto num_ragged = ragged_values_list.size(); |
385 | const auto num_sparse = sparse_indices_list.size(); |
386 | |
387 | // Validate tensor shapes. |
388 | for (int i = 0; i < num_ragged; ++i) { |
389 | if (!TensorShapeUtils::IsVector(ragged_values_list[i].shape())) { |
390 | return errors::InvalidArgument( |
391 | "tf.ragged.cross only supports inputs with rank=2." ); |
392 | } |
393 | if (!TensorShapeUtils::IsVector(ragged_splits_list[i].shape()) || |
394 | (ragged_splits_list[i].NumElements() == 0)) { |
395 | return errors::InvalidArgument("Invalid RaggedTensor" ); |
396 | } |
397 | } |
398 | for (int i = 0; i < num_sparse; ++i) { |
399 | if (!TensorShapeUtils::IsMatrix(sparse_indices_list[i].shape()) || |
400 | !TensorShapeUtils::IsVector(sparse_values_list[i].shape()) || |
401 | !TensorShapeUtils::IsVector(sparse_shape_list[i].shape())) { |
402 | return errors::InvalidArgument("Invalid SparseTensor " , i); |
403 | } |
404 | if (sparse_shape_list[i].NumElements() != 2) { |
405 | return errors::InvalidArgument( |
406 | "tf.ragged.cross only supports inputs with rank=2." ); |
407 | } |
408 | } |
409 | for (int i = 0; i < dense_list.size(); ++i) { |
410 | if (!TensorShapeUtils::IsMatrix(dense_list[i].shape())) { |
411 | return errors::InvalidArgument( |
412 | "tf.ragged.cross only supports inputs with rank=2." ); |
413 | } |
414 | } |
415 | |
416 | // Check that batch sizes are consistent. |
417 | int64_t batch_size = |
418 | CalculateBatchSize(ragged_splits_list, sparse_shape_list, dense_list); |
419 | for (int i = 0; i < num_ragged; ++i) { |
420 | if (ragged_splits_list[i].NumElements() - 1 != batch_size) { |
421 | return errors::InvalidArgument( |
422 | "inputs must all have the same batch dimension size." ); |
423 | } |
424 | } |
425 | for (int i = 0; i < num_sparse; ++i) { |
426 | if (sparse_shape_list[i].flat<int64_t>()(0) != batch_size) { |
427 | return errors::InvalidArgument( |
428 | "inputs must all have the same batch dimension size." ); |
429 | } |
430 | } |
431 | for (int i = 0; i < dense_list.size(); ++i) { |
432 | if (dense_list[i].dim_size(0) != batch_size) { |
433 | return errors::InvalidArgument( |
434 | "inputs must all have the same batch dimension size." ); |
435 | } |
436 | } |
437 | |
438 | return OkStatus(); |
439 | } |
440 | |
441 | // Calculate the batch size from any input tensor. (We check that all input |
442 | // tensors have the same batch size in `ValidateInput`). |
443 | int64_t CalculateBatchSize(const OpInputList& ragged_splits_list, |
444 | const OpInputList& sparse_shape_list, |
445 | const OpInputList& dense_list) { |
446 | if (ragged_splits_list.size() > 0) { |
447 | return ragged_splits_list[0].NumElements() - 1; |
448 | } else if (dense_list.size() > 0) { |
449 | return dense_list[0].dim_size(0); |
450 | } else if (sparse_shape_list.size() > 0) { |
451 | return sparse_shape_list[0].flat<int64_t>()(0); |
452 | } else { |
453 | return 0; |
454 | } |
455 | } |
456 | |
457 | // Build a feature reader for each input tensor, and store them in `features`. |
458 | Status BuildFeatureReaders(const OpInputList& ragged_values_list, |
459 | const OpInputList& ragged_splits_list, |
460 | const OpInputList& sparse_indices_list, |
461 | const OpInputList& sparse_values_list, |
462 | const OpInputList& dense_list, int64_t batch_size, |
463 | FeatureReaders* features) { |
464 | features->reserve(input_order_.size()); |
465 | |
466 | int next_ragged = 0; |
467 | int next_sparse = 0; |
468 | int next_dense = 0; |
469 | for (char c : input_order_) { |
470 | if (c == 'R') { |
471 | if (next_ragged >= ragged_values_list.size()) |
472 | return errors::InvalidArgument( |
473 | "input_order \"" , input_order_, |
474 | "\" specifies reading a ragged tensor value at index " , |
475 | next_ragged, " from a list of " , ragged_values_list.size(), |
476 | " values." ); |
477 | if (next_ragged >= ragged_splits_list.size()) |
478 | return errors::InvalidArgument( |
479 | "input_order \"" , input_order_, |
480 | "\" specifies reading a ragged tensor split at index " , |
481 | next_ragged, " from a list of " , ragged_splits_list.size(), |
482 | " splits." ); |
483 | TF_RETURN_IF_ERROR(BuildRaggedFeatureReader( |
484 | ragged_values_list[next_ragged], ragged_splits_list[next_ragged], |
485 | features)); |
486 | next_ragged++; |
487 | } else if (c == 'S') { |
488 | if (next_sparse >= sparse_values_list.size()) |
489 | return errors::InvalidArgument( |
490 | "input_order \"" , input_order_, |
491 | "\" specifies reading a sparse tensor value at index " , |
492 | next_sparse, " from a list of " , sparse_values_list.size(), |
493 | " values." ); |
494 | if (next_sparse >= sparse_indices_list.size()) |
495 | return errors::InvalidArgument( |
496 | "input_order \"" , input_order_, |
497 | "\" specifies reading a sparse tensor index at index " , |
498 | next_sparse, " from a list of " , sparse_indices_list.size(), |
499 | " indices." ); |
500 | TF_RETURN_IF_ERROR(BuildSparseFeatureReader( |
501 | sparse_indices_list[next_sparse], sparse_values_list[next_sparse], |
502 | batch_size, features)); |
503 | next_sparse++; |
504 | } else if (c == 'D') { |
505 | if (next_dense >= dense_list.size()) |
506 | return errors::InvalidArgument( |
507 | "input_order \"" , input_order_, |
508 | "\" specifies reading a dense tensor at index " , next_dense, |
509 | " from a list of " , dense_list.size(), " tensors." ); |
510 | TF_RETURN_IF_ERROR( |
511 | BuildDenseFeatureReader(dense_list[next_dense++], features)); |
512 | } else { |
513 | return errors::InvalidArgument("Unexpected input_order value." ); |
514 | } |
515 | } |
516 | |
517 | return OkStatus(); |
518 | } |
519 | |
520 | // Builds a RaggedReatureReader |
521 | static Status BuildRaggedFeatureReader(const Tensor& values, |
522 | const Tensor& splits, |
523 | FeatureReaders* features) { |
524 | if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) { |
525 | return errors::InvalidArgument("Unexpected dtype for input " , |
526 | (features->size() + 1), ": " , |
527 | values.dtype()); |
528 | } |
529 | if (splits.dtype() != DT_INT64 && splits.dtype() != DT_INT32) { |
530 | return errors::InvalidArgument("Unexpected row_splits.dtype for input " , |
531 | (features->size() + 1), ": " , |
532 | values.dtype()); |
533 | } |
534 | if (values.dtype() == DT_INT64) { |
535 | if (splits.dtype() == DT_INT64) { |
536 | features->emplace_back( |
537 | new RaggedFeatureReader<int64_t, int64_t>(values, splits)); |
538 | } else { |
539 | features->emplace_back( |
540 | new RaggedFeatureReader<int64_t, int32>(values, splits)); |
541 | } |
542 | } else { |
543 | if (splits.dtype() == DT_INT64) { |
544 | features->emplace_back( |
545 | new RaggedFeatureReader<tstring, int64_t>(values, splits)); |
546 | } else { |
547 | features->emplace_back( |
548 | new RaggedFeatureReader<tstring, int32>(values, splits)); |
549 | } |
550 | } |
551 | return OkStatus(); |
552 | } |
553 | |
554 | // Builds a DenseFaggedReatureReader. |
555 | static Status BuildDenseFeatureReader(const Tensor& values, |
556 | FeatureReaders* features) { |
557 | if (values.dtype() == DT_INT64) { |
558 | features->emplace_back(new DenseFeatureReader<int64_t>(values)); |
559 | } else if (values.dtype() == DT_STRING) { |
560 | features->emplace_back(new DenseFeatureReader<tstring>(values)); |
561 | } else { |
562 | return errors::InvalidArgument("Unexpected dtype for input " , |
563 | (features->size() + 1), ": " , |
564 | values.dtype()); |
565 | } |
566 | return OkStatus(); |
567 | } |
568 | |
569 | // Builds a SparseFaggedReatureReader. |
570 | static Status BuildSparseFeatureReader(const Tensor& indices, |
571 | const Tensor& values, |
572 | int64_t batch_size, |
573 | FeatureReaders* features) { |
574 | if (values.dtype() == DT_INT64) { |
575 | features->emplace_back( |
576 | new SparseFeatureReader<int64_t>(indices, values, batch_size)); |
577 | } else if (values.dtype() == DT_STRING) { |
578 | features->emplace_back( |
579 | new SparseFeatureReader<tstring>(indices, values, batch_size)); |
580 | } else { |
581 | return errors::InvalidArgument("Unexpected dtype for input " , |
582 | (features->size() + 1), ": " , |
583 | values.dtype()); |
584 | } |
585 | return OkStatus(); |
586 | } |
587 | |
588 | // Allocates output tensors with proper size, and populates row_splits_out. |
589 | Status BuildOutputTensors(const FeatureReaders& features, int64_t batch_size, |
590 | OpKernelContext* context, Tensor** values_out, |
591 | Tensor** row_splits_out) { |
592 | // Allocate and populate the row_splits output tensor. |
593 | TF_RETURN_IF_ERROR(context->allocate_output( |
594 | 1, TensorShape({batch_size + 1}), row_splits_out)); |
595 | auto flat_row_splits = (*row_splits_out)->flat<SplitsType>(); |
596 | int64_t cross_count_total = 0; |
597 | flat_row_splits(0) = 0; |
598 | for (int64_t b = 0; b < batch_size; b++) { |
599 | cross_count_total += CrossCountByBatchIndex(features, b); |
600 | flat_row_splits(b + 1) = cross_count_total; |
601 | } |
602 | |
603 | // Allocate the values output tensor. |
604 | TF_RETURN_IF_ERROR(context->allocate_output( |
605 | 0, TensorShape({cross_count_total}), values_out)); |
606 | |
607 | return OkStatus(); |
608 | } |
609 | |
610 | // Returns number of crosses for a given batch_index |
611 | int64_t CrossCountByBatchIndex(const FeatureReaders& features, |
612 | int batch_index) { |
613 | int64_t cross_count = 1; |
614 | for (int i = 0; i < features.size(); ++i) { |
615 | const auto feature_count = features[i]->FeatureCount(batch_index); |
616 | if (feature_count == 0) return 0; |
617 | cross_count *= feature_count; |
618 | } |
619 | return cross_count; |
620 | } |
621 | |
622 | int64_t num_buckets_; |
623 | uint64 hash_key_; |
624 | std::vector<DataType> ragged_values_types_; |
625 | std::vector<DataType> ragged_splits_types_; |
626 | std::vector<DataType> sparse_values_types_; |
627 | std::vector<DataType> dense_types_; |
628 | tstring input_order_; |
629 | }; |
630 | |
631 | REGISTER_KERNEL_BUILDER(Name("RaggedCross" ) |
632 | .Device(DEVICE_CPU) |
633 | .TypeConstraint<int32>("out_row_splits_type" ), |
634 | RaggedCrossOp<int32>); |
635 | REGISTER_KERNEL_BUILDER(Name("RaggedCross" ) |
636 | .Device(DEVICE_CPU) |
637 | .TypeConstraint<int64_t>("out_row_splits_type" ), |
638 | RaggedCrossOp<int64_t>); |
639 | |
640 | } // namespace |
641 | } // namespace tensorflow |
642 | |