1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
16#define TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
17
18#include <utility>
19
20#include "absl/container/flat_hash_map.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_key.h"
23#include "tensorflow/core/framework/variant.h"
24#include "tensorflow/core/framework/variant_tensor_data.h"
25#include "tensorflow/core/lib/core/refcount.h"
26
27namespace tensorflow {
28
29// Variant compatible type for a map of tensors. This is mutable but instances
30// should never be mutated after stored in a variant tensor.
31//
32// **NOTE**: TensorMap stores a refcounted container of tf::Tensor objects,
33// which are accessible via TensorMap::tensors(). Because it is refcounted,
34// straight copies of the form:
35//
36// TensorMap b = a;
37// b.tensors().insert(k,v); // WARNING: This modifies a.tensors().
38//
39// Do not create a true copy of the underlying container - but instead increment
40// a reference count. Modifying b.tensors() modifies a.tensors(). In this way,
41// TensorMap should be considered similar to the tf::Tensor object.
42//
43// In order to get a copy of the underlying map, use the Copy method:
44//
45// TensorMap b = a.Copy();
46// b.tensors().insert(k, v); // This does not modify a.tensors().
47//
48// Note that this is not a deep copy: the memory locations of the underlying
49// tensors will still point to the same locations of the corresponding tensors
50// in the original. To truly perform a deep copy, Device and Type-specific
51// code needs to be applied to the underlying tensors as usual.
52//
53// The most important implication of RefCounted TensorMaps is that OpKernels
54// wishing to reuse TensorMap inputs as outputs via context->forward_input()
55// need to perform an additional check on the refcount of the TensorList,
56// to ensure aliasing can be performed safely. For example:
57//
58// bool can_alias = false;
59// auto fw = c->forward_input(..., DT_VARIANT, {}, ...);
60// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) {
61// auto* tl = fw->scalar<Variant>()().get<TensorMap>();
62// if (tl && tl->RefCountIsOne()) {
63// can_alias = true;
64// }
65// }
66//
67class TensorMap {
68 public:
69 TensorMap() : tensors_(new Tensors) {}
70 ~TensorMap();
71
72 TensorMap(const TensorMap& other) : tensors_(other.tensors_) {
73 tensors_->Ref();
74 }
75
76 TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) {
77 rhs.tensors_ = nullptr;
78 }
79
80 TensorMap& operator=(const TensorMap& rhs) {
81 if (this == &rhs) return *this;
82 tensors_->Unref();
83 tensors_ = rhs.tensors_;
84 tensors_->Ref();
85 return *this;
86 }
87
88 TensorMap& operator=(TensorMap&& rhs) {
89 if (this == &rhs) return *this;
90 std::swap(tensors_, rhs.tensors_);
91 return *this;
92 }
93
94 static const char kTypeName[];
95
96 string TypeName() const { return kTypeName; }
97
98 void Encode(VariantTensorData* data) const;
99
100 bool Decode(const VariantTensorData& data);
101
102 // TODO(apassos) fill this out
103 string DebugString() const { return "TensorMap"; }
104
105 // Access to the underlying tensor container.
106 absl::flat_hash_map<TensorKey, Tensor>& tensors() {
107 return tensors_->values_;
108 }
109
110 const absl::flat_hash_map<TensorKey, Tensor>& tensors() const {
111 return tensors_->values_;
112 }
113
114 // Get a new TensorMap containing a copy of the underlying tensor container.
115 TensorMap Copy() const {
116 TensorMap out;
117 // This performs a copy of the absl::hashmap.
118 out.tensors_->values_ = tensors_->values_;
119 return out;
120 }
121
122 // Insert key and value if the key does not already exist.
123 // Returns true if the insertion happens.
124 bool insert(const TensorKey& key, const Tensor& value) {
125 auto r = tensors_->values_.try_emplace(key, value);
126 return r.second;
127 }
128
129 // Lookup given key. Returns iterator to found key or end.
130 absl::flat_hash_map<TensorKey, Tensor>::iterator find(TensorKey key) {
131 return tensors_->values_.find(key);
132 }
133
134 Tensor& lookup(TensorKey key) { return tensors_->values_.find(key)->second; }
135
136 Tensor& operator[](TensorKey& k) { return tensors_->values_[k]; }
137
138 bool replace(const TensorKey& k, const Tensor& v) {
139 tensors_->values_[k] = v;
140 return true;
141 }
142
143 // Removes element with given key. Return size of removed element.
144 size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
145
146 // Size returns the number of elements in the map
147 size_t size() const { return tensors_->values_.size(); }
148
149 std::vector<Tensor> keys() const {
150 std::vector<Tensor> keys;
151 keys.reserve(tensors_->values_.size());
152 absl::flat_hash_map<TensorKey, Tensor>::iterator it =
153 tensors_->values_.begin();
154 while (it != tensors_->values_.end()) {
155 keys.push_back(it->first);
156 it++;
157 }
158 return keys;
159 }
160
161 // Is this TensorMap the only one with a reference to the underlying
162 // container?
163 bool RefCountIsOne() const { return tensors_->RefCountIsOne(); }
164
165 private:
166 class Tensors : public core::RefCounted {
167 public:
168 absl::flat_hash_map<TensorKey, Tensor> values_;
169 };
170 Tensors* tensors_;
171};
172
173#if defined(PLATFORM_GOOGLE)
174// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices.
175// For 32-bit devices, it's acceptable not to inline.
176static_assert(Variant::CanInlineType<TensorMap>() || sizeof(void*) < 8,
177 "Must be able to inline TensorMap into a Variant");
178#endif
179} // namespace tensorflow
180
181#endif // TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_
182