1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/recent_request_ids.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/platform/logging.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | RecentRequestIds::RecentRequestIds(int num_tracked_request_ids, int num_shards) |
25 | : index_buckets_(num_shards > 0 ? num_shards : 1) { |
26 | DCHECK(num_tracked_request_ids >= num_shards); |
27 | const int per_bucket_size = num_tracked_request_ids / index_buckets_.size(); |
28 | for (auto& bucket : index_buckets_) { |
29 | mutex_lock l(bucket.mu); |
30 | bucket.circular_buffer.resize(per_bucket_size); |
31 | bucket.set.reserve(per_bucket_size); |
32 | } |
33 | } |
34 | |
35 | bool RecentRequestIds::Insert(int64_t request_id) { |
36 | if (request_id == 0) { |
37 | // For backwards compatibility, allow all requests with request_id 0. |
38 | return true; |
39 | } |
40 | |
41 | const int bucket_index = request_id % index_buckets_.size(); |
42 | auto& bucket = index_buckets_[bucket_index]; |
43 | |
44 | mutex_lock l(bucket.mu); |
45 | const bool inserted = bucket.set.insert(request_id).second; |
46 | if (!inserted) { |
47 | // Note: RecentRequestIds is not strict LRU because we don't update |
48 | // request_id's age in the circular_buffer_ if it's tracked again. Strict |
49 | // LRU is not useful here because returning this error will close the |
50 | // current Session. |
51 | return false; |
52 | } |
53 | |
54 | // Remove the oldest request_id from the set_. circular_buffer_ is |
55 | // zero-initialized, and zero is never tracked, so it's safe to do this even |
56 | // when the buffer is not yet full. |
57 | bucket.set.erase(bucket.circular_buffer[bucket.next_index]); |
58 | bucket.circular_buffer[bucket.next_index] = request_id; |
59 | bucket.next_index = (bucket.next_index + 1) % bucket.circular_buffer.size(); |
60 | return true; |
61 | } |
62 | |
63 | Status RecentRequestIds::TrackUnique(int64_t request_id, |
64 | const string& method_name, |
65 | const protobuf::Message& request) { |
66 | if (Insert(request_id)) { |
67 | return OkStatus(); |
68 | } else { |
69 | return errors::Aborted("The same " , method_name, |
70 | " request was received twice. " , |
71 | request.ShortDebugString()); |
72 | } |
73 | } |
74 | |
75 | } // namespace tensorflow |
76 | |