1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <stddef.h> |
19 | |
20 | #include <algorithm> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/kernel_def_builder.h" |
25 | #include "tensorflow/core/framework/numeric_types.h" |
26 | #include "tensorflow/core/framework/op.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/register_types.h" |
29 | #include "tensorflow/core/framework/shape_inference.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
33 | #include "tensorflow/core/framework/tensor_types.h" |
34 | #include "tensorflow/core/framework/types.h" |
35 | #include "tensorflow/core/kernels/broadcast_to_op.h" |
36 | #include "tensorflow/core/kernels/list_kernels.h" |
37 | #include "tensorflow/core/lib/core/errors.h" |
38 | #include "tensorflow/core/lib/core/status.h" |
39 | #include "tensorflow/core/platform/bfloat16.h" |
40 | #include "tensorflow/core/platform/types.h" |
41 | #include "tensorflow/core/util/bcast.h" |
42 | #include "tensorflow/core/util/ragged_to_dense_util.h" |
43 | |
44 | namespace tensorflow { |
45 | |
46 | namespace { |
47 | typedef Eigen::ThreadPoolDevice CPUDevice; |
48 | using ::std::vector; |
49 | |
50 | const int kShapeInputIndex = 0; |
51 | const int kValueInputIndex = 1; |
52 | const int kDefaultValueInputIndex = 2; |
53 | const int kFirstPartitionInputIndex = 3; |
54 | |
55 | template <typename INDEX_TYPE> |
56 | class RaggedTensorToTensorBaseOp : public OpKernel { |
57 | public: |
58 | typedef |
59 | typename ::tensorflow::TTypes<const INDEX_TYPE>::Flat RowPartitionTensor; |
60 | |
61 | explicit RaggedTensorToTensorBaseOp(OpKernelConstruction* context) |
62 | : OpKernel(context) { |
63 | OP_REQUIRES_OK(context, GetRowPartitionTypes<OpKernelConstruction>( |
64 | context, &row_partition_types_)); |
65 | ragged_rank_ = GetRaggedRank(row_partition_types_); |
66 | } |
67 | |
68 | // Returns the relationship between dimension and dimension + 1. |
69 | RowPartitionType GetRowPartitionTypeByDimension(int dimension) { |
70 | if (row_partition_types_[0] == RowPartitionType::FIRST_DIM_SIZE) { |
71 | return row_partition_types_[dimension + 1]; |
72 | } else { |
73 | return row_partition_types_[dimension]; |
74 | } |
75 | } |
76 | |
77 | // Returns the relationship between dimension and dimension + 1. |
78 | RowPartitionTensor GetRowPartitionTensor(OpKernelContext* c, int dimension) { |
79 | if (row_partition_types_[0] == RowPartitionType::FIRST_DIM_SIZE) { |
80 | return c->input(dimension + 1 + kFirstPartitionInputIndex) |
81 | .flat<INDEX_TYPE>(); |
82 | } else { |
83 | return c->input(dimension + kFirstPartitionInputIndex).flat<INDEX_TYPE>(); |
84 | } |
85 | } |
86 | |
87 | Status GetMaxWidth(OpKernelContext* c, int dimension, INDEX_TYPE* result) { |
88 | const RowPartitionTensor row_partition_tensor = |
89 | GetRowPartitionTensor(c, dimension - 1); |
90 | switch (GetRowPartitionTypeByDimension(dimension - 1)) { |
91 | case RowPartitionType::VALUE_ROWIDS: |
92 | *result = GetMaxWidthValueRowID(row_partition_tensor); |
93 | return OkStatus(); |
94 | case RowPartitionType::ROW_SPLITS: |
95 | *result = GetMaxWidthRowSplit(row_partition_tensor); |
96 | return OkStatus(); |
97 | default: |
98 | return errors::InvalidArgument( |
99 | "Cannot handle partition type " , |
100 | RowPartitionTypeToString( |
101 | GetRowPartitionTypeByDimension(dimension - 1))); |
102 | } |
103 | } |
104 | |
105 | static INDEX_TYPE GetMaxWidthRowSplit(const RowPartitionTensor& row_split) { |
106 | const INDEX_TYPE tensor_length = row_split.size(); |
107 | if (tensor_length == 0 || tensor_length == 1) { |
108 | return 0; |
109 | } |
110 | INDEX_TYPE max_width = 0; |
111 | for (INDEX_TYPE i = 0; i < tensor_length - 1; ++i) { |
112 | const INDEX_TYPE current_width = row_split(i + 1) - row_split(i); |
113 | if (current_width > max_width) { |
114 | max_width = current_width; |
115 | } |
116 | } |
117 | return max_width; |
118 | } |
119 | |
120 | static INDEX_TYPE GetMaxWidthValueRowID( |
121 | const RowPartitionTensor& value_rowids) { |
122 | const INDEX_TYPE index_length = value_rowids.size(); |
123 | if (index_length == 0) { |
124 | return 0; |
125 | } |
126 | INDEX_TYPE first_equal_index = 0; |
127 | INDEX_TYPE first_equal_index_value = value_rowids(0); |
128 | INDEX_TYPE max_width = 0; |
129 | for (INDEX_TYPE i = 1; i < index_length; ++i) { |
130 | const INDEX_TYPE value = value_rowids(i); |
131 | if (value != first_equal_index_value) { |
132 | first_equal_index_value = value; |
133 | max_width = std::max(i - first_equal_index, max_width); |
134 | first_equal_index = i; |
135 | } |
136 | } |
137 | return std::max(index_length - first_equal_index, max_width); |
138 | } |
139 | |
140 | Status CalculateOutputSize(INDEX_TYPE first_dim, OpKernelContext* c, |
141 | vector<INDEX_TYPE>* result) { |
142 | TensorShapeProto value_shape_proto; |
143 | c->input(kValueInputIndex).shape().AsProto(&value_shape_proto); |
144 | |
145 | TensorShapeProto default_value_shape_proto; |
146 | c->input(kDefaultValueInputIndex) |
147 | .shape() |
148 | .AsProto(&default_value_shape_proto); |
149 | |
150 | TensorShapeProto output_shape_proto; |
151 | TF_RETURN_IF_ERROR(ValidateDefaultValueShape(default_value_shape_proto, |
152 | value_shape_proto)); |
153 | |
154 | TensorShapeProto shape_proto; |
155 | { |
156 | PartialTensorShape partial_tensor_shape; |
157 | TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(kShapeInputIndex), |
158 | &partial_tensor_shape)); |
159 | partial_tensor_shape.AsProto(&shape_proto); |
160 | } |
161 | |
162 | TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes( |
163 | ragged_rank_, shape_proto, value_shape_proto, &output_shape_proto)); |
164 | |
165 | result->reserve(output_shape_proto.dim_size()); |
166 | for (const TensorShapeProto::Dim& dim : output_shape_proto.dim()) { |
167 | // Note that this may be -1 (if dimension size is unknown). |
168 | result->push_back(dim.size()); |
169 | } |
170 | |
171 | if ((*result)[0] < 0) { |
172 | (*result)[0] = first_dim; |
173 | } |
174 | for (int i = 1; i <= ragged_rank_; ++i) { |
175 | if ((*result)[i] < 0) { |
176 | TF_RETURN_IF_ERROR(GetMaxWidth(c, i, &(*result)[i])); |
177 | } |
178 | } |
179 | return OkStatus(); |
180 | } |
181 | |
182 | /** |
183 | * The output_index represents the index in the output tensor |
184 | * where the first element of a particular dimension would be written. |
185 | * If it is -1, it indicates that the index is out of scope. |
186 | * Example, given first_dimension = 10, first_dimension_output = 6, |
187 | * and output_index_multiplier = 100: |
188 | * result = [0 100 200 300 400 500 -1 -1 -1 -1] |
189 | * If first_dimension_output = 11 instead, then: |
190 | * result = [0 100 200 300 400 500 600 700 800 900] |
191 | */ |
192 | void CalculateFirstParentOutputIndex(INDEX_TYPE first_dimension, |
193 | INDEX_TYPE output_index_multiplier, |
194 | INDEX_TYPE first_dimension_output, |
195 | vector<INDEX_TYPE>* result) { |
196 | const INDEX_TYPE min_dimension = |
197 | std::min(first_dimension, first_dimension_output); |
198 | result->reserve(first_dimension); |
199 | int current_output_index = 0; |
200 | for (INDEX_TYPE i = 0; i < min_dimension; |
201 | ++i, current_output_index += output_index_multiplier) { |
202 | result->push_back(current_output_index); |
203 | } |
204 | for (INDEX_TYPE i = min_dimension; i < first_dimension; ++i) { |
205 | result->push_back(-1); |
206 | } |
207 | DCHECK_EQ(result->size(), first_dimension); |
208 | } |
209 | |
210 | Status CalculateOutputIndexRowSplit( |
211 | const RowPartitionTensor& row_split, |
212 | const vector<INDEX_TYPE>& parent_output_index, |
213 | INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, |
214 | vector<INDEX_TYPE>* result) { |
215 | INDEX_TYPE row_split_size = row_split.size(); |
216 | if (row_split_size > 0) { |
217 | result->reserve(row_split(row_split_size - 1)); |
218 | } |
219 | for (INDEX_TYPE i = 0; i < row_split_size - 1; ++i) { |
220 | INDEX_TYPE row_length = row_split(i + 1) - row_split(i); |
221 | INDEX_TYPE real_length = std::min(output_size, row_length); |
222 | INDEX_TYPE parent_output_index_current = parent_output_index[i]; |
223 | |
224 | if (parent_output_index_current == -1) { |
225 | real_length = 0; |
226 | } |
227 | for (INDEX_TYPE j = 0; j < real_length; ++j) { |
228 | result->push_back(parent_output_index_current); |
229 | parent_output_index_current += output_index_multiplier; |
230 | } |
231 | for (INDEX_TYPE j = 0; j < row_length - real_length; ++j) { |
232 | result->push_back(-1); |
233 | } |
234 | } |
235 | if (row_split_size > 0 && result->size() != row_split(row_split_size - 1)) { |
236 | return errors::InvalidArgument("Invalid row split size." ); |
237 | } |
238 | |
239 | return OkStatus(); |
240 | } |
241 | |
242 | // Calculate the output index of the first element of a list. |
243 | // The parent_output_index is the same computation for the previous list. |
244 | // -1 indicates an element or list that is out of range. |
245 | // The output_index_multiplier is the number of output indices one moves |
246 | // forward for each column. |
247 | // E.g., given: |
248 | // value_rowids:[0 1 2 2 2 3 5 5 6] |
249 | // parent_output_index:[1000 1100 2000 2100 -1 3000 4000] |
250 | // output_index_multiplier: 10 |
251 | // output_size: 2 |
252 | // You get: |
253 | // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000] |
254 | // result[0] = parent_output_index[value_rowids[0]] |
255 | // result[1] = parent_output_index[value_rowids[1]] |
256 | // result[2] = parent_output_index[value_rowids[2]] |
257 | // result[3] = parent_output_index[value_rowids[2] + 10] |
258 | // result[4] = -1 because it is the third element the size is 2. |
259 | // result[5] = parent_output_index[value_rowids[3]] |
260 | // result[6] = -1 because parent_output_index[value_rowids[6]] == -1 |
261 | // result[7] = -1 because parent_output_index[value_rowids[6]] == -1 |
262 | // result[8] = parent_output_index[value_rowids[7]] |
263 | Status CalculateOutputIndexValueRowID( |
264 | const RowPartitionTensor& value_rowids, |
265 | const vector<INDEX_TYPE>& parent_output_index, |
266 | INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, |
267 | vector<INDEX_TYPE>* result) { |
268 | const INDEX_TYPE index_size = value_rowids.size(); |
269 | result->reserve(index_size); |
270 | if (index_size == 0) { |
271 | return OkStatus(); |
272 | } |
273 | |
274 | INDEX_TYPE current_output_column = 0; |
275 | INDEX_TYPE current_value_rowid = value_rowids(0); |
276 | |
277 | if (current_value_rowid >= parent_output_index.size()) { |
278 | return errors::InvalidArgument( |
279 | "Got current_value_rowid=" , current_value_rowid, |
280 | " which is not less than " , parent_output_index.size()); |
281 | } |
282 | |
283 | INDEX_TYPE current_output_index = parent_output_index[current_value_rowid]; |
284 | result->push_back(current_output_index); |
285 | for (INDEX_TYPE i = 1; i < index_size; ++i) { |
286 | INDEX_TYPE next_value_rowid = value_rowids(i); |
287 | if (next_value_rowid == current_value_rowid) { |
288 | if (current_output_index >= 0) { |
289 | ++current_output_column; |
290 | if (current_output_column < output_size) { |
291 | current_output_index += output_index_multiplier; |
292 | } else { |
293 | current_output_index = -1; |
294 | } |
295 | } |
296 | } else { |
297 | current_output_column = 0; |
298 | current_value_rowid = next_value_rowid; |
299 | |
300 | if (next_value_rowid >= parent_output_index.size()) { |
301 | return errors::InvalidArgument( |
302 | "Got next_value_rowid=" , next_value_rowid, |
303 | " which is not less than " , parent_output_index.size()); |
304 | } |
305 | |
306 | current_output_index = parent_output_index[next_value_rowid]; |
307 | } |
308 | result->push_back(current_output_index); |
309 | } |
310 | |
311 | if (result->size() != value_rowids.size()) { |
312 | return errors::InvalidArgument("Invalid row ids." ); |
313 | } |
314 | |
315 | return OkStatus(); |
316 | } |
317 | |
318 | Status CalculateOutputIndex(OpKernelContext* context, int dimension, |
319 | const vector<INDEX_TYPE>& parent_output_index, |
320 | INDEX_TYPE output_index_multiplier, |
321 | INDEX_TYPE output_size, |
322 | vector<INDEX_TYPE>* result) { |
323 | const RowPartitionTensor row_partition_tensor = |
324 | GetRowPartitionTensor(context, dimension); |
325 | auto partition_type = GetRowPartitionTypeByDimension(dimension); |
326 | switch (partition_type) { |
327 | case RowPartitionType::VALUE_ROWIDS: |
328 | return CalculateOutputIndexValueRowID( |
329 | row_partition_tensor, parent_output_index, output_index_multiplier, |
330 | output_size, result); |
331 | case RowPartitionType::ROW_SPLITS: |
332 | if (row_partition_tensor.size() - 1 > parent_output_index.size()) { |
333 | return errors::InvalidArgument( |
334 | "Row partition size is greater than output size: " , |
335 | row_partition_tensor.size() - 1, " > " , |
336 | parent_output_index.size()); |
337 | } |
338 | return CalculateOutputIndexRowSplit( |
339 | row_partition_tensor, parent_output_index, output_index_multiplier, |
340 | output_size, result); |
341 | default: |
342 | return errors::InvalidArgument( |
343 | "Unsupported partition type:" , |
344 | RowPartitionTypeToString(partition_type)); |
345 | } |
346 | } |
347 | |
348 | Status GetFirstDimensionSize(OpKernelContext* context, INDEX_TYPE* result) { |
349 | const Tensor first_partition_tensor = |
350 | context->input(kFirstPartitionInputIndex); |
351 | if (row_partition_types_.empty()) { |
352 | return errors::InvalidArgument("No row_partition_types given." ); |
353 | } |
354 | const RowPartitionType first_partition_type = row_partition_types_[0]; |
355 | switch (first_partition_type) { |
356 | case RowPartitionType::FIRST_DIM_SIZE: |
357 | *result = first_partition_tensor.scalar<INDEX_TYPE>()(); |
358 | return OkStatus(); |
359 | case RowPartitionType::VALUE_ROWIDS: |
360 | return errors::InvalidArgument( |
361 | "Cannot handle VALUE_ROWIDS in first dimension." ); |
362 | case RowPartitionType::ROW_SPLITS: |
363 | *result = first_partition_tensor.shape().dim_size(0) - 1; |
364 | return OkStatus(); |
365 | default: |
366 | return errors::InvalidArgument( |
367 | "Cannot handle type " , |
368 | RowPartitionTypeToString(first_partition_type)); |
369 | } |
370 | } |
371 | |
372 | void Compute(OpKernelContext* context) override { |
373 | INDEX_TYPE first_dimension; |
374 | const Tensor first_partition_tensor = |
375 | context->input(kFirstPartitionInputIndex); |
376 | OP_REQUIRES(context, first_partition_tensor.NumElements() > 0, |
377 | errors::InvalidArgument("Invalid first partition input. Tensor " |
378 | "requires at least one element." )); |
379 | OP_REQUIRES_OK(context, GetFirstDimensionSize(context, &first_dimension)); |
380 | vector<INDEX_TYPE> output_size; |
381 | OP_REQUIRES_OK(context, |
382 | CalculateOutputSize(first_dimension, context, &output_size)); |
383 | vector<INDEX_TYPE> multiplier; |
384 | multiplier.resize(ragged_rank_ + 1); |
385 | |
386 | multiplier[multiplier.size() - 1] = 1; |
387 | for (int i = multiplier.size() - 2; i >= 0; --i) { |
388 | multiplier[i] = multiplier[i + 1] * output_size[i + 1]; |
389 | } |
390 | // Full size of the tensor. |
391 | TensorShape output_shape; |
392 | OP_REQUIRES_OK(context, |
393 | TensorShapeUtils::MakeShape(output_size, &output_shape)); |
394 | Tensor* output_tensor = nullptr; |
395 | |
396 | OP_REQUIRES_OK(context, |
397 | context->allocate_output(0, output_shape, &output_tensor)); |
398 | const INDEX_TYPE full_size = multiplier[0] * output_size[0]; |
399 | if (full_size > 0) { |
400 | vector<INDEX_TYPE> output_index, new_output_index; |
401 | int nvals = context->input(kValueInputIndex).shape().dim_size(0); |
402 | output_index.reserve(nvals); |
403 | new_output_index.reserve(nvals); |
404 | |
405 | CalculateFirstParentOutputIndex(first_dimension, multiplier[0], |
406 | output_size[0], &output_index); |
407 | for (int i = 1; i <= ragged_rank_; ++i) { |
408 | OP_REQUIRES_OK(context, CalculateOutputIndex( |
409 | context, i - 1, output_index, multiplier[i], |
410 | output_size[i], &new_output_index)); |
411 | output_index.swap(new_output_index); |
412 | new_output_index.clear(); |
413 | } |
414 | |
415 | SetOutput(context, ragged_rank_, output_index, output_tensor); |
416 | } |
417 | } |
418 | virtual void SetOutput(OpKernelContext* context, int ragged_rank, |
419 | const vector<INDEX_TYPE>& output_index, |
420 | Tensor* output_tensor) = 0; |
421 | |
422 | private: |
423 | vector<RowPartitionType> row_partition_types_; |
424 | int ragged_rank_; |
425 | }; |
426 | |
427 | template <typename VALUE_TYPE, typename INDEX_TYPE> |
428 | void slow_copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size) { |
429 | for (INDEX_TYPE index = 0; index < size; ++index) { |
430 | dst[index] = src[index]; |
431 | } |
432 | } |
433 | |
434 | template <typename VALUE_TYPE, typename INDEX_TYPE> |
435 | void copy_array(VALUE_TYPE* dst, const VALUE_TYPE* src, INDEX_TYPE size) { |
436 | memcpy(dst, src, size * sizeof(VALUE_TYPE)); |
437 | } |
438 | |
439 | template <> |
440 | void copy_array<tstring, int64_t>(tstring* dst, const tstring* src, |
441 | int64_t size) { |
442 | slow_copy_array(dst, src, size); |
443 | } |
444 | |
445 | template <> |
446 | void copy_array<tstring, int32>(tstring* dst, const tstring* src, |
447 | int32_t size) { |
448 | slow_copy_array(dst, src, size); |
449 | } |
450 | |
451 | // If we don't specialize for Eigen::half, we get: |
452 | // undefined behavior, destination object type 'Eigen::half' |
453 | // is not TriviallyCopyable |
454 | template <> |
455 | void copy_array<Eigen::half, int64_t>(Eigen::half* dst, const Eigen::half* src, |
456 | int64_t size) { |
457 | slow_copy_array(dst, src, size); |
458 | } |
459 | |
460 | template <> |
461 | void copy_array<Eigen::half, int32>(Eigen::half* dst, const Eigen::half* src, |
462 | int32_t size) { |
463 | slow_copy_array(dst, src, size); |
464 | } |
465 | |
466 | template <typename VALUE_TYPE, typename INDEX_TYPE> |
467 | class RaggedTensorToTensorOp : public RaggedTensorToTensorBaseOp<INDEX_TYPE> { |
468 | public: |
469 | explicit RaggedTensorToTensorOp(OpKernelConstruction* context) |
470 | : RaggedTensorToTensorBaseOp<INDEX_TYPE>(context) {} |
471 | |
472 | void SetOutput(OpKernelContext* context, int ragged_rank, |
473 | const vector<INDEX_TYPE>& output_index, |
474 | Tensor* output_tensor) override { |
475 | // Note: it's ok to use OP_REQUIRES_OK (rather than TF_RETURN_IF_ERROR) |
476 | // in this function, but only because it's the last thing we do before |
477 | // returning from Compute(). |
478 | |
479 | if (output_tensor->NumElements() == 0) return; |
480 | |
481 | const auto& values_tensor = context->input(kValueInputIndex); |
482 | const VALUE_TYPE* values_base = values_tensor.flat<VALUE_TYPE>().data(); |
483 | const auto& default_value_tensor = context->input(kDefaultValueInputIndex); |
484 | VALUE_TYPE* output_base = output_tensor->flat<VALUE_TYPE>().data(); |
485 | |
486 | TensorShape element_shape = output_tensor->shape(); |
487 | element_shape.RemoveDimRange(0, ragged_rank + 1); |
488 | int value_element_size = element_shape.num_elements(); |
489 | size_t output_index_size = output_index.size(); |
490 | |
491 | // Broadcast the default value to value_element_size. (We can skip this |
492 | // if default_value_tensor.NumElements() == 1, since we use std::fill |
493 | // when that's true.) |
494 | const VALUE_TYPE* default_value = |
495 | default_value_tensor.flat<VALUE_TYPE>().data(); |
496 | Tensor bcast_default; // Temporary tensor for result of broadcast |
497 | if (default_value_tensor.NumElements() != value_element_size && |
498 | default_value_tensor.NumElements() != 1) { |
499 | const auto& src_shape = default_value_tensor.shape(); |
500 | BCast bcast(BCast::FromShape(src_shape), BCast::FromShape(element_shape), |
501 | /*fewer_dims_optimization=*/true); |
502 | // Note: bcast should always be valid, since we rejected any incompatible |
503 | // shapes when we called ValidateDefaultValueShape(). |
504 | OP_REQUIRES(context, bcast.IsValid(), |
505 | errors::InvalidArgument("Error broadcasting default_value" )); |
506 | OP_REQUIRES_OK(context, |
507 | context->allocate_temp(default_value_tensor.dtype(), |
508 | element_shape, &bcast_default)); |
509 | const CPUDevice& device = context->eigen_device<CPUDevice>(); |
510 | functor::BroadcastTo<CPUDevice, VALUE_TYPE>()( |
511 | device, context, bcast_default, element_shape, default_value_tensor, |
512 | src_shape, bcast); |
513 | default_value = bcast_default.flat<VALUE_TYPE>().data(); |
514 | } |
515 | |
516 | // Loop through the output_index vector, finding contiguous regions that |
517 | // should be copied. Once we find the end of a contiguous region, copy it |
518 | // and add any necessary padding (with default_value). |
519 | INDEX_TYPE src_start = 0; // Start of contiguous region (in values) |
520 | INDEX_TYPE dst_start = 0; // Destination for contiguous region (in output) |
521 | INDEX_TYPE dst_end = 0; // Destination for contiguous region (in output) |
522 | for (int src_i = 0; src_i <= output_index_size; ++src_i) { |
523 | // dst_i is the destination where the value at src_i should be copied. |
524 | INDEX_TYPE dst_i = src_i < output_index_size ? output_index[src_i] : -1; |
525 | |
526 | // If we're still in a contiguous region, then update dst_end go to the |
527 | // next src_i. |
528 | if (dst_i == dst_end) { |
529 | ++dst_end; |
530 | continue; |
531 | } |
532 | |
533 | // We found the end of contiguous region. This can be because we found |
534 | // a gap (dst_i > dst_end), or a source value that shouldn't be copied |
535 | // because it's out-of-bounds (dst_i == -1), or the end of the tensor |
536 | // (dst_i = -1). |
537 | if (dst_start < dst_end) { |
538 | // Copy the contiguous region. |
539 | const VALUE_TYPE* src = values_base + src_start * value_element_size; |
540 | VALUE_TYPE* dst = output_base + dst_start * value_element_size; |
541 | INDEX_TYPE nvals = (dst_end - dst_start) * value_element_size; |
542 | copy_array<VALUE_TYPE, INDEX_TYPE>(dst, src, nvals); |
543 | } |
544 | |
545 | // Add any necessary padding (w/ default_value). |
546 | if (src_i >= output_index_size) { |
547 | // We reached the end of values: pad to the end of output. |
548 | size_t output_size = output_tensor->NumElements(); |
549 | dst_i = output_size / value_element_size; |
550 | } |
551 | if (dst_i > dst_end) { |
552 | if (default_value_tensor.NumElements() == 1) { |
553 | std::fill(output_base + dst_end * value_element_size, |
554 | output_base + dst_i * value_element_size, *default_value); |
555 | dst_end = dst_i; |
556 | } else { |
557 | while (dst_i > dst_end) { |
558 | VALUE_TYPE* dst = output_base + dst_end * value_element_size; |
559 | copy_array<VALUE_TYPE, INDEX_TYPE>(dst, default_value, |
560 | value_element_size); |
561 | ++dst_end; |
562 | } |
563 | } |
564 | } |
565 | |
566 | // Update indices. |
567 | if (dst_i < 0) { |
568 | // src_i should be skipped -- leave it out of the contiguous region. |
569 | src_start = src_i + 1; |
570 | dst_start = dst_end; |
571 | } else { |
572 | // src_i should be copied -- include it in the contiguous region. |
573 | src_start = src_i; |
574 | dst_start = dst_end; |
575 | dst_end = dst_start + 1; |
576 | } |
577 | } |
578 | } |
579 | }; |
580 | |
581 | #define REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, index_type) \ |
582 | REGISTER_KERNEL_BUILDER(Name("RaggedTensorToTensor") \ |
583 | .Device(DEVICE_CPU) \ |
584 | .TypeConstraint<value_type>("T") \ |
585 | .TypeConstraint<index_type>("Tindex"), \ |
586 | RaggedTensorToTensorOp<value_type, index_type>); |
587 | |
588 | #define REGISTER_CPU_KERNEL(value_type) \ |
589 | REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, int64_t); \ |
590 | REGISTER_CPU_KERNEL_INDEX_TYPE(value_type, tensorflow::int32); |
591 | |
592 | TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL); |
593 | TF_CALL_string(REGISTER_CPU_KERNEL); |
594 | TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); |
595 | TF_CALL_quint16(REGISTER_CPU_KERNEL); |
596 | TF_CALL_qint16(REGISTER_CPU_KERNEL); |
597 | |
598 | #undef REGISTER_CPU_KERNEL |
599 | |
600 | } // namespace |
601 | } // namespace tensorflow |
602 | |