1 | #include "tensorflow/core/framework/tensor_key.h" |
2 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
3 | |
4 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | you may not use this file except in compliance with the License. |
6 | You may obtain a copy of the License at |
7 | |
8 | http://www.apache.org/licenses/LICENSE-2.0 |
9 | |
10 | Unless required by applicable law or agreed to in writing, software |
11 | distributed under the License is distributed on an "AS IS" BASIS, |
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | See the License for the specific language governing permissions and |
14 | limitations under the License. |
15 | ==============================================================================*/ |
16 | |
17 | #ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ |
18 | #define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ |
19 | |
20 | #define EIGEN_USE_THREADS |
21 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
22 | #define EIGEN_USE_GPU |
23 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
24 | |
25 | #include <vector> |
26 | |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/framework/variant_op_registry.h" |
30 | #include "tensorflow/core/framework/variant_tensor_data.h" |
31 | #include "tensorflow/core/kernels/cwise_ops_common.h" |
32 | #include "tensorflow/core/util/tensor_ops_util.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | // Class used to store a RaggedTensor as a Variant scalar. |
37 | class RaggedTensorVariant { |
38 | public: |
39 | RaggedTensorVariant() {} |
40 | RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits) |
41 | : values_(std::move(values)), nested_splits_(nested_splits) {} |
42 | |
43 | // Variant support methods. |
44 | string TypeName() const; |
45 | string DebugString() const; |
46 | void Encode(VariantTensorData* data) const; |
47 | bool Decode(const VariantTensorData& data); |
48 | |
49 | // The flat_values of the RaggedTensor. |
50 | const Tensor& values() const { return values_; } |
51 | Tensor* mutable_values() { return &values_; } |
52 | void set_values(const Tensor& new_values) { values_ = new_values; } |
53 | |
54 | // The nested row_splits of the RaggedTensor. |
55 | int ragged_rank() const { return nested_splits_.size(); } |
56 | const std::vector<Tensor>& nested_splits() const { return nested_splits_; } |
57 | std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; } |
58 | const Tensor& splits(int i) const { return nested_splits_[i]; } |
59 | Tensor* mutable_splits(int i) { return &nested_splits_[i]; } |
60 | void set_nested_splits(const std::vector<Tensor>& nested_splits) { |
61 | nested_splits_ = nested_splits; |
62 | } |
63 | void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); } |
64 | |
65 | private: |
66 | Tensor values_; |
67 | std::vector<Tensor> nested_splits_; |
68 | }; |
69 | |
70 | template <typename Device> |
71 | Status RaggedTensorVariantZerosLike(OpKernelContext* c, |
72 | const RaggedTensorVariant& x, |
73 | RaggedTensorVariant* y) { |
74 | y->set_nested_splits(x.nested_splits()); |
75 | TF_RETURN_IF_ERROR( |
76 | ZerosLikeTensor<Device>(c, x.values(), y->mutable_values())); |
77 | return OkStatus(); |
78 | } |
79 | |
80 | template <typename Device> |
81 | Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, |
82 | const RaggedTensorVariant& x, |
83 | const RaggedTensorVariant& y, |
84 | RaggedTensorVariant* out) { |
85 | if (x.values().dtype() != y.values().dtype()) { |
86 | return errors::InvalidArgument( |
87 | "Can't add RaggedTensorVariants of different dtypes. One is " , |
88 | DataTypeString(x.values().dtype()), " and the other is " , |
89 | DataTypeString(y.values().dtype())); |
90 | } |
91 | if (x.ragged_rank() != y.ragged_rank()) { |
92 | return errors::InvalidArgument( |
93 | "Can't add RaggedTensorVariants of different ragged rank. " , "One is " , |
94 | x.ragged_rank(), " and the other is " , y.ragged_rank()); |
95 | } |
96 | for (int i = 0; i < x.ragged_rank(); ++i) { |
97 | if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) { |
98 | return errors::InvalidArgument( |
99 | "Can't add RaggedTensorVariants with different row_splits." ); |
100 | } |
101 | } |
102 | out->set_nested_splits(x.nested_splits()); |
103 | TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(), |
104 | out->mutable_values())); |
105 | return OkStatus(); |
106 | } |
107 | |
108 | } // namespace tensorflow |
109 | |
110 | #endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ |
111 | |