1#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
2#define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
3
4/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5
6Licensed under the Apache License, Version 2.0 (the "License");
7you may not use this file except in compliance with the License.
8You may obtain a copy of the License at
9
10 http://www.apache.org/licenses/LICENSE-2.0
11
12Unless required by applicable law or agreed to in writing, software
13distributed under the License is distributed on an "AS IS" BASIS,
14WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15See the License for the specific language governing permissions and
16limitations under the License.
17==============================================================================*/
18
19#include <atomic>
20
21#include "tensorflow/core/lib/gtl/flatmap.h"
22#include "tensorflow/core/lib/hash/hash.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/util/port.h"
26
27namespace tensorflow {
28
29// PendingCounts is an internal helper class to keep track of pending and
30// dead counts for nodes, for use in the ExecutorState module. It
31// holds a map from Handles to various counts for that handle. This
32// information is needed per frame iteration. The amount of memory
33// needed for an iteration is the same across all executions of the
34// iteration. The memory amount and handles are precomputed at startup
35// using a Layout object.
36//
37// PendingCounts::Layout layout;
38// std::vector<PendingCounts::Handle> h(C);
39// for (int id = 0; id < C; id++) {
40// h[id] = r.AddHandle(max_pending[id], max_dead[id]);
41// }
42//
43// When we actually want to start an iteration we first create a
44// PendingCounts object and then index into it using the precomputed
45// handles:
46
47// PendingCounts counts(layout);
48// ...
49// counts.decrement_pending(h[id], 1);
50class PendingCounts {
51 public:
52 // The state machine for a node's execution.
53 enum NodeState {
54 // The pending count for the node > 0.
55 PENDING_NOTREADY,
56 // The pending count for the node == 0, but the node has not
57 // started executing.
58 PENDING_READY,
59 // The node has started executing.
60 STARTED,
61 // The node has finished executing.
62 COMPLETED
63 };
64
65 // An opaque handle indicating where in the PendingCounts data structure
66 // the appropriate count information can be found.
67 class Handle;
68 // Given a node that needs to represent counts no larger than the
69 // specified "max_pending_count" and "max_dead_count", create a
70 // handle that can be passed to various PendingCounts routines
71 // to retrieve the count data for this node.
72 class Layout {
73 public:
74 Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
75
76 private:
77 friend class PendingCounts;
78 int next_offset_ = 0; // Next byte offset to allocate
79 };
80
81 // Create a new PendingCounts object that can hold the state of
82 // all the Handles allocated from "final_allocator".
83 explicit PendingCounts(Layout layout)
84 : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {
85 if (num_bytes_ >= sizeof(LargeCounts)) {
86 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
87 }
88 }
89
90 // Create a new PendingCounts object with the same layout and counts
91 // as "other".
92 explicit PendingCounts(const PendingCounts& other)
93 : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
94 if (num_bytes_ >= sizeof(LargeCounts)) {
95 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
96 }
97 memcpy(bytes_, other.bytes_, other.num_bytes_);
98 }
99
100 ~PendingCounts() { delete[] bytes_; }
101
102 void set_initial_count(Handle h, size_t pending_count) {
103 if (h.is_large_) {
104 std::atomic<LargeCounts>* c_ptr = Large(h);
105 auto c = c_ptr->load(std::memory_order_relaxed);
106 c.pending = pending_count;
107 c.dead_count = 0;
108 c.has_started = 0;
109 c_ptr->store(c, std::memory_order_relaxed);
110 } else {
111 DCHECK_LE(pending_count, kMaxCountForPackedCounts);
112 std::atomic<PackedCounts>* c_ptr = Packed(h);
113 auto c = c_ptr->load(std::memory_order_relaxed);
114 c.pending = pending_count;
115 c.dead_count = 0;
116 c.has_started = 0;
117 c_ptr->store(c, std::memory_order_relaxed);
118 }
119 }
120
121 NodeState node_state(Handle h) {
122 if (h.is_large_) {
123 return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed));
124 } else {
125 return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed));
126 }
127 }
128 void mark_started(Handle h) {
129 DCHECK_EQ(pending(h), 0);
130 if (h.is_large_) {
131 std::atomic<LargeCounts>* c_ptr = Large(h);
132 auto c = c_ptr->load(std::memory_order_relaxed);
133 DCHECK_EQ(c.has_started, 0);
134 c.has_started = 1;
135 c_ptr->store(c, std::memory_order_relaxed);
136 } else {
137 std::atomic<PackedCounts>* c_ptr = Packed(h);
138 auto c = c_ptr->load(std::memory_order_relaxed);
139 DCHECK_EQ(c.has_started, 0);
140 c.has_started = 1;
141 c_ptr->store(c, std::memory_order_relaxed);
142 }
143 }
144 void mark_completed(Handle h) {
145 if (h.is_large_) {
146 std::atomic<LargeCounts>* c_ptr = Large(h);
147 auto c = c_ptr->load(std::memory_order_relaxed);
148 DCHECK_EQ(c.has_started, 1);
149 c.pending = 1;
150 c_ptr->store(c, std::memory_order_relaxed);
151 } else {
152 std::atomic<PackedCounts>* c_ptr = Packed(h);
153 auto c = c_ptr->load(std::memory_order_relaxed);
154 DCHECK_EQ(c.has_started, 1);
155 c.pending = 1;
156 c_ptr->store(c, std::memory_order_relaxed);
157 }
158 }
159 int pending(Handle h) {
160 if (h.is_large_) {
161 LargeCounts c = Large(h)->load(std::memory_order_relaxed);
162 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
163 return c.pending;
164 } else {
165 // The pending count encodes the state once the node has
166 // started, so just return 0.
167 return 0;
168 }
169 } else {
170 PackedCounts c = Packed(h)->load(std::memory_order_relaxed);
171 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
172 return c.pending;
173 } else {
174 // The pending count encodes the state once the node has
175 // started, so just return 0.
176 return 0;
177 }
178 }
179 }
180 struct AdjustResult {
181 int dead_count;
182 int pending_count;
183
184 AdjustResult(int dead_count, int pending_count)
185 : dead_count(dead_count), pending_count(pending_count) {}
186 };
187 int decrement_pending(Handle h, int v) {
188 DCHECK_GE(pending(h), v);
189 if (h.is_large_) {
190 std::atomic<LargeCounts>* c_ptr = Large(h);
191 auto c = c_ptr->load(std::memory_order_relaxed);
192 c.pending -= v;
193 c_ptr->store(c, std::memory_order_relaxed);
194 return c.pending;
195 } else {
196 std::atomic<PackedCounts>* c_ptr = Packed(h);
197 auto c = c_ptr->load(std::memory_order_relaxed);
198 c.pending -= v;
199 c_ptr->store(c, std::memory_order_relaxed);
200 return c.pending;
201 }
202 }
203
204 // Mark a merge node as live
205 // REQUIRES: Node corresponding to "h" is a merge node
206 void mark_live(Handle h) {
207 if (h.is_large_) {
208 std::atomic<LargeCounts>* c_ptr = Large(h);
209 auto c = c_ptr->load(std::memory_order_relaxed);
210 // Only do anything if the node hasn't already started executing.
211 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
212 c.pending &= ~static_cast<int>(0x1);
213 c_ptr->store(c, std::memory_order_relaxed);
214 }
215 } else {
216 std::atomic<PackedCounts>* c_ptr = Packed(h);
217 auto c = c_ptr->load(std::memory_order_relaxed);
218 // Only do anything if the node hasn't already started executing.
219 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
220 static_assert(7 == kMaxCountForPackedCounts,
221 "Live flag incorrect for max packed count");
222 c.pending &= 0x6;
223 c_ptr->store(c, std::memory_order_relaxed);
224 }
225 }
226 }
227
228 int dead_count(Handle h) {
229 int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count
230 : Packed(h)->load(std::memory_order_relaxed).dead_count;
231 return r;
232 }
233 void increment_dead_count(Handle h) {
234 if (h.is_large_) {
235 std::atomic<LargeCounts>* c_ptr = Large(h);
236 auto c = c_ptr->load(std::memory_order_relaxed);
237 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
238 c.dead_count++;
239 c_ptr->store(c, std::memory_order_relaxed);
240 }
241 } else {
242 std::atomic<PackedCounts>* c_ptr = Packed(h);
243 auto c = c_ptr->load(std::memory_order_relaxed);
244 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
245 DCHECK_LT(c.dead_count, kMaxCountForPackedCounts);
246 c.dead_count++;
247 c_ptr->store(c, std::memory_order_relaxed);
248 }
249 }
250 }
251
252 // Mark a merge node as live. Please note that the pending count it returns
253 // is before the update.
254 AdjustResult adjust_for_mark_live(Handle h) {
255 if (h.is_large_) {
256 std::atomic<LargeCounts>* c_ptr = Large(h);
257 auto c = c_ptr->load(std::memory_order_relaxed);
258 auto ret_pending = 0;
259 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
260 ret_pending = c.pending;
261 c.pending &= ~static_cast<int>(0x1);
262 c_ptr->store(c, std::memory_order_relaxed);
263 }
264 return AdjustResult(c.dead_count, ret_pending);
265 } else {
266 std::atomic<PackedCounts>* c_ptr = Packed(h);
267 auto c = c_ptr->load(std::memory_order_relaxed);
268 auto ret_pending = 0;
269 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
270 static_assert(7 == kMaxCountForPackedCounts,
271 "Live flag incorrect for max packed count");
272 ret_pending = c.pending;
273 c.pending &= 0x6;
274 c_ptr->store(c, std::memory_order_relaxed);
275 }
276 return AdjustResult(c.dead_count, ret_pending);
277 }
278 }
279
280 // The same as the above, but performs the operation atomically. This
281 // is thread-safe to run concurrently with other threads.
282 AdjustResult adjust_for_mark_live_atomic(Handle h) {
283 if (h.is_large_) {
284 std::atomic<LargeCounts>* c_ptr = Large(h);
285 auto old_val = c_ptr->load(std::memory_order_relaxed);
286 while (true) {
287 auto new_val = old_val;
288 auto ret_pending = 0;
289 // Only do anything if the node hasn't already started executing.
290 if (PENDING_NOTREADY == NodeStateForStruct(new_val)) {
291 ret_pending = old_val.pending;
292 new_val.pending &= ~static_cast<int>(0x1);
293 }
294 AdjustResult ret(old_val.dead_count, ret_pending);
295 if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val)))
296 return ret;
297 }
298 } else {
299 std::atomic<PackedCounts>* c_ptr = Packed(h);
300 auto old_val = c_ptr->load(std::memory_order_relaxed);
301 while (true) {
302 auto new_val = old_val;
303 auto ret_pending = 0;
304 // Only do anything if the node hasn't already started executing.
305 if (PENDING_NOTREADY == NodeStateForStruct(new_val)) {
306 static_assert(7 == kMaxCountForPackedCounts,
307 "Live flag incorrect for max packed count");
308 ret_pending = old_val.pending;
309 new_val.pending &= 0x6;
310 }
311 AdjustResult ret(old_val.dead_count, ret_pending);
312 if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val)))
313 return ret;
314 }
315 }
316 }
317
318 // A streamlined routine that does several pieces of bookkeeping at
319 // once. Equivalent to:
320 // increment_dead_count(h);
321 // return {dead_count(h) pending(h)};
322 AdjustResult adjust_for_increment_dead(Handle h) {
323 if (h.is_large_) {
324 return adjust_for_increment_dead_shared(Large(h));
325 } else {
326 return adjust_for_increment_dead_shared(Packed(h));
327 }
328 }
329
330 // The same as the above, but performs the operation atomically. This
331 // is thread-safe to run concurrently with other threads.
332 AdjustResult adjust_for_increment_dead_atomic(Handle h) {
333 if (h.is_large_) {
334 return adjust_for_increment_dead_shared_atomic(Large(h));
335 } else {
336 return adjust_for_increment_dead_shared_atomic(Packed(h));
337 }
338 }
339
340 // A streamlined routine that does several pieces of bookkeeping at
341 // once. Equivalent to:
342 // decrement_pending(h, decrement_pending);
343 // return {dead_count(h) pending(h)};
344 AdjustResult adjust_for_decrement_pending(Handle h, int decrement_pending) {
345 DCHECK_GE(pending(h), decrement_pending);
346 if (h.is_large_) {
347 return adjust_for_decrement_pending_shared(Large(h), decrement_pending);
348 } else {
349 return adjust_for_decrement_pending_shared(Packed(h), decrement_pending);
350 }
351 }
352
353 // The same as the above, but performs the operation atomically. This
354 // is thread-safe to run concurrently with other threads.
355 AdjustResult adjust_for_decrement_pending_atomic(Handle h,
356 int decrement_pending) {
357 DCHECK_GE(pending(h), decrement_pending);
358 if (h.is_large_) {
359 return adjust_for_decrement_pending_shared_atomic(Large(h),
360 decrement_pending);
361 } else {
362 return adjust_for_decrement_pending_shared_atomic(Packed(h),
363 decrement_pending);
364 }
365 }
366
367 // A streamlined routine that does several pieces of bookkeeping at
368 // once. Equivalent to:
369 // if (increment_dead) increment_dead_count(h);
370 // decrement_pending(h, 1);
371 // return {dead_count(h), pending(h)};
372 AdjustResult adjust_for_activation(Handle h, bool increment_dead) {
373 DCHECK_GE(pending(h), 1);
374 if (h.is_large_) {
375 return adjust_for_activation_shared(Large(h), increment_dead);
376 } else {
377 return adjust_for_activation_shared(Packed(h), increment_dead);
378 }
379 }
380
381 // The same as the above, but performs the operation atomically. This
382 // is thread-safe to run concurrently with other threads.
383 AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) {
384 DCHECK_GE(pending(h), 1);
385 if (h.is_large_) {
386 return adjust_for_activation_shared_atomic(Large(h), increment_dead);
387 } else {
388 return adjust_for_activation_shared_atomic(Packed(h), increment_dead);
389 }
390 }
391
392 class Handle {
393 public:
394 Handle() : byte_offset_(0), is_large_(0) {}
395
396 private:
397 friend class PendingCounts;
398 int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object
399 bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts
400 };
401
402 private:
403 template <typename T>
404 inline AdjustResult adjust_for_increment_dead_shared(std::atomic<T>* c) {
405 T val = c->load(std::memory_order_relaxed);
406 auto ret_pending = 0;
407 // Only do anything if the node hasn't already started executing.
408 if (PENDING_NOTREADY == NodeStateForStruct(val)) {
409 val.dead_count++;
410 ret_pending = val.pending;
411 c->store(val, std::memory_order_relaxed);
412 }
413 return AdjustResult(val.dead_count, ret_pending);
414 }
415
416 template <typename T>
417 inline AdjustResult adjust_for_increment_dead_shared_atomic(
418 std::atomic<T>* c) {
419 T old_val = c->load(std::memory_order_relaxed);
420 while (true) {
421 auto new_val = old_val;
422 auto ret_pending = 0;
423 // Only do anything if the node hasn't already started executing.
424 if (PENDING_NOTREADY == NodeStateForStruct(new_val)) {
425 ret_pending = new_val.pending;
426 new_val.dead_count++;
427 }
428 AdjustResult ret(new_val.dead_count, ret_pending);
429 if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val)))
430 return ret;
431 }
432 }
433
434 template <typename T>
435 inline AdjustResult adjust_for_decrement_pending_shared(
436 std::atomic<T>* c, int decrement_pending) {
437 T val = c->load(std::memory_order_relaxed);
438 DCHECK_GE(val.pending, decrement_pending);
439 val.pending -= decrement_pending;
440 c->store(val, std::memory_order_relaxed);
441 return AdjustResult(val.dead_count, val.pending);
442 }
443
444 template <typename T>
445 inline AdjustResult adjust_for_decrement_pending_shared_atomic(
446 std::atomic<T>* c, int decrement_pending) {
447 T old_val = c->load(std::memory_order_relaxed);
448 while (true) {
449 T new_val = old_val;
450 DCHECK_GE(new_val.pending, decrement_pending);
451 new_val.pending -= decrement_pending;
452 AdjustResult ret(new_val.dead_count, new_val.pending);
453 if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val)))
454 return ret;
455 }
456 }
457
458 template <typename T>
459 inline AdjustResult adjust_for_activation_shared(std::atomic<T>* c,
460 bool increment_dead) {
461 T val = c->load(std::memory_order_relaxed);
462 if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) {
463 val.dead_count++;
464 }
465 DCHECK_GE(val.pending, 1);
466 val.pending--;
467 c->store(val, std::memory_order_relaxed);
468 return AdjustResult(val.dead_count, val.pending);
469 }
470
471 template <typename T>
472 inline AdjustResult adjust_for_activation_shared_atomic(std::atomic<T>* c,
473 bool increment_dead) {
474 T old_val = c->load(std::memory_order_relaxed);
475 while (true) {
476 T new_val = old_val;
477 if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) {
478 new_val.dead_count++;
479 }
480 DCHECK_GE(new_val.pending, 1);
481 new_val.pending--;
482 AdjustResult ret(new_val.dead_count, new_val.pending);
483 if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val)))
484 return ret;
485 }
486 }
487
488 // We keep track of the pending count and dead input count for each
489 // graph node. The representation used here is designed to be cache
490 // efficient for graphs with large numbers of nodes, where most
491 // nodes have relatively small maximum pending counts (e.g. for one
492 // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We
493 // use one byte to hold both the pending and dead count for a node
494 // where these together can fit in one byte, and we use a hash table
495 // to handle the rare node ids that need larger counts than this.
496 // Each frame in this subgraph has its own PendingCounts.
497
498 // We use 3 bits each for dead_count and pending.
499 static constexpr int kMaxCountForPackedCounts = 7;
500
501 // Most counts are small, so we pack a pending count and a dead
502 // count into 3 bits each, use 1 bit to indicate that the node has
503 // started computing.
504 struct PackedCounts {
505 uint8 pending : 3;
506 uint8 dead_count : 3;
507 uint8 has_started : 1;
508 };
509
510 // NOTE: alignas(8) is critical to implement efficient atomic<LargeCounts>
511 // on MSVC.
512 struct alignas(8) LargeCounts {
513 uint32 pending;
514 uint32 dead_count : 31;
515 // NOTE(tlipcon): MSVC won't pack this struct into 8 bytes unless
516 // all of the member types are uint32.
517 uint32 has_started : 1;
518 };
519
520 template <typename T>
521 NodeState NodeStateForStruct(const T& c) const {
522 if (c.has_started) {
523 return (c.pending == 0) ? STARTED : COMPLETED;
524 } else {
525 return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY;
526 }
527 }
528 inline std::atomic<LargeCounts>* Large(Handle h) {
529 DCHECK(h.is_large_);
530 DCHECK_LE(h.byte_offset_ + sizeof(std::atomic<LargeCounts>), num_bytes_);
531 DCHECK_EQ(h.byte_offset_ % alignof(std::atomic<LargeCounts>), 0);
532 return reinterpret_cast<std::atomic<LargeCounts>*>(bytes_ + h.byte_offset_);
533 }
534 inline std::atomic<PackedCounts>* Packed(Handle h) {
535 DCHECK(!h.is_large_);
536 DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
537 return reinterpret_cast<std::atomic<PackedCounts>*>(bytes_ +
538 h.byte_offset_);
539 }
540
541 const int num_bytes_; // Just for bounds checking in debug mode
542 char* bytes_; // Array of num_bytes_ bytes
543
544 void operator=(const PendingCounts&) = delete;
545};
546
547inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
548 size_t max_pending_count, size_t max_dead_count) {
549 Handle result;
550 if ((max_pending_count > kMaxCountForPackedCounts) ||
551 (max_dead_count > kMaxCountForPackedCounts)) {
552 constexpr int B = sizeof(std::atomic<LargeCounts>);
553 // Round byte offset to proper alignment
554 static_assert(
555 sizeof(std::atomic<LargeCounts>) >= alignof(std::atomic<LargeCounts>),
556 "std::atomic<LargeCounts> must be packed");
557 int64_t offset = ((static_cast<int64_t>(next_offset_) + B - 1) / B) * B;
558 result.byte_offset_ = offset;
559 result.is_large_ = true;
560 next_offset_ = result.byte_offset_ + B;
561 } else {
562 result.byte_offset_ = next_offset_;
563 result.is_large_ = false;
564 static_assert(sizeof(std::atomic<PackedCounts>) == 1,
565 "std::atomic<PackedCounts> should be a single byte");
566 next_offset_ += sizeof(std::atomic<PackedCounts>);
567 }
568 return result;
569}
570
571} // end namespace tensorflow
572
573#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
574