1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/rendezvous.h" |
17 | |
18 | #include <deque> |
19 | #include <functional> |
20 | #include <utility> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/local_rendezvous.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/core/notification.h" |
26 | #include "tensorflow/core/lib/gtl/flatmap.h" |
27 | #include "tensorflow/core/lib/gtl/manual_constructor.h" |
28 | #include "tensorflow/core/lib/hash/hash.h" |
29 | #include "tensorflow/core/lib/strings/str_util.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/macros.h" |
32 | #include "tensorflow/core/platform/mutex.h" |
33 | #include "tensorflow/core/platform/thread_annotations.h" |
34 | #include "tensorflow/core/platform/types.h" |
35 | |
36 | namespace tensorflow { |
37 | |
38 | Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) { |
39 | const char* b_base = b.buf_.data(); |
40 | buf_ = b.buf_; |
41 | src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base), |
42 | b.src_device.size()); |
43 | src = b.src; |
44 | src_incarnation = b.src_incarnation; |
45 | dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base), |
46 | b.dst_device.size()); |
47 | dst = b.dst; |
48 | edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base), |
49 | b.edge_name.size()); |
50 | return *this; |
51 | } |
52 | |
53 | /* static */ |
54 | string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, |
55 | const string& dst_device, const string& name, |
56 | const FrameAndIter& frame_iter) { |
57 | // NOTE: ';' is not used in the device name's job name. |
58 | // |
59 | // We include both sender and receiver in the key to facilitate |
60 | // debugging. For correctness, we only need to encode the receiver. |
61 | // |
62 | // "src_incarnation" is used to distinguish a worker when it |
63 | // restarts. |
64 | char buf[strings::kFastToBufferSize]; |
65 | return strings::StrCat( |
66 | src_device, ";" , strings::Uint64ToHexString(src_incarnation, buf), ";" , |
67 | dst_device, ";" , name, ";" , frame_iter.frame_id, ":" , frame_iter.iter_id); |
68 | } |
69 | |
70 | // Return the prefix of "*s" up to the next occurrence of "delim", or |
71 | // the whole remaining string if "delim" is not found. "*s" is advanced |
72 | // past the string returned plus the delimiter (if found). |
73 | static StringPiece ConsumeNextPart(StringPiece* s, char delim) { |
74 | for (size_t offset = 0; offset < s->size(); offset++) { |
75 | if ((*s)[offset] == delim) { |
76 | StringPiece result(s->data(), offset); |
77 | s->remove_prefix(offset + 1); // +1: remove delim, as well |
78 | return result; |
79 | } |
80 | } |
81 | // No delimiter found: return rest of string |
82 | StringPiece result(s->data(), s->size()); |
83 | s->remove_prefix(s->size()); |
84 | return result; |
85 | } |
86 | |
87 | /* static */ |
88 | Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { |
89 | if (key.data() == out->buf_.data()) { |
90 | // Caller used our buf_ string directly, so we don't need to copy. (The |
91 | // SendOp and RecvOp implementations do this, for example). |
92 | DCHECK_EQ(key.size(), out->buf_.size()); |
93 | } else { |
94 | // Make a copy that our StringPieces can point at a copy that will persist |
95 | // for the lifetime of the ParsedKey object. |
96 | out->buf_.assign(key.data(), key.size()); |
97 | } |
98 | StringPiece s(out->buf_); |
99 | StringPiece parts[5]; |
100 | for (int i = 0; i < 5; i++) { |
101 | parts[i] = ConsumeNextPart(&s, ';'); |
102 | } |
103 | if (s.empty() && // Consumed the whole string |
104 | !parts[4].empty() && // Exactly five parts |
105 | DeviceNameUtils::ParseFullName(parts[0], &out->src) && |
106 | strings::HexStringToUint64(parts[1], &out->src_incarnation) && |
107 | DeviceNameUtils::ParseFullName(parts[2], &out->dst) && |
108 | !parts[3].empty()) { |
109 | out->src_device = StringPiece(parts[0].data(), parts[0].size()); |
110 | out->dst_device = StringPiece(parts[2].data(), parts[2].size()); |
111 | out->edge_name = StringPiece(parts[3].data(), parts[3].size()); |
112 | return OkStatus(); |
113 | } |
114 | return errors::InvalidArgument("Invalid rendezvous key: " , key); |
115 | } |
116 | |
117 | RendezvousInterface::~RendezvousInterface() {} |
118 | |
119 | Status RendezvousInterface::Recv(const ParsedKey& key, const Args& recv_args, |
120 | Tensor* val, bool* is_dead, |
121 | int64_t timeout_ms) { |
122 | Status ret; |
123 | Notification n; |
124 | RecvAsync(key, recv_args, |
125 | [&ret, &n, val, is_dead](const Status& s, const Args& send_args, |
126 | const Args& recv_args, const Tensor& v, |
127 | const bool dead) { |
128 | ret = s; |
129 | *val = v; |
130 | *is_dead = dead; |
131 | n.Notify(); |
132 | }); |
133 | if (timeout_ms > 0) { |
134 | int64_t timeout_us = timeout_ms * 1000; |
135 | bool notified = WaitForNotificationWithTimeout(&n, timeout_us); |
136 | if (!notified) { |
137 | return Status(error::DEADLINE_EXCEEDED, |
138 | "Timed out waiting for notification" ); |
139 | } |
140 | } else { |
141 | n.WaitForNotification(); |
142 | } |
143 | return ret; |
144 | } |
145 | |
146 | Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, |
147 | Tensor* val, bool* is_dead) { |
148 | const int64_t no_timeout = 0; |
149 | return Recv(key, args, val, is_dead, no_timeout); |
150 | } |
151 | |
152 | namespace { |
153 | class LocalRendezvousWrapper : public Rendezvous { |
154 | public: |
155 | LocalRendezvousWrapper() : impl_(this) {} |
156 | |
157 | Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, |
158 | const bool is_dead) override { |
159 | return impl_.Send(key, send_args, val, is_dead); |
160 | } |
161 | |
162 | void RecvAsync(const ParsedKey& key, const Args& recv_args, |
163 | DoneCallback done) override { |
164 | impl_.RecvAsync(key, recv_args, std::move(done)); |
165 | } |
166 | |
167 | void StartAbort(const Status& status) override { impl_.StartAbort(status); } |
168 | |
169 | private: |
170 | LocalRendezvous impl_; |
171 | |
172 | TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper); |
173 | }; |
174 | } // namespace |
175 | |
176 | Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; } |
177 | |
178 | } // end namespace tensorflow |
179 | |