1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "absl/types/span.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/variant_tensor_data.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class CompositeTensorVariantMetadata; |
28 | |
29 | // Encoding for a `tf.ExtensionType` value, that can be saved as a Variant. |
30 | // |
31 | // `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class |
32 | // used to Python types that are supported by TensorFlow APIs. Example |
33 | // ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`. |
34 | // |
35 | // `CompositeTensorVariant` decomposes the `ExtensionType` value into two |
36 | // parts: |
37 | // |
38 | // * `components`: A list of Tensors, which encodes the value's dynamic |
39 | // data -- i.e., data that may change for different executions of a graph. |
40 | // * `metadata`: A serialized TypeSpec, which encodes the value's |
41 | // static data -- i.e., data that is the same for all executions of a graph. |
42 | // |
43 | // CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT. |
44 | // Typically, extension type values are encoded with a scalar tensor containing |
45 | // a single CompositeTensorVariant value. |
46 | class CompositeTensorVariant { |
47 | public: |
48 | CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata, |
49 | absl::Span<Tensor> flat_components); |
50 | |
51 | CompositeTensorVariant(); |
52 | CompositeTensorVariant(const CompositeTensorVariant& other); |
53 | CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default; |
54 | CompositeTensorVariant& operator=(const CompositeTensorVariant& other) = |
55 | delete; |
56 | |
57 | // Returns the list of Tensor components that encode this value's dynamic |
58 | // data. |
59 | absl::Span<const Tensor> flat_components() const { |
60 | return absl::MakeConstSpan(flat_components_); |
61 | } |
62 | |
63 | // Returns the serialized TypeSpec that encodes the value's static data. |
64 | const CompositeTensorVariantMetadata& metadata() const { return *metadata_; } |
65 | |
66 | // Variant methods. |
67 | string TypeName() const { return kTypeName; } |
68 | |
69 | // Updates `VariantTensorData` with an encoding for this value. |
70 | void Encode(VariantTensorData* data) const; |
71 | |
72 | // Updates this value to match the encoding in a given `VariantTensorData`. |
73 | bool Decode(const VariantTensorData& data); |
74 | |
75 | // Returns a string summary for this value. |
76 | string DebugString() const; |
77 | |
78 | // Name of this type (used for variant serialization). |
79 | static constexpr const char kTypeName[] = "CompositeTensorVariant" ; |
80 | |
81 | private: |
82 | // Tensor components for this value. |
83 | std::vector<Tensor> flat_components_; |
84 | |
85 | // TypeSpec for this value. CompositeTensorVariantMetadata is a thin wrapper |
86 | // around a TypeSpecProto, which is used to retain flexibility to change the |
87 | // variant encoding. |
88 | // |
89 | // Note: we use a unique_ptr, because header files in the kernels/ directory |
90 | // are not allowed to import .pb.h files. |
91 | std::unique_ptr<CompositeTensorVariantMetadata> metadata_; |
92 | }; |
93 | |
94 | } // namespace tensorflow |
95 | |
96 | #endif // TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ |
97 | |