1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ |
16 | #define TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ |
17 | |
18 | #include <utility> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_key.h" |
23 | #include "tensorflow/core/framework/variant.h" |
24 | #include "tensorflow/core/framework/variant_tensor_data.h" |
25 | #include "tensorflow/core/lib/core/refcount.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Variant compatible type for a map of tensors. This is mutable but instances |
30 | // should never be mutated after stored in a variant tensor. |
31 | // |
32 | // **NOTE**: TensorMap stores a refcounted container of tf::Tensor objects, |
33 | // which are accessible via TensorMap::tensors(). Because it is refcounted, |
34 | // straight copies of the form: |
35 | // |
36 | // TensorMap b = a; |
37 | // b.tensors().insert(k,v); // WARNING: This modifies a.tensors(). |
38 | // |
39 | // Do not create a true copy of the underlying container - but instead increment |
40 | // a reference count. Modifying b.tensors() modifies a.tensors(). In this way, |
41 | // TensorMap should be considered similar to the tf::Tensor object. |
42 | // |
43 | // In order to get a copy of the underlying map, use the Copy method: |
44 | // |
45 | // TensorMap b = a.Copy(); |
46 | // b.tensors().insert(k, v); // This does not modify a.tensors(). |
47 | // |
48 | // Note that this is not a deep copy: the memory locations of the underlying |
49 | // tensors will still point to the same locations of the corresponding tensors |
50 | // in the original. To truly perform a deep copy, Device and Type-specific |
51 | // code needs to be applied to the underlying tensors as usual. |
52 | // |
53 | // The most important implication of RefCounted TensorMaps is that OpKernels |
54 | // wishing to reuse TensorMap inputs as outputs via context->forward_input() |
55 | // need to perform an additional check on the refcount of the TensorList, |
56 | // to ensure aliasing can be performed safely. For example: |
57 | // |
58 | // bool can_alias = false; |
59 | // auto fw = c->forward_input(..., DT_VARIANT, {}, ...); |
60 | // if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { |
61 | // auto* tl = fw->scalar<Variant>()().get<TensorMap>(); |
62 | // if (tl && tl->RefCountIsOne()) { |
63 | // can_alias = true; |
64 | // } |
65 | // } |
66 | // |
67 | class TensorMap { |
68 | public: |
69 | TensorMap() : tensors_(new Tensors) {} |
70 | ~TensorMap(); |
71 | |
72 | TensorMap(const TensorMap& other) : tensors_(other.tensors_) { |
73 | tensors_->Ref(); |
74 | } |
75 | |
76 | TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) { |
77 | rhs.tensors_ = nullptr; |
78 | } |
79 | |
80 | TensorMap& operator=(const TensorMap& rhs) { |
81 | if (this == &rhs) return *this; |
82 | tensors_->Unref(); |
83 | tensors_ = rhs.tensors_; |
84 | tensors_->Ref(); |
85 | return *this; |
86 | } |
87 | |
88 | TensorMap& operator=(TensorMap&& rhs) { |
89 | if (this == &rhs) return *this; |
90 | std::swap(tensors_, rhs.tensors_); |
91 | return *this; |
92 | } |
93 | |
94 | static const char kTypeName[]; |
95 | |
96 | string TypeName() const { return kTypeName; } |
97 | |
98 | void Encode(VariantTensorData* data) const; |
99 | |
100 | bool Decode(const VariantTensorData& data); |
101 | |
102 | // TODO(apassos) fill this out |
103 | string DebugString() const { return "TensorMap" ; } |
104 | |
105 | // Access to the underlying tensor container. |
106 | absl::flat_hash_map<TensorKey, Tensor>& tensors() { |
107 | return tensors_->values_; |
108 | } |
109 | |
110 | const absl::flat_hash_map<TensorKey, Tensor>& tensors() const { |
111 | return tensors_->values_; |
112 | } |
113 | |
114 | // Get a new TensorMap containing a copy of the underlying tensor container. |
115 | TensorMap Copy() const { |
116 | TensorMap out; |
117 | // This performs a copy of the absl::hashmap. |
118 | out.tensors_->values_ = tensors_->values_; |
119 | return out; |
120 | } |
121 | |
122 | // Insert key and value if the key does not already exist. |
123 | // Returns true if the insertion happens. |
124 | bool insert(const TensorKey& key, const Tensor& value) { |
125 | auto r = tensors_->values_.try_emplace(key, value); |
126 | return r.second; |
127 | } |
128 | |
129 | // Lookup given key. Returns iterator to found key or end. |
130 | absl::flat_hash_map<TensorKey, Tensor>::iterator find(TensorKey key) { |
131 | return tensors_->values_.find(key); |
132 | } |
133 | |
134 | Tensor& lookup(TensorKey key) { return tensors_->values_.find(key)->second; } |
135 | |
136 | Tensor& operator[](TensorKey& k) { return tensors_->values_[k]; } |
137 | |
138 | bool replace(const TensorKey& k, const Tensor& v) { |
139 | tensors_->values_[k] = v; |
140 | return true; |
141 | } |
142 | |
143 | // Removes element with given key. Return size of removed element. |
144 | size_t erase(TensorKey key) { return tensors_->values_.erase(key); } |
145 | |
146 | // Size returns the number of elements in the map |
147 | size_t size() const { return tensors_->values_.size(); } |
148 | |
149 | std::vector<Tensor> keys() const { |
150 | std::vector<Tensor> keys; |
151 | keys.reserve(tensors_->values_.size()); |
152 | absl::flat_hash_map<TensorKey, Tensor>::iterator it = |
153 | tensors_->values_.begin(); |
154 | while (it != tensors_->values_.end()) { |
155 | keys.push_back(it->first); |
156 | it++; |
157 | } |
158 | return keys; |
159 | } |
160 | |
161 | // Is this TensorMap the only one with a reference to the underlying |
162 | // container? |
163 | bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } |
164 | |
165 | private: |
166 | class Tensors : public core::RefCounted { |
167 | public: |
168 | absl::flat_hash_map<TensorKey, Tensor> values_; |
169 | }; |
170 | Tensors* tensors_; |
171 | }; |
172 | |
173 | #if defined(PLATFORM_GOOGLE) |
174 | // TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. |
175 | // For 32-bit devices, it's acceptable not to inline. |
176 | static_assert(Variant::CanInlineType<TensorMap>() || sizeof(void*) < 8, |
177 | "Must be able to inline TensorMap into a Variant" ); |
178 | #endif |
179 | } // namespace tensorflow |
180 | |
181 | #endif // TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ |
182 | |