1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/sendrecv_ops.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/core/framework/attr_value.pb.h" |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/framework/op_def_util.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/lib/strings/numbers.h" |
25 | #include "tensorflow/core/lib/strings/strcat.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/profiler/lib/traceme.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | static string GetRendezvousKeyPrefix(const string& send_device, |
32 | const string& recv_device, |
33 | const uint64 send_device_incarnation, |
34 | const string& tensor_name) { |
35 | return strings::StrCat(send_device, ";" , |
36 | strings::FpToString(send_device_incarnation), ";" , |
37 | recv_device, ";" , tensor_name); |
38 | } |
39 | |
40 | static void GetRendezvousKey(const string& key_prefix, |
41 | const FrameAndIter& frame_iter, string* key) { |
42 | key->clear(); |
43 | strings::StrAppend(key, key_prefix, ";" , frame_iter.frame_id, ":" , |
44 | frame_iter.iter_id); |
45 | } |
46 | |
47 | static FrameAndIter GetFrameAndIter(OpKernelContext* ctx, |
48 | bool hostmem_sendrecv) { |
49 | if (hostmem_sendrecv && ctx->call_frame() != nullptr) { |
50 | // Host memory send/recv pairs are added by |
51 | // common_runtime/memory_types.cc. When the pair of nodes are |
52 | // added inside a function, we need to use the function call frame |
53 | // to formulate the unique rendezvous key. |
54 | return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0); |
55 | } else { |
56 | return ctx->frame_iter(); |
57 | } |
58 | } |
59 | |
60 | SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
61 | string send_device; |
62 | OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device" , &send_device)); |
63 | string recv_device; |
64 | OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device" , &recv_device)); |
65 | uint64 send_device_incarnation; |
66 | OP_REQUIRES_OK( |
67 | ctx, ctx->GetAttr("send_device_incarnation" , |
68 | reinterpret_cast<int64_t*>(&send_device_incarnation))); |
69 | string tensor_name; |
70 | OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name" , &tensor_name)); |
71 | key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, |
72 | send_device_incarnation, tensor_name); |
73 | // The vast majority of Send nodes are outside any loop context, so |
74 | // proactively cache the rendezvous key for the top-level. |
75 | GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_); |
76 | OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_)); |
77 | if (!ctx->GetAttr("_hostmem_sendrecv" , &hostmem_sendrecv_).ok()) { |
78 | hostmem_sendrecv_ = false; |
79 | } |
80 | } |
81 | |
82 | void SendOp::Compute(OpKernelContext* ctx) { |
83 | OP_REQUIRES( |
84 | ctx, ctx->rendezvous() != nullptr, |
85 | errors::Internal("Op kernel context needs to provide a rendezvous." )); |
86 | |
87 | // The device context may be passed between the Send/Recv |
88 | // boundary, so that the device context used to produce the Tensor |
89 | // is used when performing the copy on the recv side (which may be |
90 | // a different device). |
91 | Rendezvous::Args args; |
92 | args.device_context = ctx->op_device_context(); |
93 | args.alloc_attrs = ctx->input_alloc_attr(0); |
94 | |
95 | FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); |
96 | if (frame_iter == FrameAndIter(0, 0)) { |
97 | // Use the cached rendezvous key. |
98 | VLOG(2) << "Send " << parsed_key_.buf_ << " using " |
99 | << reinterpret_cast<uintptr_t>(ctx->rendezvous()); |
100 | ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), |
101 | ctx->is_input_dead())); |
102 | return; |
103 | } else { |
104 | Rendezvous::ParsedKey in_loop_parsed; |
105 | GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); |
106 | VLOG(2) << "Send " << in_loop_parsed.buf_ << " using " |
107 | << reinterpret_cast<uintptr_t>(ctx->rendezvous()); |
108 | OP_REQUIRES_OK(ctx, |
109 | Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed)); |
110 | |
111 | ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), |
112 | ctx->is_input_dead())); |
113 | return; |
114 | } |
115 | } |
116 | |
117 | string SendOp::TraceString(const OpKernelContext& ctx, bool verbose) const { |
118 | const auto& attr = def().attr(); |
119 | auto src_it = attr.find("_src" ); |
120 | auto dst_it = attr.find("_dst" ); |
121 | const string& src = src_it != attr.end() ? src_it->second.s() : "" ; |
122 | const string& dst = dst_it != attr.end() ? dst_it->second.s() : "" ; |
123 | string op = profiler::TraceMeOp(name_view(), type_string_view()); |
124 | return profiler::TraceMeEncode( |
125 | std::move(op), |
126 | {{"from" , src}, {"to" , dst}, {"key" , parsed_key_.FullKey()}}); |
127 | } |
128 | |
129 | REGISTER_KERNEL_BUILDER(Name("_Send" ).Device(DEVICE_CPU), SendOp); |
130 | REGISTER_KERNEL_BUILDER(Name("_Send" ).Device(DEVICE_DEFAULT), SendOp); |
131 | REGISTER_KERNEL_BUILDER(Name("_Send" ).Device(DEVICE_TPU_SYSTEM), SendOp); |
132 | REGISTER_KERNEL_BUILDER(Name("_HostSend" ).Device(DEVICE_TPU_SYSTEM), SendOp); |
133 | |
134 | // Public alias. Added for use in Lingvo. |
135 | REGISTER_KERNEL_BUILDER(Name("Send" ).Device(DEVICE_CPU), SendOp); |
136 | REGISTER_KERNEL_BUILDER(Name("Send" ).Device(DEVICE_DEFAULT), SendOp); |
137 | |
138 | REGISTER_KERNEL_BUILDER( |
139 | Name("_HostSend" ).Device(DEVICE_DEFAULT).HostMemory("tensor" ), SendOp); |
140 | |
141 | RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |
142 | string send_device; |
143 | OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device" , &send_device)); |
144 | string recv_device; |
145 | OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device" , &recv_device)); |
146 | uint64 send_device_incarnation; |
147 | OP_REQUIRES_OK( |
148 | ctx, ctx->GetAttr("send_device_incarnation" , |
149 | reinterpret_cast<int64_t*>(&send_device_incarnation))); |
150 | string tensor_name; |
151 | OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name" , &tensor_name)); |
152 | key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, |
153 | send_device_incarnation, tensor_name); |
154 | // The vast majority of Recv nodes are outside any loop context, so |
155 | // proactively cache the rendezvous key for the top-level. |
156 | GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_); |
157 | OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_)); |
158 | if (!ctx->GetAttr("_hostmem_sendrecv" , &hostmem_sendrecv_).ok()) { |
159 | hostmem_sendrecv_ = false; |
160 | } |
161 | } |
162 | |
163 | string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const { |
164 | const auto& attr = def().attr(); |
165 | auto src_it = attr.find("_src" ); |
166 | auto dst_it = attr.find("_dst" ); |
167 | const string& src = src_it != attr.end() ? src_it->second.s() : "" ; |
168 | const string& dst = dst_it != attr.end() ? dst_it->second.s() : "" ; |
169 | string op = profiler::TraceMeOp(name_view(), type_string_view()); |
170 | return profiler::TraceMeEncode( |
171 | std::move(op), |
172 | {{"from" , src}, {"to" , dst}, {"key" , parsed_key_.FullKey()}}); |
173 | } |
174 | |
175 | namespace { |
176 | Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx, |
177 | AsyncOpKernel::DoneCallback done) { |
178 | return [ctx, done = std::move(done)](const Status& s, |
179 | const Rendezvous::Args& send_args, |
180 | const Rendezvous::Args& recv_args, |
181 | const Tensor& val, bool is_dead) { |
182 | ctx->SetStatus(s); |
183 | if (s.ok()) { |
184 | // 'ctx' allocates the output tensor of the expected type. |
185 | // The runtime checks whether the tensor received here is |
186 | // the same type. |
187 | if (!is_dead) { |
188 | ctx->set_output(0, val); |
189 | } |
190 | } |
191 | done(); |
192 | }; |
193 | } |
194 | } // namespace |
195 | |
196 | void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { |
197 | OP_REQUIRES_ASYNC( |
198 | ctx, ctx->rendezvous() != nullptr, |
199 | errors::Internal("Op kernel context needs to provide a rendezvous." ), |
200 | done); |
201 | |
202 | Rendezvous::Args args; |
203 | args.device_context = ctx->op_device_context(); |
204 | args.alloc_attrs = ctx->output_alloc_attr(0); |
205 | args.cancellation_manager = ctx->cancellation_manager(); |
206 | |
207 | FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); |
208 | if (frame_iter == FrameAndIter(0, 0)) { |
209 | VLOG(2) << "Recv " << parsed_key_.buf_ << " using " |
210 | << reinterpret_cast<uintptr_t>(ctx->rendezvous()); |
211 | ctx->rendezvous()->RecvAsync(parsed_key_, args, |
212 | make_recv_callback(ctx, std::move(done))); |
213 | } else { |
214 | Rendezvous::ParsedKey in_loop_parsed; |
215 | GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); |
216 | VLOG(2) << "Recv " << in_loop_parsed.buf_ << " using " |
217 | << reinterpret_cast<uintptr_t>(ctx->rendezvous()); |
218 | OP_REQUIRES_OK_ASYNC( |
219 | ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done); |
220 | ctx->rendezvous()->RecvAsync(in_loop_parsed, args, |
221 | make_recv_callback(ctx, std::move(done))); |
222 | } |
223 | } |
224 | |
225 | REGISTER_KERNEL_BUILDER(Name("_Recv" ).Device(DEVICE_CPU), RecvOp); |
226 | REGISTER_KERNEL_BUILDER(Name("_Recv" ).Device(DEVICE_DEFAULT), RecvOp); |
227 | REGISTER_KERNEL_BUILDER(Name("_Recv" ).Device(DEVICE_TPU_SYSTEM), RecvOp); |
228 | REGISTER_KERNEL_BUILDER(Name("_HostRecv" ).Device(DEVICE_TPU_SYSTEM), RecvOp); |
229 | |
230 | // Public alias. Added for use in Lingvo. |
231 | REGISTER_KERNEL_BUILDER(Name("Recv" ).Device(DEVICE_CPU), RecvOp); |
232 | REGISTER_KERNEL_BUILDER(Name("Recv" ).Device(DEVICE_DEFAULT), RecvOp); |
233 | |
234 | REGISTER_KERNEL_BUILDER( |
235 | Name("_HostRecv" ).Device(DEVICE_DEFAULT).HostMemory("tensor" ), RecvOp); |
236 | |
237 | // Environment variable `DISABLE_HOST_SEND_RECV_REGISTRATION` is used to disable |
238 | // hostSend and hostRecv registration on CPU device in the mock environment. |
239 | static bool InitModule() { |
240 | if (!std::getenv("DISABLE_HOST_SEND_RECV_REGISTRATION" )) { |
241 | REGISTER_KERNEL_BUILDER(Name("_HostRecv" ).Device(DEVICE_CPU), RecvOp); |
242 | REGISTER_KERNEL_BUILDER(Name("_HostSend" ).Device(DEVICE_CPU), SendOp); |
243 | } |
244 | return true; |
245 | } |
246 | |
247 | static bool module_initialized = InitModule(); |
248 | |
249 | } // end namespace tensorflow |
250 | |