1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ |
16 | #define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/variant.h" |
22 | #include "tensorflow/core/framework/variant_tensor_data.h" |
23 | #include "tensorflow/core/lib/core/refcount.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | // Variant compatible type for a list of tensors. This is mutable but instances |
28 | // should never be mutated after stored in a variant tensor. |
29 | // |
30 | // **NOTE**: TensorList stores a refcounted container of tf::Tensor objects, |
31 | // which are accessible via TensorList::tensors(). Because it is refcounted, |
32 | // straight copies of the form: |
33 | // |
34 | // TensorList b = a; |
35 | // b.tensors().push_back(t); // WARNING: This modifies a.tensors(). |
36 | // |
37 | // Do not create a true copy of the underlying container - but instead increment |
38 | // a reference count. Modifying b.tensors() modifies a.tensors(). In this way, |
39 | // TensorList should be considered similar to the tf::Tensor object. |
40 | // |
41 | // In order to get a copy of the underlying list, use the Copy method: |
42 | // |
43 | // TensorList b = a.Copy(); |
44 | // b.tensors().push_back(t); // This does not modify a.tensors(). |
45 | // |
46 | // Note that this is not a deep copy: the memory locations of the underlying |
47 | // tensors will still point to the same locations of the corresponding tensors |
48 | // in the original. To truly perform a deep copy, Device and Type-specific |
49 | // code needs to be applied to the underlying tensors as usual. |
50 | // |
51 | // The most important implication of RefCounted TLs is that OpKernels |
52 | // wishing to reuse TensorList inputs as outputs via context->forward_input() |
53 | // need to perform an additional check on the refcount of the TensorList, |
54 | // to ensure aliasing can be performed safely. For example: |
55 | // |
56 | // bool can_alias = false; |
57 | // auto fw = c->forward_input(..., DT_VARIANT, {}, ...); |
58 | // if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { |
59 | // auto* tl = fw->scalar<Variant>()().get<TensorList>(); |
60 | // if (tl && tl->RefCountIsOne()) { |
61 | // can_alias = true; |
62 | // } |
63 | // } |
64 | // |
65 | class TensorList { |
66 | public: |
67 | TensorList() : tensors_(new Tensors) {} |
68 | ~TensorList(); |
69 | |
70 | TensorList(const TensorList& other) |
71 | : element_shape(other.element_shape), |
72 | element_dtype(other.element_dtype), |
73 | max_num_elements(other.max_num_elements), |
74 | tensors_(other.tensors_) { |
75 | tensors_->Ref(); |
76 | } |
77 | |
78 | TensorList(TensorList&& rhs) |
79 | : element_shape(std::move(rhs.element_shape)), |
80 | element_dtype(rhs.element_dtype), |
81 | max_num_elements(rhs.max_num_elements), |
82 | tensors_(rhs.tensors_) { |
83 | rhs.tensors_ = nullptr; |
84 | } |
85 | |
86 | TensorList& operator=(const TensorList& rhs) { |
87 | if (this == &rhs) return *this; |
88 | element_shape = rhs.element_shape; |
89 | element_dtype = rhs.element_dtype; |
90 | max_num_elements = rhs.max_num_elements; |
91 | tensors_->Unref(); |
92 | tensors_ = rhs.tensors_; |
93 | tensors_->Ref(); |
94 | return *this; |
95 | } |
96 | |
97 | TensorList& operator=(TensorList&& rhs) { |
98 | if (this == &rhs) return *this; |
99 | element_shape = rhs.element_shape; |
100 | element_dtype = rhs.element_dtype; |
101 | max_num_elements = rhs.max_num_elements; |
102 | std::swap(tensors_, rhs.tensors_); |
103 | return *this; |
104 | } |
105 | |
106 | static const char kTypeName[]; |
107 | |
108 | string TypeName() const { return kTypeName; } |
109 | |
110 | void Encode(VariantTensorData* data) const; |
111 | |
112 | bool Decode(const VariantTensorData& data); |
113 | |
114 | // TODO(apassos) fill this out |
115 | string DebugString() const { return "TensorList" ; } |
116 | |
117 | PartialTensorShape element_shape; |
118 | |
119 | DataType element_dtype; |
120 | |
121 | // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size |
122 | // of `tensors` is unbounded. |
123 | int max_num_elements = -1; |
124 | |
125 | // Access to the underlying tensor container. |
126 | std::vector<Tensor>& tensors() { return tensors_->values_; } |
127 | const std::vector<Tensor>& tensors() const { return tensors_->values_; } |
128 | |
129 | // Get a new TensorList containing a copy of the underlying tensor container. |
130 | TensorList Copy() const { |
131 | TensorList out; |
132 | out.element_shape = element_shape; |
133 | out.element_dtype = element_dtype; |
134 | out.max_num_elements = max_num_elements; |
135 | // This performs a copy of the std::vector. |
136 | out.tensors_->values_ = tensors_->values_; |
137 | return out; |
138 | } |
139 | |
140 | // Is this TensorList the only one with a reference to the underlying |
141 | // container? |
142 | bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } |
143 | |
144 | private: |
145 | class Tensors : public core::RefCounted { |
146 | public: |
147 | std::vector<Tensor> values_; |
148 | }; |
149 | Tensors* tensors_; |
150 | }; |
151 | |
152 | #if defined(PLATFORM_GOOGLE) |
153 | // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. |
154 | // For 32-bit devices, it's acceptable not to inline. |
155 | static_assert(Variant::CanInlineType<TensorList>() || sizeof(void*) < 8, |
156 | "Must be able to inline TensorList into a Variant" ); |
157 | #endif |
158 | } // namespace tensorflow |
159 | |
160 | #endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ |
161 | |