1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
16#define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
17
18#include <utility>
19
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/variant.h"
22#include "tensorflow/core/framework/variant_tensor_data.h"
23#include "tensorflow/core/lib/core/refcount.h"
24
25namespace tensorflow {
26
27// Variant compatible type for a list of tensors. This is mutable but instances
28// should never be mutated after stored in a variant tensor.
29//
30// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects,
31// which are accessible via TensorList::tensors(). Because it is refcounted,
32// straight copies of the form:
33//
34// TensorList b = a;
35// b.tensors().push_back(t); // WARNING: This modifies a.tensors().
36//
37// Do not create a true copy of the underlying container - but instead increment
38// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
39// TensorList should be considered similar to the tf::Tensor object.
40//
41// In order to get a copy of the underlying list, use the Copy method:
42//
43// TensorList b = a.Copy();
44// b.tensors().push_back(t); // This does not modify a.tensors().
45//
46// Note that this is not a deep copy: the memory locations of the underlying
47// tensors will still point to the same locations of the corresponding tensors
48// in the original. To truly perform a deep copy, Device and Type-specific
49// code needs to be applied to the underlying tensors as usual.
50//
51// The most important implication of RefCounted TLs is that OpKernels
52// wishing to reuse TensorList inputs as outputs via context->forward_input()
53// need to perform an additional check on the refcount of the TensorList,
54// to ensure aliasing can be performed safely. For example:
55//
56// bool can_alias = false;
57// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
58// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
59// auto* tl = fw->scalar<Variant>()().get<TensorList>();
60// if (tl && tl->RefCountIsOne()) {
61// can_alias = true;
62// }
63// }
64//
65class TensorList {
66 public:
67 TensorList() : tensors_(new Tensors) {}
68 ~TensorList();
69
70 TensorList(const TensorList& other)
71 : element_shape(other.element_shape),
72 element_dtype(other.element_dtype),
73 max_num_elements(other.max_num_elements),
74 tensors_(other.tensors_) {
75 tensors_->Ref();
76 }
77
78 TensorList(TensorList&& rhs)
79 : element_shape(std::move(rhs.element_shape)),
80 element_dtype(rhs.element_dtype),
81 max_num_elements(rhs.max_num_elements),
82 tensors_(rhs.tensors_) {
83 rhs.tensors_ = nullptr;
84 }
85
86 TensorList& operator=(const TensorList& rhs) {
87 if (this == &rhs) return *this;
88 element_shape = rhs.element_shape;
89 element_dtype = rhs.element_dtype;
90 max_num_elements = rhs.max_num_elements;
91 tensors_->Unref();
92 tensors_ = rhs.tensors_;
93 tensors_->Ref();
94 return *this;
95 }
96
97 TensorList& operator=(TensorList&& rhs) {
98 if (this == &rhs) return *this;
99 element_shape = rhs.element_shape;
100 element_dtype = rhs.element_dtype;
101 max_num_elements = rhs.max_num_elements;
102 std::swap(tensors_, rhs.tensors_);
103 return *this;
104 }
105
106 static const char kTypeName[];
107
108 string TypeName() const { return kTypeName; }
109
110 void Encode(VariantTensorData* data) const;
111
112 bool Decode(const VariantTensorData& data);
113
114 // TODO(apassos) fill this out
115 string DebugString() const { return "TensorList"; }
116
117 PartialTensorShape element_shape;
118
119 DataType element_dtype;
120
121 // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
122 // of `tensors` is unbounded.
123 int max_num_elements = -1;
124
125 // Access to the underlying tensor container.
126 std::vector<Tensor>& tensors() { return tensors_->values_; }
127 const std::vector<Tensor>& tensors() const { return tensors_->values_; }
128
129 // Get a new TensorList containing a copy of the underlying tensor container.
130 TensorList Copy() const {
131 TensorList out;
132 out.element_shape = element_shape;
133 out.element_dtype = element_dtype;
134 out.max_num_elements = max_num_elements;
135 // This performs a copy of the std::vector.
136 out.tensors_->values_ = tensors_->values_;
137 return out;
138 }
139
140 // Is this TensorList the only one with a reference to the underlying
141 // container?
142 bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
143
144 private:
145 class Tensors : public core::RefCounted {
146 public:
147 std::vector<Tensor> values_;
148 };
149 Tensors* tensors_;
150};
151
152#if defined(PLATFORM_GOOGLE)
153// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
154// For 32-bit devices, it's acceptable not to inline.
155static_assert(Variant::CanInlineType<TensorList>() || sizeof(void*) < 8,
156 "Must be able to inline TensorList into a Variant");
157#endif
158} // namespace tensorflow
159
160#endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_
161