1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_
17#define TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_
18
19#include <stddef.h>
20
21#include <functional>
22#include <initializer_list>
23#include <iterator>
24#include <utility>
25
26#include "tensorflow/tsl/lib/gtl/flatrep.h"
27#include "tensorflow/tsl/platform/hash.h"
28#include "tensorflow/tsl/platform/logging.h"
29#include "tensorflow/tsl/platform/types.h"
30
31namespace tsl {
32namespace gtl {
33
34// FlatMap<K,V,...> provides a map from K to V.
35//
36// The map is implemented using an open-addressed hash table. A
37// single array holds entire map contents and collisions are resolved
38// by probing at a sequence of locations in the array.
39template <typename Key, typename Val, class Hash = hash<Key>,
40 class Eq = std::equal_to<Key>>
41class FlatMap {
42 private:
43 // Forward declare some internal types needed in public section.
44 struct Bucket;
45
46 // We cannot use std::pair<> since internal representation stores
47 // keys and values in separate arrays, so we make a custom struct
48 // that holds references to the internal key, value elements.
49 //
50 // We define the struct as private ValueType, and typedef it as public
51 // value_type, to work around a gcc bug when compiling the iterators.
52 struct ValueType {
53 typedef Key first_type;
54 typedef Val second_type;
55
56 const Key& first;
57 Val& second;
58 ValueType(const Key& k, Val& v) : first(k), second(v) {}
59 };
60
61 public:
62 typedef Key key_type;
63 typedef Val mapped_type;
64 typedef Hash hasher;
65 typedef Eq key_equal;
66 typedef size_t size_type;
67 typedef ptrdiff_t difference_type;
68 typedef ValueType value_type;
69 typedef value_type* pointer;
70 typedef const value_type* const_pointer;
71 typedef value_type& reference;
72 typedef const value_type& const_reference;
73
74 FlatMap() : FlatMap(1) {}
75
76 explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
77 : rep_(N, hf, eq) {}
78
79 FlatMap(const FlatMap& src) : rep_(src.rep_) {}
80
81 // Move constructor leaves src in a valid but unspecified state (same as
82 // std::unordered_map).
83 FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {}
84
85 template <typename InputIter>
86 FlatMap(InputIter first, InputIter last, size_t N = 1,
87 const Hash& hf = Hash(), const Eq& eq = Eq())
88 : FlatMap(N, hf, eq) {
89 insert(first, last);
90 }
91
92 FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1,
93 const Hash& hf = Hash(), const Eq& eq = Eq())
94 : FlatMap(init.begin(), init.end(), N, hf, eq) {}
95
96 FlatMap& operator=(const FlatMap& src) {
97 rep_.CopyFrom(src.rep_);
98 return *this;
99 }
100
101 // Move-assignment operator leaves src in a valid but unspecified state (same
102 // as std::unordered_map).
103 FlatMap& operator=(FlatMap&& src) {
104 rep_.MoveFrom(std::move(src.rep_));
105 return *this;
106 }
107
108 ~FlatMap() {}
109
110 void swap(FlatMap& x) { rep_.swap(x.rep_); }
111 void clear_no_resize() { rep_.clear_no_resize(); }
112 void clear() { rep_.clear(); }
113 void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
114 void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
115 void resize(size_t N) { rep_.Resize(std::max(N, size())); }
116 size_t size() const { return rep_.size(); }
117 bool empty() const { return size() == 0; }
118 size_t bucket_count() const { return rep_.bucket_count(); }
119 hasher hash_function() const { return rep_.hash_function(); }
120 key_equal key_eq() const { return rep_.key_eq(); }
121
122 class iterator {
123 public:
124 typedef typename FlatMap::difference_type difference_type;
125 typedef typename FlatMap::value_type value_type;
126 typedef typename FlatMap::pointer pointer;
127 typedef typename FlatMap::reference reference;
128 typedef ::std::forward_iterator_tag iterator_category;
129
130 iterator() : b_(nullptr), end_(nullptr), i_(0) {}
131
132 // Make iterator pointing at first element at or after b.
133 iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); }
134
135 // Make iterator pointing exactly at ith element in b, which must exist.
136 iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) {
137 FillValue();
138 }
139
140 reference operator*() { return *val(); }
141 pointer operator->() { return val(); }
142 bool operator==(const iterator& x) const {
143 return b_ == x.b_ && i_ == x.i_;
144 }
145 bool operator!=(const iterator& x) const { return !(*this == x); }
146 iterator& operator++() {
147 DCHECK(b_ != end_);
148 i_++;
149 SkipUnused();
150 return *this;
151 }
152 iterator operator++(int /*indicates postfix*/) {
153 iterator tmp(*this);
154 ++*this;
155 return tmp;
156 }
157
158 private:
159 friend class FlatMap;
160 Bucket* b_;
161 Bucket* end_;
162 char space_ alignas(value_type)[sizeof(value_type)];
163 uint32 i_;
164
165 pointer val() { return reinterpret_cast<pointer>(space_); }
166 void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); }
167 void SkipUnused() {
168 while (b_ < end_) {
169 if (i_ >= Rep::kWidth) {
170 i_ = 0;
171 b_++;
172 } else if (b_->marker[i_] < 2) {
173 i_++;
174 } else {
175 FillValue();
176 break;
177 }
178 }
179 }
180 };
181
182 class const_iterator {
183 private:
184 mutable iterator rep_; // Share state and logic with non-const iterator.
185
186 public:
187 typedef typename FlatMap::difference_type difference_type;
188 typedef typename FlatMap::value_type value_type;
189 typedef typename FlatMap::const_pointer pointer;
190 typedef typename FlatMap::const_reference reference;
191 typedef ::std::forward_iterator_tag iterator_category;
192
193 const_iterator() : rep_() {}
194 const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {}
195 const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {}
196
197 reference operator*() const { return *rep_.val(); }
198 pointer operator->() const { return rep_.val(); }
199 bool operator==(const const_iterator& x) const { return rep_ == x.rep_; }
200 bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; }
201 const_iterator& operator++() {
202 ++rep_;
203 return *this;
204 }
205 const_iterator operator++(int /*indicates postfix*/) {
206 const_iterator tmp(*this);
207 ++*this;
208 return tmp;
209 }
210 };
211
212 iterator begin() { return iterator(rep_.start(), rep_.limit()); }
213 iterator end() { return iterator(rep_.limit(), rep_.limit()); }
214 const_iterator begin() const {
215 return const_iterator(rep_.start(), rep_.limit());
216 }
217 const_iterator end() const {
218 return const_iterator(rep_.limit(), rep_.limit());
219 }
220
221 size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
222 iterator find(const Key& k) {
223 auto r = rep_.Find(k);
224 return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
225 }
226 const_iterator find(const Key& k) const {
227 auto r = rep_.Find(k);
228 return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
229 }
230
231 Val& at(const Key& k) {
232 auto r = rep_.Find(k);
233 DCHECK(r.found);
234 return r.b->val(r.index);
235 }
236 const Val& at(const Key& k) const {
237 auto r = rep_.Find(k);
238 DCHECK(r.found);
239 return r.b->val(r.index);
240 }
241
242 template <typename P>
243 std::pair<iterator, bool> insert(const P& p) {
244 return Insert(p.first, p.second);
245 }
246 std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) {
247 return Insert(p.first, p.second);
248 }
249 template <typename InputIter>
250 void insert(InputIter first, InputIter last) {
251 for (; first != last; ++first) {
252 insert(*first);
253 }
254 }
255
256 Val& operator[](const Key& k) { return IndexOp(k); }
257 Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); }
258
259 template <typename... Args>
260 std::pair<iterator, bool> emplace(Args&&... args) {
261 return InsertPair(std::make_pair(std::forward<Args>(args)...));
262 }
263
264 size_t erase(const Key& k) {
265 auto r = rep_.Find(k);
266 if (!r.found) return 0;
267 rep_.Erase(r.b, r.index);
268 return 1;
269 }
270 iterator erase(iterator pos) {
271 rep_.Erase(pos.b_, pos.i_);
272 ++pos;
273 return pos;
274 }
275 iterator erase(iterator pos, iterator last) {
276 for (; pos != last; ++pos) {
277 rep_.Erase(pos.b_, pos.i_);
278 }
279 return pos;
280 }
281
282 std::pair<iterator, iterator> equal_range(const Key& k) {
283 auto pos = find(k);
284 if (pos == end()) {
285 return std::make_pair(pos, pos);
286 } else {
287 auto next = pos;
288 ++next;
289 return std::make_pair(pos, next);
290 }
291 }
292 std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
293 auto pos = find(k);
294 if (pos == end()) {
295 return std::make_pair(pos, pos);
296 } else {
297 auto next = pos;
298 ++next;
299 return std::make_pair(pos, next);
300 }
301 }
302
303 bool operator==(const FlatMap& x) const {
304 if (size() != x.size()) return false;
305 for (auto& p : x) {
306 auto i = find(p.first);
307 if (i == end()) return false;
308 if (i->second != p.second) return false;
309 }
310 return true;
311 }
312 bool operator!=(const FlatMap& x) const { return !(*this == x); }
313
314 // If key exists in the table, prefetch the associated value. This
315 // is a hint, and may have no effect.
316 void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
317
318 private:
319 using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
320
321 // Bucket stores kWidth <marker, key, value> triples.
322 // The data is organized as three parallel arrays to reduce padding.
323 struct Bucket {
324 uint8 marker[Rep::kWidth];
325
326 // Wrap keys and values in union to control construction and destruction.
327 union Storage {
328 struct {
329 Key key[Rep::kWidth];
330 Val val[Rep::kWidth];
331 };
332 Storage() {}
333 ~Storage() {}
334 } storage;
335
336 Key& key(uint32 i) {
337 DCHECK_GE(marker[i], 2);
338 return storage.key[i];
339 }
340 Val& val(uint32 i) {
341 DCHECK_GE(marker[i], 2);
342 return storage.val[i];
343 }
344 template <typename V>
345 void InitVal(uint32 i, V&& v) {
346 new (&storage.val[i]) Val(std::forward<V>(v));
347 }
348 void Destroy(uint32 i) {
349 storage.key[i].Key::~Key();
350 storage.val[i].Val::~Val();
351 }
352 void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
353 new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
354 new (&storage.val[i]) Val(std::move(src->storage.val[src_index]));
355 }
356 void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
357 new (&storage.key[i]) Key(src->storage.key[src_index]);
358 new (&storage.val[i]) Val(src->storage.val[src_index]);
359 }
360 };
361
362 template <typename Pair>
363 std::pair<iterator, bool> InsertPair(Pair&& p) {
364 return Insert(std::forward<decltype(p.first)>(p.first),
365 std::forward<decltype(p.second)>(p.second));
366 }
367
368 template <typename K, typename V>
369 std::pair<iterator, bool> Insert(K&& k, V&& v) {
370 rep_.MaybeResize();
371 auto r = rep_.FindOrInsert(std::forward<K>(k));
372 const bool inserted = !r.found;
373 if (inserted) {
374 r.b->InitVal(r.index, std::forward<V>(v));
375 }
376 return {iterator(r.b, rep_.limit(), r.index), inserted};
377 }
378
379 template <typename K>
380 Val& IndexOp(K&& k) {
381 rep_.MaybeResize();
382 auto r = rep_.FindOrInsert(std::forward<K>(k));
383 Val* vptr = &r.b->val(r.index);
384 if (!r.found) {
385 new (vptr) Val(); // Initialize value in new slot.
386 }
387 return *vptr;
388 }
389
390 Rep rep_;
391};
392
393} // namespace gtl
394} // namespace tsl
395
396#endif // TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_
397