1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "tensorflow/core/distributed_runtime/message_wrappers.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/platform/mutex.h" |
27 | #include "tensorflow/core/platform/protobuf.h" |
28 | #include "tensorflow/core/platform/thread_annotations.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | #include "tensorflow/core/protobuf/worker.pb.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // RecentRequestIds tracks recent 64-bit request_ids. When maximum capacity is |
35 | // reached, the oldest request_id is evicted. Thread safe. |
36 | // |
37 | // Some RPCs like RecvTensor are unsafe to retry. For example, RecvTensor pairs |
38 | // one sender and one receiver, and the receiver waits for the sender's tensor. |
39 | // Retried RecvTensor requests are problematic, because the original RecvTensor |
40 | // request may have consumed the sender's tensor, so a retried request might |
41 | // block forever. RecentRequestIds identifies retried requests, so we can fail |
42 | // them instead of blocking forever. |
43 | // |
44 | // Internally, recent request_ids are stored in two data structures: a set and a |
45 | // circular buffer. The set is used for efficient lookups, and the circular |
46 | // buffer tracks the oldest request_id. When the buffer is full, the new |
47 | // request_id replaces the oldest request_id in the circular buffer, and the |
48 | // oldest request_id is removed from the set. |
49 | class RecentRequestIds { |
50 | public: |
51 | // num_tracked_request_ids should be much larger than the number of RPCs that |
52 | // can be received in a small time window. For example, we observed a peak RPC |
53 | // rate of ~700 RecvTensor RPC/s when training inception v3 on TPUs, so we |
54 | // currently set num_tracked_request_ids to 100,000 for RecvTensor. |
55 | // Having a large `num_shars` can prevent run into lock contention in this |
56 | // class. |
57 | explicit RecentRequestIds(int num_tracked_request_ids, int num_shards = 1); |
58 | |
59 | // Returns OK iff request_id has not been seen in the last |
60 | // num_tracked_request_ids insertions. For backwards compatibility, this |
61 | // always returns OK for request_id 0. The method_name and the request's |
62 | // ShortDebugString are added to returned errors. |
63 | Status TrackUnique(int64_t request_id, const string& method_name, |
64 | const protobuf::Message& request); |
65 | // Overloaded version of the above function for wrapped protos. |
66 | template <typename RequestWrapper> |
67 | Status TrackUnique(int64_t request_id, const string& method_name, |
68 | const RequestWrapper* wrapper); |
69 | |
70 | private: |
71 | bool Insert(int64_t request_id); |
72 | |
73 | struct IndexBucket { |
74 | mutex mu; |
75 | // next_index indexes into circular_buffer_, and points to the next storage |
76 | // space to use. When the buffer is full, next_index_ points at the oldest |
77 | // request_id. |
78 | int next_index TF_GUARDED_BY(mu) = 0; |
79 | std::vector<int64_t> circular_buffer TF_GUARDED_BY(mu); |
80 | absl::flat_hash_set<int64_t> set TF_GUARDED_BY(mu); |
81 | }; |
82 | |
83 | // This vector is immutable so we don't need to use a mutex to protect it. |
84 | std::vector<IndexBucket> index_buckets_; |
85 | }; |
86 | |
87 | // Implementation details |
88 | |
89 | template <typename RequestWrapper> |
90 | Status RecentRequestIds::TrackUnique(int64_t request_id, |
91 | const string& method_name, |
92 | const RequestWrapper* wrapper) { |
93 | if (Insert(request_id)) { |
94 | return OkStatus(); |
95 | } else { |
96 | return errors::Aborted("The same " , method_name, |
97 | " request was received twice. " , |
98 | wrapper->ToProto().ShortDebugString()); |
99 | } |
100 | } |
101 | |
102 | } // namespace tensorflow |
103 | |
104 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ |
105 | |