1#include "tensorflow/core/framework/tensor_key.h"
2/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
18#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
19
20#define EIGEN_USE_THREADS
21#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22#define EIGEN_USE_GPU
23#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24
25#include <vector>
26
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/framework/variant_op_registry.h"
30#include "tensorflow/core/framework/variant_tensor_data.h"
31#include "tensorflow/core/kernels/cwise_ops_common.h"
32#include "tensorflow/core/util/tensor_ops_util.h"
33
34namespace tensorflow {
35
36// Class used to store a RaggedTensor as a Variant scalar.
37class RaggedTensorVariant {
38 public:
39 RaggedTensorVariant() {}
40 RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
41 : values_(std::move(values)), nested_splits_(nested_splits) {}
42
43 // Variant support methods.
44 string TypeName() const;
45 string DebugString() const;
46 void Encode(VariantTensorData* data) const;
47 bool Decode(const VariantTensorData& data);
48
49 // The flat_values of the RaggedTensor.
50 const Tensor& values() const { return values_; }
51 Tensor* mutable_values() { return &values_; }
52 void set_values(const Tensor& new_values) { values_ = new_values; }
53
54 // The nested row_splits of the RaggedTensor.
55 int ragged_rank() const { return nested_splits_.size(); }
56 const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
57 std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
58 const Tensor& splits(int i) const { return nested_splits_[i]; }
59 Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
60 void set_nested_splits(const std::vector<Tensor>& nested_splits) {
61 nested_splits_ = nested_splits;
62 }
63 void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
64
65 private:
66 Tensor values_;
67 std::vector<Tensor> nested_splits_;
68};
69
70template <typename Device>
71Status RaggedTensorVariantZerosLike(OpKernelContext* c,
72 const RaggedTensorVariant& x,
73 RaggedTensorVariant* y) {
74 y->set_nested_splits(x.nested_splits());
75 TF_RETURN_IF_ERROR(
76 ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
77 return OkStatus();
78}
79
80template <typename Device>
81Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
82 const RaggedTensorVariant& x,
83 const RaggedTensorVariant& y,
84 RaggedTensorVariant* out) {
85 if (x.values().dtype() != y.values().dtype()) {
86 return errors::InvalidArgument(
87 "Can't add RaggedTensorVariants of different dtypes. One is ",
88 DataTypeString(x.values().dtype()), " and the other is ",
89 DataTypeString(y.values().dtype()));
90 }
91 if (x.ragged_rank() != y.ragged_rank()) {
92 return errors::InvalidArgument(
93 "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
94 x.ragged_rank(), " and the other is ", y.ragged_rank());
95 }
96 for (int i = 0; i < x.ragged_rank(); ++i) {
97 if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
98 return errors::InvalidArgument(
99 "Can't add RaggedTensorVariants with different row_splits.");
100 }
101 }
102 out->set_nested_splits(x.nested_splits());
103 TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
104 out->mutable_values()));
105 return OkStatus();
106}
107
108} // namespace tensorflow
109
110#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
111