1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ |
17 | #define TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ |
18 | |
19 | #include <stddef.h> |
20 | |
21 | #include <functional> |
22 | #include <initializer_list> |
23 | #include <iterator> |
24 | #include <utility> |
25 | |
26 | #include "tensorflow/tsl/lib/gtl/flatrep.h" |
27 | #include "tensorflow/tsl/platform/hash.h" |
28 | #include "tensorflow/tsl/platform/logging.h" |
29 | #include "tensorflow/tsl/platform/types.h" |
30 | |
31 | namespace tsl { |
32 | namespace gtl { |
33 | |
34 | // FlatMap<K,V,...> provides a map from K to V. |
35 | // |
36 | // The map is implemented using an open-addressed hash table. A |
37 | // single array holds entire map contents and collisions are resolved |
38 | // by probing at a sequence of locations in the array. |
39 | template <typename Key, typename Val, class Hash = hash<Key>, |
40 | class Eq = std::equal_to<Key>> |
41 | class FlatMap { |
42 | private: |
43 | // Forward declare some internal types needed in public section. |
44 | struct Bucket; |
45 | |
46 | // We cannot use std::pair<> since internal representation stores |
47 | // keys and values in separate arrays, so we make a custom struct |
48 | // that holds references to the internal key, value elements. |
49 | // |
50 | // We define the struct as private ValueType, and typedef it as public |
51 | // value_type, to work around a gcc bug when compiling the iterators. |
52 | struct ValueType { |
53 | typedef Key first_type; |
54 | typedef Val second_type; |
55 | |
56 | const Key& first; |
57 | Val& second; |
58 | ValueType(const Key& k, Val& v) : first(k), second(v) {} |
59 | }; |
60 | |
61 | public: |
62 | typedef Key key_type; |
63 | typedef Val mapped_type; |
64 | typedef Hash hasher; |
65 | typedef Eq key_equal; |
66 | typedef size_t size_type; |
67 | typedef ptrdiff_t difference_type; |
68 | typedef ValueType value_type; |
69 | typedef value_type* pointer; |
70 | typedef const value_type* const_pointer; |
71 | typedef value_type& reference; |
72 | typedef const value_type& const_reference; |
73 | |
74 | FlatMap() : FlatMap(1) {} |
75 | |
76 | explicit FlatMap(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) |
77 | : rep_(N, hf, eq) {} |
78 | |
79 | FlatMap(const FlatMap& src) : rep_(src.rep_) {} |
80 | |
81 | // Move constructor leaves src in a valid but unspecified state (same as |
82 | // std::unordered_map). |
83 | FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {} |
84 | |
85 | template <typename InputIter> |
86 | FlatMap(InputIter first, InputIter last, size_t N = 1, |
87 | const Hash& hf = Hash(), const Eq& eq = Eq()) |
88 | : FlatMap(N, hf, eq) { |
89 | insert(first, last); |
90 | } |
91 | |
92 | FlatMap(std::initializer_list<std::pair<const Key, Val>> init, size_t N = 1, |
93 | const Hash& hf = Hash(), const Eq& eq = Eq()) |
94 | : FlatMap(init.begin(), init.end(), N, hf, eq) {} |
95 | |
96 | FlatMap& operator=(const FlatMap& src) { |
97 | rep_.CopyFrom(src.rep_); |
98 | return *this; |
99 | } |
100 | |
101 | // Move-assignment operator leaves src in a valid but unspecified state (same |
102 | // as std::unordered_map). |
103 | FlatMap& operator=(FlatMap&& src) { |
104 | rep_.MoveFrom(std::move(src.rep_)); |
105 | return *this; |
106 | } |
107 | |
108 | ~FlatMap() {} |
109 | |
110 | void swap(FlatMap& x) { rep_.swap(x.rep_); } |
111 | void clear_no_resize() { rep_.clear_no_resize(); } |
112 | void clear() { rep_.clear(); } |
113 | void reserve(size_t N) { rep_.Resize(std::max(N, size())); } |
114 | void rehash(size_t N) { rep_.Resize(std::max(N, size())); } |
115 | void resize(size_t N) { rep_.Resize(std::max(N, size())); } |
116 | size_t size() const { return rep_.size(); } |
117 | bool empty() const { return size() == 0; } |
118 | size_t bucket_count() const { return rep_.bucket_count(); } |
119 | hasher hash_function() const { return rep_.hash_function(); } |
120 | key_equal key_eq() const { return rep_.key_eq(); } |
121 | |
122 | class iterator { |
123 | public: |
124 | typedef typename FlatMap::difference_type difference_type; |
125 | typedef typename FlatMap::value_type value_type; |
126 | typedef typename FlatMap::pointer pointer; |
127 | typedef typename FlatMap::reference reference; |
128 | typedef ::std::forward_iterator_tag iterator_category; |
129 | |
130 | iterator() : b_(nullptr), end_(nullptr), i_(0) {} |
131 | |
132 | // Make iterator pointing at first element at or after b. |
133 | iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { SkipUnused(); } |
134 | |
135 | // Make iterator pointing exactly at ith element in b, which must exist. |
136 | iterator(Bucket* b, Bucket* end, uint32 i) : b_(b), end_(end), i_(i) { |
137 | FillValue(); |
138 | } |
139 | |
140 | reference operator*() { return *val(); } |
141 | pointer operator->() { return val(); } |
142 | bool operator==(const iterator& x) const { |
143 | return b_ == x.b_ && i_ == x.i_; |
144 | } |
145 | bool operator!=(const iterator& x) const { return !(*this == x); } |
146 | iterator& operator++() { |
147 | DCHECK(b_ != end_); |
148 | i_++; |
149 | SkipUnused(); |
150 | return *this; |
151 | } |
152 | iterator operator++(int /*indicates postfix*/) { |
153 | iterator tmp(*this); |
154 | ++*this; |
155 | return tmp; |
156 | } |
157 | |
158 | private: |
159 | friend class FlatMap; |
160 | Bucket* b_; |
161 | Bucket* end_; |
162 | char space_ alignas(value_type)[sizeof(value_type)]; |
163 | uint32 i_; |
164 | |
165 | pointer val() { return reinterpret_cast<pointer>(space_); } |
166 | void FillValue() { new (space_) value_type(b_->key(i_), b_->val(i_)); } |
167 | void SkipUnused() { |
168 | while (b_ < end_) { |
169 | if (i_ >= Rep::kWidth) { |
170 | i_ = 0; |
171 | b_++; |
172 | } else if (b_->marker[i_] < 2) { |
173 | i_++; |
174 | } else { |
175 | FillValue(); |
176 | break; |
177 | } |
178 | } |
179 | } |
180 | }; |
181 | |
182 | class const_iterator { |
183 | private: |
184 | mutable iterator rep_; // Share state and logic with non-const iterator. |
185 | |
186 | public: |
187 | typedef typename FlatMap::difference_type difference_type; |
188 | typedef typename FlatMap::value_type value_type; |
189 | typedef typename FlatMap::const_pointer pointer; |
190 | typedef typename FlatMap::const_reference reference; |
191 | typedef ::std::forward_iterator_tag iterator_category; |
192 | |
193 | const_iterator() : rep_() {} |
194 | const_iterator(Bucket* start, Bucket* end) : rep_(start, end) {} |
195 | const_iterator(Bucket* b, Bucket* end, uint32 i) : rep_(b, end, i) {} |
196 | |
197 | reference operator*() const { return *rep_.val(); } |
198 | pointer operator->() const { return rep_.val(); } |
199 | bool operator==(const const_iterator& x) const { return rep_ == x.rep_; } |
200 | bool operator!=(const const_iterator& x) const { return rep_ != x.rep_; } |
201 | const_iterator& operator++() { |
202 | ++rep_; |
203 | return *this; |
204 | } |
205 | const_iterator operator++(int /*indicates postfix*/) { |
206 | const_iterator tmp(*this); |
207 | ++*this; |
208 | return tmp; |
209 | } |
210 | }; |
211 | |
212 | iterator begin() { return iterator(rep_.start(), rep_.limit()); } |
213 | iterator end() { return iterator(rep_.limit(), rep_.limit()); } |
214 | const_iterator begin() const { |
215 | return const_iterator(rep_.start(), rep_.limit()); |
216 | } |
217 | const_iterator end() const { |
218 | return const_iterator(rep_.limit(), rep_.limit()); |
219 | } |
220 | |
221 | size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } |
222 | iterator find(const Key& k) { |
223 | auto r = rep_.Find(k); |
224 | return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); |
225 | } |
226 | const_iterator find(const Key& k) const { |
227 | auto r = rep_.Find(k); |
228 | return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); |
229 | } |
230 | |
231 | Val& at(const Key& k) { |
232 | auto r = rep_.Find(k); |
233 | DCHECK(r.found); |
234 | return r.b->val(r.index); |
235 | } |
236 | const Val& at(const Key& k) const { |
237 | auto r = rep_.Find(k); |
238 | DCHECK(r.found); |
239 | return r.b->val(r.index); |
240 | } |
241 | |
242 | template <typename P> |
243 | std::pair<iterator, bool> insert(const P& p) { |
244 | return Insert(p.first, p.second); |
245 | } |
246 | std::pair<iterator, bool> insert(const std::pair<const Key, Val>& p) { |
247 | return Insert(p.first, p.second); |
248 | } |
249 | template <typename InputIter> |
250 | void insert(InputIter first, InputIter last) { |
251 | for (; first != last; ++first) { |
252 | insert(*first); |
253 | } |
254 | } |
255 | |
256 | Val& operator[](const Key& k) { return IndexOp(k); } |
257 | Val& operator[](Key&& k) { return IndexOp(std::forward<Key>(k)); } |
258 | |
259 | template <typename... Args> |
260 | std::pair<iterator, bool> emplace(Args&&... args) { |
261 | return InsertPair(std::make_pair(std::forward<Args>(args)...)); |
262 | } |
263 | |
264 | size_t erase(const Key& k) { |
265 | auto r = rep_.Find(k); |
266 | if (!r.found) return 0; |
267 | rep_.Erase(r.b, r.index); |
268 | return 1; |
269 | } |
270 | iterator erase(iterator pos) { |
271 | rep_.Erase(pos.b_, pos.i_); |
272 | ++pos; |
273 | return pos; |
274 | } |
275 | iterator erase(iterator pos, iterator last) { |
276 | for (; pos != last; ++pos) { |
277 | rep_.Erase(pos.b_, pos.i_); |
278 | } |
279 | return pos; |
280 | } |
281 | |
282 | std::pair<iterator, iterator> equal_range(const Key& k) { |
283 | auto pos = find(k); |
284 | if (pos == end()) { |
285 | return std::make_pair(pos, pos); |
286 | } else { |
287 | auto next = pos; |
288 | ++next; |
289 | return std::make_pair(pos, next); |
290 | } |
291 | } |
292 | std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { |
293 | auto pos = find(k); |
294 | if (pos == end()) { |
295 | return std::make_pair(pos, pos); |
296 | } else { |
297 | auto next = pos; |
298 | ++next; |
299 | return std::make_pair(pos, next); |
300 | } |
301 | } |
302 | |
303 | bool operator==(const FlatMap& x) const { |
304 | if (size() != x.size()) return false; |
305 | for (auto& p : x) { |
306 | auto i = find(p.first); |
307 | if (i == end()) return false; |
308 | if (i->second != p.second) return false; |
309 | } |
310 | return true; |
311 | } |
312 | bool operator!=(const FlatMap& x) const { return !(*this == x); } |
313 | |
314 | // If key exists in the table, prefetch the associated value. This |
315 | // is a hint, and may have no effect. |
316 | void prefetch_value(const Key& key) const { rep_.Prefetch(key); } |
317 | |
318 | private: |
319 | using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; |
320 | |
321 | // Bucket stores kWidth <marker, key, value> triples. |
322 | // The data is organized as three parallel arrays to reduce padding. |
323 | struct Bucket { |
324 | uint8 marker[Rep::kWidth]; |
325 | |
326 | // Wrap keys and values in union to control construction and destruction. |
327 | union Storage { |
328 | struct { |
329 | Key key[Rep::kWidth]; |
330 | Val val[Rep::kWidth]; |
331 | }; |
332 | Storage() {} |
333 | ~Storage() {} |
334 | } storage; |
335 | |
336 | Key& key(uint32 i) { |
337 | DCHECK_GE(marker[i], 2); |
338 | return storage.key[i]; |
339 | } |
340 | Val& val(uint32 i) { |
341 | DCHECK_GE(marker[i], 2); |
342 | return storage.val[i]; |
343 | } |
344 | template <typename V> |
345 | void InitVal(uint32 i, V&& v) { |
346 | new (&storage.val[i]) Val(std::forward<V>(v)); |
347 | } |
348 | void Destroy(uint32 i) { |
349 | storage.key[i].Key::~Key(); |
350 | storage.val[i].Val::~Val(); |
351 | } |
352 | void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { |
353 | new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); |
354 | new (&storage.val[i]) Val(std::move(src->storage.val[src_index])); |
355 | } |
356 | void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { |
357 | new (&storage.key[i]) Key(src->storage.key[src_index]); |
358 | new (&storage.val[i]) Val(src->storage.val[src_index]); |
359 | } |
360 | }; |
361 | |
362 | template <typename Pair> |
363 | std::pair<iterator, bool> InsertPair(Pair&& p) { |
364 | return Insert(std::forward<decltype(p.first)>(p.first), |
365 | std::forward<decltype(p.second)>(p.second)); |
366 | } |
367 | |
368 | template <typename K, typename V> |
369 | std::pair<iterator, bool> Insert(K&& k, V&& v) { |
370 | rep_.MaybeResize(); |
371 | auto r = rep_.FindOrInsert(std::forward<K>(k)); |
372 | const bool inserted = !r.found; |
373 | if (inserted) { |
374 | r.b->InitVal(r.index, std::forward<V>(v)); |
375 | } |
376 | return {iterator(r.b, rep_.limit(), r.index), inserted}; |
377 | } |
378 | |
379 | template <typename K> |
380 | Val& IndexOp(K&& k) { |
381 | rep_.MaybeResize(); |
382 | auto r = rep_.FindOrInsert(std::forward<K>(k)); |
383 | Val* vptr = &r.b->val(r.index); |
384 | if (!r.found) { |
385 | new (vptr) Val(); // Initialize value in new slot. |
386 | } |
387 | return *vptr; |
388 | } |
389 | |
390 | Rep rep_; |
391 | }; |
392 | |
393 | } // namespace gtl |
394 | } // namespace tsl |
395 | |
396 | #endif // TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ |
397 | |