1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/sendrecv_ops.h"
17
18#include <utility>
19
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/op_def_util.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/lib/strings/numbers.h"
25#include "tensorflow/core/lib/strings/strcat.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/profiler/lib/traceme.h"
28
29namespace tensorflow {
30
31static string GetRendezvousKeyPrefix(const string& send_device,
32 const string& recv_device,
33 const uint64 send_device_incarnation,
34 const string& tensor_name) {
35 return strings::StrCat(send_device, ";",
36 strings::FpToString(send_device_incarnation), ";",
37 recv_device, ";", tensor_name);
38}
39
40static void GetRendezvousKey(const string& key_prefix,
41 const FrameAndIter& frame_iter, string* key) {
42 key->clear();
43 strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":",
44 frame_iter.iter_id);
45}
46
47static FrameAndIter GetFrameAndIter(OpKernelContext* ctx,
48 bool hostmem_sendrecv) {
49 if (hostmem_sendrecv && ctx->call_frame() != nullptr) {
50 // Host memory send/recv pairs are added by
51 // common_runtime/memory_types.cc. When the pair of nodes are
52 // added inside a function, we need to use the function call frame
53 // to formulate the unique rendezvous key.
54 return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0);
55 } else {
56 return ctx->frame_iter();
57 }
58}
59
60SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
61 string send_device;
62 OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
63 string recv_device;
64 OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
65 uint64 send_device_incarnation;
66 OP_REQUIRES_OK(
67 ctx, ctx->GetAttr("send_device_incarnation",
68 reinterpret_cast<int64_t*>(&send_device_incarnation)));
69 string tensor_name;
70 OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
71 key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
72 send_device_incarnation, tensor_name);
73 // The vast majority of Send nodes are outside any loop context, so
74 // proactively cache the rendezvous key for the top-level.
75 GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
76 OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
77 if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
78 hostmem_sendrecv_ = false;
79 }
80}
81
82void SendOp::Compute(OpKernelContext* ctx) {
83 OP_REQUIRES(
84 ctx, ctx->rendezvous() != nullptr,
85 errors::Internal("Op kernel context needs to provide a rendezvous."));
86
87 // The device context may be passed between the Send/Recv
88 // boundary, so that the device context used to produce the Tensor
89 // is used when performing the copy on the recv side (which may be
90 // a different device).
91 Rendezvous::Args args;
92 args.device_context = ctx->op_device_context();
93 args.alloc_attrs = ctx->input_alloc_attr(0);
94
95 FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
96 if (frame_iter == FrameAndIter(0, 0)) {
97 // Use the cached rendezvous key.
98 VLOG(2) << "Send " << parsed_key_.buf_ << " using "
99 << reinterpret_cast<uintptr_t>(ctx->rendezvous());
100 ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0),
101 ctx->is_input_dead()));
102 return;
103 } else {
104 Rendezvous::ParsedKey in_loop_parsed;
105 GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
106 VLOG(2) << "Send " << in_loop_parsed.buf_ << " using "
107 << reinterpret_cast<uintptr_t>(ctx->rendezvous());
108 OP_REQUIRES_OK(ctx,
109 Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
110
111 ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0),
112 ctx->is_input_dead()));
113 return;
114 }
115}
116
117string SendOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
118 const auto& attr = def().attr();
119 auto src_it = attr.find("_src");
120 auto dst_it = attr.find("_dst");
121 const string& src = src_it != attr.end() ? src_it->second.s() : "";
122 const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
123 string op = profiler::TraceMeOp(name_view(), type_string_view());
124 return profiler::TraceMeEncode(
125 std::move(op),
126 {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
127}
128
129REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
130REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_DEFAULT), SendOp);
131REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_TPU_SYSTEM), SendOp);
132REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_TPU_SYSTEM), SendOp);
133
134// Public alias. Added for use in Lingvo.
135REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_CPU), SendOp);
136REGISTER_KERNEL_BUILDER(Name("Send").Device(DEVICE_DEFAULT), SendOp);
137
138REGISTER_KERNEL_BUILDER(
139 Name("_HostSend").Device(DEVICE_DEFAULT).HostMemory("tensor"), SendOp);
140
141RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
142 string send_device;
143 OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
144 string recv_device;
145 OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
146 uint64 send_device_incarnation;
147 OP_REQUIRES_OK(
148 ctx, ctx->GetAttr("send_device_incarnation",
149 reinterpret_cast<int64_t*>(&send_device_incarnation)));
150 string tensor_name;
151 OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
152 key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
153 send_device_incarnation, tensor_name);
154 // The vast majority of Recv nodes are outside any loop context, so
155 // proactively cache the rendezvous key for the top-level.
156 GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
157 OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
158 if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
159 hostmem_sendrecv_ = false;
160 }
161}
162
163string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const {
164 const auto& attr = def().attr();
165 auto src_it = attr.find("_src");
166 auto dst_it = attr.find("_dst");
167 const string& src = src_it != attr.end() ? src_it->second.s() : "";
168 const string& dst = dst_it != attr.end() ? dst_it->second.s() : "";
169 string op = profiler::TraceMeOp(name_view(), type_string_view());
170 return profiler::TraceMeEncode(
171 std::move(op),
172 {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}});
173}
174
175namespace {
176Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx,
177 AsyncOpKernel::DoneCallback done) {
178 return [ctx, done = std::move(done)](const Status& s,
179 const Rendezvous::Args& send_args,
180 const Rendezvous::Args& recv_args,
181 const Tensor& val, bool is_dead) {
182 ctx->SetStatus(s);
183 if (s.ok()) {
184 // 'ctx' allocates the output tensor of the expected type.
185 // The runtime checks whether the tensor received here is
186 // the same type.
187 if (!is_dead) {
188 ctx->set_output(0, val);
189 }
190 }
191 done();
192 };
193}
194} // namespace
195
196void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
197 OP_REQUIRES_ASYNC(
198 ctx, ctx->rendezvous() != nullptr,
199 errors::Internal("Op kernel context needs to provide a rendezvous."),
200 done);
201
202 Rendezvous::Args args;
203 args.device_context = ctx->op_device_context();
204 args.alloc_attrs = ctx->output_alloc_attr(0);
205 args.cancellation_manager = ctx->cancellation_manager();
206
207 FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
208 if (frame_iter == FrameAndIter(0, 0)) {
209 VLOG(2) << "Recv " << parsed_key_.buf_ << " using "
210 << reinterpret_cast<uintptr_t>(ctx->rendezvous());
211 ctx->rendezvous()->RecvAsync(parsed_key_, args,
212 make_recv_callback(ctx, std::move(done)));
213 } else {
214 Rendezvous::ParsedKey in_loop_parsed;
215 GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
216 VLOG(2) << "Recv " << in_loop_parsed.buf_ << " using "
217 << reinterpret_cast<uintptr_t>(ctx->rendezvous());
218 OP_REQUIRES_OK_ASYNC(
219 ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
220 ctx->rendezvous()->RecvAsync(in_loop_parsed, args,
221 make_recv_callback(ctx, std::move(done)));
222 }
223}
224
225REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
226REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_DEFAULT), RecvOp);
227REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_TPU_SYSTEM), RecvOp);
228REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_TPU_SYSTEM), RecvOp);
229
230// Public alias. Added for use in Lingvo.
231REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_CPU), RecvOp);
232REGISTER_KERNEL_BUILDER(Name("Recv").Device(DEVICE_DEFAULT), RecvOp);
233
234REGISTER_KERNEL_BUILDER(
235 Name("_HostRecv").Device(DEVICE_DEFAULT).HostMemory("tensor"), RecvOp);
236
237// Environment variable `DISABLE_HOST_SEND_RECV_REGISTRATION` is used to disable
238// hostSend and hostRecv registration on CPU device in the mock environment.
239static bool InitModule() {
240 if (!std::getenv("DISABLE_HOST_SEND_RECV_REGISTRATION")) {
241 REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
242 REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
243 }
244 return true;
245}
246
247static bool module_initialized = InitModule();
248
249} // end namespace tensorflow
250