1 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ |
2 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ |
3 | |
4 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
5 | |
6 | Licensed under the Apache License, Version 2.0 (the "License"); |
7 | you may not use this file except in compliance with the License. |
8 | You may obtain a copy of the License at |
9 | |
10 | http://www.apache.org/licenses/LICENSE-2.0 |
11 | |
12 | Unless required by applicable law or agreed to in writing, software |
13 | distributed under the License is distributed on an "AS IS" BASIS, |
14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | See the License for the specific language governing permissions and |
16 | limitations under the License. |
17 | ==============================================================================*/ |
18 | |
19 | #include <atomic> |
20 | |
21 | #include "tensorflow/core/lib/gtl/flatmap.h" |
22 | #include "tensorflow/core/lib/hash/hash.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/util/port.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // PendingCounts is an internal helper class to keep track of pending and |
30 | // dead counts for nodes, for use in the ExecutorState module. It |
31 | // holds a map from Handles to various counts for that handle. This |
32 | // information is needed per frame iteration. The amount of memory |
33 | // needed for an iteration is the same across all executions of the |
34 | // iteration. The memory amount and handles are precomputed at startup |
35 | // using a Layout object. |
36 | // |
37 | // PendingCounts::Layout layout; |
38 | // std::vector<PendingCounts::Handle> h(C); |
39 | // for (int id = 0; id < C; id++) { |
40 | // h[id] = r.AddHandle(max_pending[id], max_dead[id]); |
41 | // } |
42 | // |
43 | // When we actually want to start an iteration we first create a |
44 | // PendingCounts object and then index into it using the precomputed |
45 | // handles: |
46 | |
47 | // PendingCounts counts(layout); |
48 | // ... |
49 | // counts.decrement_pending(h[id], 1); |
50 | class PendingCounts { |
51 | public: |
52 | // The state machine for a node's execution. |
53 | enum NodeState { |
54 | // The pending count for the node > 0. |
55 | PENDING_NOTREADY, |
56 | // The pending count for the node == 0, but the node has not |
57 | // started executing. |
58 | PENDING_READY, |
59 | // The node has started executing. |
60 | STARTED, |
61 | // The node has finished executing. |
62 | COMPLETED |
63 | }; |
64 | |
65 | // An opaque handle indicating where in the PendingCounts data structure |
66 | // the appropriate count information can be found. |
67 | class Handle; |
68 | // Given a node that needs to represent counts no larger than the |
69 | // specified "max_pending_count" and "max_dead_count", create a |
70 | // handle that can be passed to various PendingCounts routines |
71 | // to retrieve the count data for this node. |
72 | class Layout { |
73 | public: |
74 | Handle CreateHandle(size_t max_pending_count, size_t max_dead_count); |
75 | |
76 | private: |
77 | friend class PendingCounts; |
78 | int next_offset_ = 0; // Next byte offset to allocate |
79 | }; |
80 | |
81 | // Create a new PendingCounts object that can hold the state of |
82 | // all the Handles allocated from "final_allocator". |
83 | explicit PendingCounts(Layout layout) |
84 | : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) { |
85 | if (num_bytes_ >= sizeof(LargeCounts)) { |
86 | CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0); |
87 | } |
88 | } |
89 | |
90 | // Create a new PendingCounts object with the same layout and counts |
91 | // as "other". |
92 | explicit PendingCounts(const PendingCounts& other) |
93 | : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) { |
94 | if (num_bytes_ >= sizeof(LargeCounts)) { |
95 | CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0); |
96 | } |
97 | memcpy(bytes_, other.bytes_, other.num_bytes_); |
98 | } |
99 | |
100 | ~PendingCounts() { delete[] bytes_; } |
101 | |
102 | void set_initial_count(Handle h, size_t pending_count) { |
103 | if (h.is_large_) { |
104 | std::atomic<LargeCounts>* c_ptr = Large(h); |
105 | auto c = c_ptr->load(std::memory_order_relaxed); |
106 | c.pending = pending_count; |
107 | c.dead_count = 0; |
108 | c.has_started = 0; |
109 | c_ptr->store(c, std::memory_order_relaxed); |
110 | } else { |
111 | DCHECK_LE(pending_count, kMaxCountForPackedCounts); |
112 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
113 | auto c = c_ptr->load(std::memory_order_relaxed); |
114 | c.pending = pending_count; |
115 | c.dead_count = 0; |
116 | c.has_started = 0; |
117 | c_ptr->store(c, std::memory_order_relaxed); |
118 | } |
119 | } |
120 | |
121 | NodeState node_state(Handle h) { |
122 | if (h.is_large_) { |
123 | return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed)); |
124 | } else { |
125 | return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed)); |
126 | } |
127 | } |
128 | void mark_started(Handle h) { |
129 | DCHECK_EQ(pending(h), 0); |
130 | if (h.is_large_) { |
131 | std::atomic<LargeCounts>* c_ptr = Large(h); |
132 | auto c = c_ptr->load(std::memory_order_relaxed); |
133 | DCHECK_EQ(c.has_started, 0); |
134 | c.has_started = 1; |
135 | c_ptr->store(c, std::memory_order_relaxed); |
136 | } else { |
137 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
138 | auto c = c_ptr->load(std::memory_order_relaxed); |
139 | DCHECK_EQ(c.has_started, 0); |
140 | c.has_started = 1; |
141 | c_ptr->store(c, std::memory_order_relaxed); |
142 | } |
143 | } |
144 | void mark_completed(Handle h) { |
145 | if (h.is_large_) { |
146 | std::atomic<LargeCounts>* c_ptr = Large(h); |
147 | auto c = c_ptr->load(std::memory_order_relaxed); |
148 | DCHECK_EQ(c.has_started, 1); |
149 | c.pending = 1; |
150 | c_ptr->store(c, std::memory_order_relaxed); |
151 | } else { |
152 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
153 | auto c = c_ptr->load(std::memory_order_relaxed); |
154 | DCHECK_EQ(c.has_started, 1); |
155 | c.pending = 1; |
156 | c_ptr->store(c, std::memory_order_relaxed); |
157 | } |
158 | } |
159 | int pending(Handle h) { |
160 | if (h.is_large_) { |
161 | LargeCounts c = Large(h)->load(std::memory_order_relaxed); |
162 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
163 | return c.pending; |
164 | } else { |
165 | // The pending count encodes the state once the node has |
166 | // started, so just return 0. |
167 | return 0; |
168 | } |
169 | } else { |
170 | PackedCounts c = Packed(h)->load(std::memory_order_relaxed); |
171 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
172 | return c.pending; |
173 | } else { |
174 | // The pending count encodes the state once the node has |
175 | // started, so just return 0. |
176 | return 0; |
177 | } |
178 | } |
179 | } |
180 | struct AdjustResult { |
181 | int dead_count; |
182 | int pending_count; |
183 | |
184 | AdjustResult(int dead_count, int pending_count) |
185 | : dead_count(dead_count), pending_count(pending_count) {} |
186 | }; |
187 | int decrement_pending(Handle h, int v) { |
188 | DCHECK_GE(pending(h), v); |
189 | if (h.is_large_) { |
190 | std::atomic<LargeCounts>* c_ptr = Large(h); |
191 | auto c = c_ptr->load(std::memory_order_relaxed); |
192 | c.pending -= v; |
193 | c_ptr->store(c, std::memory_order_relaxed); |
194 | return c.pending; |
195 | } else { |
196 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
197 | auto c = c_ptr->load(std::memory_order_relaxed); |
198 | c.pending -= v; |
199 | c_ptr->store(c, std::memory_order_relaxed); |
200 | return c.pending; |
201 | } |
202 | } |
203 | |
204 | // Mark a merge node as live |
205 | // REQUIRES: Node corresponding to "h" is a merge node |
206 | void mark_live(Handle h) { |
207 | if (h.is_large_) { |
208 | std::atomic<LargeCounts>* c_ptr = Large(h); |
209 | auto c = c_ptr->load(std::memory_order_relaxed); |
210 | // Only do anything if the node hasn't already started executing. |
211 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
212 | c.pending &= ~static_cast<int>(0x1); |
213 | c_ptr->store(c, std::memory_order_relaxed); |
214 | } |
215 | } else { |
216 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
217 | auto c = c_ptr->load(std::memory_order_relaxed); |
218 | // Only do anything if the node hasn't already started executing. |
219 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
220 | static_assert(7 == kMaxCountForPackedCounts, |
221 | "Live flag incorrect for max packed count" ); |
222 | c.pending &= 0x6; |
223 | c_ptr->store(c, std::memory_order_relaxed); |
224 | } |
225 | } |
226 | } |
227 | |
228 | int dead_count(Handle h) { |
229 | int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count |
230 | : Packed(h)->load(std::memory_order_relaxed).dead_count; |
231 | return r; |
232 | } |
233 | void increment_dead_count(Handle h) { |
234 | if (h.is_large_) { |
235 | std::atomic<LargeCounts>* c_ptr = Large(h); |
236 | auto c = c_ptr->load(std::memory_order_relaxed); |
237 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
238 | c.dead_count++; |
239 | c_ptr->store(c, std::memory_order_relaxed); |
240 | } |
241 | } else { |
242 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
243 | auto c = c_ptr->load(std::memory_order_relaxed); |
244 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
245 | DCHECK_LT(c.dead_count, kMaxCountForPackedCounts); |
246 | c.dead_count++; |
247 | c_ptr->store(c, std::memory_order_relaxed); |
248 | } |
249 | } |
250 | } |
251 | |
252 | // Mark a merge node as live. Please note that the pending count it returns |
253 | // is before the update. |
254 | AdjustResult adjust_for_mark_live(Handle h) { |
255 | if (h.is_large_) { |
256 | std::atomic<LargeCounts>* c_ptr = Large(h); |
257 | auto c = c_ptr->load(std::memory_order_relaxed); |
258 | auto ret_pending = 0; |
259 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
260 | ret_pending = c.pending; |
261 | c.pending &= ~static_cast<int>(0x1); |
262 | c_ptr->store(c, std::memory_order_relaxed); |
263 | } |
264 | return AdjustResult(c.dead_count, ret_pending); |
265 | } else { |
266 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
267 | auto c = c_ptr->load(std::memory_order_relaxed); |
268 | auto ret_pending = 0; |
269 | if (PENDING_NOTREADY == NodeStateForStruct(c)) { |
270 | static_assert(7 == kMaxCountForPackedCounts, |
271 | "Live flag incorrect for max packed count" ); |
272 | ret_pending = c.pending; |
273 | c.pending &= 0x6; |
274 | c_ptr->store(c, std::memory_order_relaxed); |
275 | } |
276 | return AdjustResult(c.dead_count, ret_pending); |
277 | } |
278 | } |
279 | |
280 | // The same as the above, but performs the operation atomically. This |
281 | // is thread-safe to run concurrently with other threads. |
282 | AdjustResult adjust_for_mark_live_atomic(Handle h) { |
283 | if (h.is_large_) { |
284 | std::atomic<LargeCounts>* c_ptr = Large(h); |
285 | auto old_val = c_ptr->load(std::memory_order_relaxed); |
286 | while (true) { |
287 | auto new_val = old_val; |
288 | auto ret_pending = 0; |
289 | // Only do anything if the node hasn't already started executing. |
290 | if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { |
291 | ret_pending = old_val.pending; |
292 | new_val.pending &= ~static_cast<int>(0x1); |
293 | } |
294 | AdjustResult ret(old_val.dead_count, ret_pending); |
295 | if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val))) |
296 | return ret; |
297 | } |
298 | } else { |
299 | std::atomic<PackedCounts>* c_ptr = Packed(h); |
300 | auto old_val = c_ptr->load(std::memory_order_relaxed); |
301 | while (true) { |
302 | auto new_val = old_val; |
303 | auto ret_pending = 0; |
304 | // Only do anything if the node hasn't already started executing. |
305 | if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { |
306 | static_assert(7 == kMaxCountForPackedCounts, |
307 | "Live flag incorrect for max packed count" ); |
308 | ret_pending = old_val.pending; |
309 | new_val.pending &= 0x6; |
310 | } |
311 | AdjustResult ret(old_val.dead_count, ret_pending); |
312 | if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val))) |
313 | return ret; |
314 | } |
315 | } |
316 | } |
317 | |
318 | // A streamlined routine that does several pieces of bookkeeping at |
319 | // once. Equivalent to: |
320 | // increment_dead_count(h); |
321 | // return {dead_count(h) pending(h)}; |
322 | AdjustResult adjust_for_increment_dead(Handle h) { |
323 | if (h.is_large_) { |
324 | return adjust_for_increment_dead_shared(Large(h)); |
325 | } else { |
326 | return adjust_for_increment_dead_shared(Packed(h)); |
327 | } |
328 | } |
329 | |
330 | // The same as the above, but performs the operation atomically. This |
331 | // is thread-safe to run concurrently with other threads. |
332 | AdjustResult adjust_for_increment_dead_atomic(Handle h) { |
333 | if (h.is_large_) { |
334 | return adjust_for_increment_dead_shared_atomic(Large(h)); |
335 | } else { |
336 | return adjust_for_increment_dead_shared_atomic(Packed(h)); |
337 | } |
338 | } |
339 | |
340 | // A streamlined routine that does several pieces of bookkeeping at |
341 | // once. Equivalent to: |
342 | // decrement_pending(h, decrement_pending); |
343 | // return {dead_count(h) pending(h)}; |
344 | AdjustResult adjust_for_decrement_pending(Handle h, int decrement_pending) { |
345 | DCHECK_GE(pending(h), decrement_pending); |
346 | if (h.is_large_) { |
347 | return adjust_for_decrement_pending_shared(Large(h), decrement_pending); |
348 | } else { |
349 | return adjust_for_decrement_pending_shared(Packed(h), decrement_pending); |
350 | } |
351 | } |
352 | |
353 | // The same as the above, but performs the operation atomically. This |
354 | // is thread-safe to run concurrently with other threads. |
355 | AdjustResult adjust_for_decrement_pending_atomic(Handle h, |
356 | int decrement_pending) { |
357 | DCHECK_GE(pending(h), decrement_pending); |
358 | if (h.is_large_) { |
359 | return adjust_for_decrement_pending_shared_atomic(Large(h), |
360 | decrement_pending); |
361 | } else { |
362 | return adjust_for_decrement_pending_shared_atomic(Packed(h), |
363 | decrement_pending); |
364 | } |
365 | } |
366 | |
367 | // A streamlined routine that does several pieces of bookkeeping at |
368 | // once. Equivalent to: |
369 | // if (increment_dead) increment_dead_count(h); |
370 | // decrement_pending(h, 1); |
371 | // return {dead_count(h), pending(h)}; |
372 | AdjustResult adjust_for_activation(Handle h, bool increment_dead) { |
373 | DCHECK_GE(pending(h), 1); |
374 | if (h.is_large_) { |
375 | return adjust_for_activation_shared(Large(h), increment_dead); |
376 | } else { |
377 | return adjust_for_activation_shared(Packed(h), increment_dead); |
378 | } |
379 | } |
380 | |
381 | // The same as the above, but performs the operation atomically. This |
382 | // is thread-safe to run concurrently with other threads. |
383 | AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) { |
384 | DCHECK_GE(pending(h), 1); |
385 | if (h.is_large_) { |
386 | return adjust_for_activation_shared_atomic(Large(h), increment_dead); |
387 | } else { |
388 | return adjust_for_activation_shared_atomic(Packed(h), increment_dead); |
389 | } |
390 | } |
391 | |
392 | class Handle { |
393 | public: |
394 | Handle() : byte_offset_(0), is_large_(0) {} |
395 | |
396 | private: |
397 | friend class PendingCounts; |
398 | int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object |
399 | bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts |
400 | }; |
401 | |
402 | private: |
403 | template <typename T> |
404 | inline AdjustResult adjust_for_increment_dead_shared(std::atomic<T>* c) { |
405 | T val = c->load(std::memory_order_relaxed); |
406 | auto ret_pending = 0; |
407 | // Only do anything if the node hasn't already started executing. |
408 | if (PENDING_NOTREADY == NodeStateForStruct(val)) { |
409 | val.dead_count++; |
410 | ret_pending = val.pending; |
411 | c->store(val, std::memory_order_relaxed); |
412 | } |
413 | return AdjustResult(val.dead_count, ret_pending); |
414 | } |
415 | |
416 | template <typename T> |
417 | inline AdjustResult adjust_for_increment_dead_shared_atomic( |
418 | std::atomic<T>* c) { |
419 | T old_val = c->load(std::memory_order_relaxed); |
420 | while (true) { |
421 | auto new_val = old_val; |
422 | auto ret_pending = 0; |
423 | // Only do anything if the node hasn't already started executing. |
424 | if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { |
425 | ret_pending = new_val.pending; |
426 | new_val.dead_count++; |
427 | } |
428 | AdjustResult ret(new_val.dead_count, ret_pending); |
429 | if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) |
430 | return ret; |
431 | } |
432 | } |
433 | |
434 | template <typename T> |
435 | inline AdjustResult adjust_for_decrement_pending_shared( |
436 | std::atomic<T>* c, int decrement_pending) { |
437 | T val = c->load(std::memory_order_relaxed); |
438 | DCHECK_GE(val.pending, decrement_pending); |
439 | val.pending -= decrement_pending; |
440 | c->store(val, std::memory_order_relaxed); |
441 | return AdjustResult(val.dead_count, val.pending); |
442 | } |
443 | |
444 | template <typename T> |
445 | inline AdjustResult adjust_for_decrement_pending_shared_atomic( |
446 | std::atomic<T>* c, int decrement_pending) { |
447 | T old_val = c->load(std::memory_order_relaxed); |
448 | while (true) { |
449 | T new_val = old_val; |
450 | DCHECK_GE(new_val.pending, decrement_pending); |
451 | new_val.pending -= decrement_pending; |
452 | AdjustResult ret(new_val.dead_count, new_val.pending); |
453 | if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) |
454 | return ret; |
455 | } |
456 | } |
457 | |
458 | template <typename T> |
459 | inline AdjustResult adjust_for_activation_shared(std::atomic<T>* c, |
460 | bool increment_dead) { |
461 | T val = c->load(std::memory_order_relaxed); |
462 | if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) { |
463 | val.dead_count++; |
464 | } |
465 | DCHECK_GE(val.pending, 1); |
466 | val.pending--; |
467 | c->store(val, std::memory_order_relaxed); |
468 | return AdjustResult(val.dead_count, val.pending); |
469 | } |
470 | |
471 | template <typename T> |
472 | inline AdjustResult adjust_for_activation_shared_atomic(std::atomic<T>* c, |
473 | bool increment_dead) { |
474 | T old_val = c->load(std::memory_order_relaxed); |
475 | while (true) { |
476 | T new_val = old_val; |
477 | if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) { |
478 | new_val.dead_count++; |
479 | } |
480 | DCHECK_GE(new_val.pending, 1); |
481 | new_val.pending--; |
482 | AdjustResult ret(new_val.dead_count, new_val.pending); |
483 | if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) |
484 | return ret; |
485 | } |
486 | } |
487 | |
488 | // We keep track of the pending count and dead input count for each |
489 | // graph node. The representation used here is designed to be cache |
490 | // efficient for graphs with large numbers of nodes, where most |
491 | // nodes have relatively small maximum pending counts (e.g. for one |
492 | // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We |
493 | // use one byte to hold both the pending and dead count for a node |
494 | // where these together can fit in one byte, and we use a hash table |
495 | // to handle the rare node ids that need larger counts than this. |
496 | // Each frame in this subgraph has its own PendingCounts. |
497 | |
498 | // We use 3 bits each for dead_count and pending. |
499 | static constexpr int kMaxCountForPackedCounts = 7; |
500 | |
501 | // Most counts are small, so we pack a pending count and a dead |
502 | // count into 3 bits each, use 1 bit to indicate that the node has |
503 | // started computing. |
504 | struct PackedCounts { |
505 | uint8 pending : 3; |
506 | uint8 dead_count : 3; |
507 | uint8 has_started : 1; |
508 | }; |
509 | |
510 | // NOTE: alignas(8) is critical to implement efficient atomic<LargeCounts> |
511 | // on MSVC. |
512 | struct alignas(8) LargeCounts { |
513 | uint32 pending; |
514 | uint32 dead_count : 31; |
515 | // NOTE(tlipcon): MSVC won't pack this struct into 8 bytes unless |
516 | // all of the member types are uint32. |
517 | uint32 has_started : 1; |
518 | }; |
519 | |
520 | template <typename T> |
521 | NodeState NodeStateForStruct(const T& c) const { |
522 | if (c.has_started) { |
523 | return (c.pending == 0) ? STARTED : COMPLETED; |
524 | } else { |
525 | return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY; |
526 | } |
527 | } |
528 | inline std::atomic<LargeCounts>* Large(Handle h) { |
529 | DCHECK(h.is_large_); |
530 | DCHECK_LE(h.byte_offset_ + sizeof(std::atomic<LargeCounts>), num_bytes_); |
531 | DCHECK_EQ(h.byte_offset_ % alignof(std::atomic<LargeCounts>), 0); |
532 | return reinterpret_cast<std::atomic<LargeCounts>*>(bytes_ + h.byte_offset_); |
533 | } |
534 | inline std::atomic<PackedCounts>* Packed(Handle h) { |
535 | DCHECK(!h.is_large_); |
536 | DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_); |
537 | return reinterpret_cast<std::atomic<PackedCounts>*>(bytes_ + |
538 | h.byte_offset_); |
539 | } |
540 | |
541 | const int num_bytes_; // Just for bounds checking in debug mode |
542 | char* bytes_; // Array of num_bytes_ bytes |
543 | |
544 | void operator=(const PendingCounts&) = delete; |
545 | }; |
546 | |
547 | inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( |
548 | size_t max_pending_count, size_t max_dead_count) { |
549 | Handle result; |
550 | if ((max_pending_count > kMaxCountForPackedCounts) || |
551 | (max_dead_count > kMaxCountForPackedCounts)) { |
552 | constexpr int B = sizeof(std::atomic<LargeCounts>); |
553 | // Round byte offset to proper alignment |
554 | static_assert( |
555 | sizeof(std::atomic<LargeCounts>) >= alignof(std::atomic<LargeCounts>), |
556 | "std::atomic<LargeCounts> must be packed" ); |
557 | int64_t offset = ((static_cast<int64_t>(next_offset_) + B - 1) / B) * B; |
558 | result.byte_offset_ = offset; |
559 | result.is_large_ = true; |
560 | next_offset_ = result.byte_offset_ + B; |
561 | } else { |
562 | result.byte_offset_ = next_offset_; |
563 | result.is_large_ = false; |
564 | static_assert(sizeof(std::atomic<PackedCounts>) == 1, |
565 | "std::atomic<PackedCounts> should be a single byte" ); |
566 | next_offset_ += sizeof(std::atomic<PackedCounts>); |
567 | } |
568 | return result; |
569 | } |
570 | |
571 | } // end namespace tensorflow |
572 | |
573 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ |
574 | |