1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_
17#define TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_
18
19#include <vector>
20
21#include "absl/types/span.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/variant_tensor_data.h"
24
25namespace tensorflow {
26
27class CompositeTensorVariantMetadata;
28
29// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
30//
31// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
32// used to Python types that are supported by TensorFlow APIs. Example
33// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
34//
35// `CompositeTensorVariant` decomposes the `ExtensionType` value into two
36// parts:
37//
38// * `components`: A list of Tensors, which encodes the value's dynamic
39// data -- i.e., data that may change for different executions of a graph.
40// * `metadata`: A serialized TypeSpec, which encodes the value's
41// static data -- i.e., data that is the same for all executions of a graph.
42//
43// CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT.
44// Typically, extension type values are encoded with a scalar tensor containing
45// a single CompositeTensorVariant value.
46class CompositeTensorVariant {
47 public:
48 CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata,
49 absl::Span<Tensor> flat_components);
50
51 CompositeTensorVariant();
52 CompositeTensorVariant(const CompositeTensorVariant& other);
53 CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default;
54 CompositeTensorVariant& operator=(const CompositeTensorVariant& other) =
55 delete;
56
57 // Returns the list of Tensor components that encode this value's dynamic
58 // data.
59 absl::Span<const Tensor> flat_components() const {
60 return absl::MakeConstSpan(flat_components_);
61 }
62
63 // Returns the serialized TypeSpec that encodes the value's static data.
64 const CompositeTensorVariantMetadata& metadata() const { return *metadata_; }
65
66 // Variant methods.
67 string TypeName() const { return kTypeName; }
68
69 // Updates `VariantTensorData` with an encoding for this value.
70 void Encode(VariantTensorData* data) const;
71
72 // Updates this value to match the encoding in a given `VariantTensorData`.
73 bool Decode(const VariantTensorData& data);
74
75 // Returns a string summary for this value.
76 string DebugString() const;
77
78 // Name of this type (used for variant serialization).
79 static constexpr const char kTypeName[] = "CompositeTensorVariant";
80
81 private:
82 // Tensor components for this value.
83 std::vector<Tensor> flat_components_;
84
85 // TypeSpec for this value. CompositeTensorVariantMetadata is a thin wrapper
86 // around a TypeSpecProto, which is used to retain flexibility to change the
87 // variant encoding.
88 //
89 // Note: we use a unique_ptr, because header files in the kernels/ directory
90 // are not allowed to import .pb.h files.
91 std::unique_ptr<CompositeTensorVariantMetadata> metadata_;
92};
93
94} // namespace tensorflow
95
96#endif // TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_
97