1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/function_ops.h"
17
18#include <deque>
19#include <vector>
20
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/common_runtime/executor.h"
23#include "tensorflow/core/common_runtime/function.h"
24#include "tensorflow/core/common_runtime/gradients.h"
25#include "tensorflow/core/common_runtime/graph_constructor.h"
26#include "tensorflow/core/common_runtime/memory_types.h"
27#include "tensorflow/core/framework/cancellation.h"
28#include "tensorflow/core/framework/full_type.pb.h"
29#include "tensorflow/core/framework/full_type_util.h"
30#include "tensorflow/core/framework/op.h"
31#include "tensorflow/core/framework/register_types.h"
32#include "tensorflow/core/graph/algorithm.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/tracing.h"
35#include "tensorflow/core/profiler/lib/traceme.h"
36#include "tensorflow/core/util/device_name_utils.h"
37
38namespace tensorflow {
39
40static constexpr const char* const kGradientOp =
41 FunctionLibraryDefinition::kGradientOp;
42
43ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
44 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
45 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
46}
47
48void ArgOp::Compute(OpKernelContext* ctx) {
49 auto frame = ctx->call_frame();
50 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
51 const Tensor* val;
52
53 auto validate_type = [this](const Tensor& val) {
54 if (val.dtype() == dtype_) {
55 return OkStatus();
56 } else {
57 return errors::InvalidArgument("Type mismatch: actual ",
58 DataTypeString(val.dtype()),
59 " vs. expect ", DataTypeString(dtype_));
60 }
61 };
62
63 if (frame->CanConsumeArg(index_)) {
64 Tensor val;
65 frame->ConsumeArg(index_, &val);
66 OP_REQUIRES_OK(ctx, validate_type(val));
67 ctx->set_output(0, std::move(val));
68 } else {
69 OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
70 OP_REQUIRES_OK(ctx, validate_type(*val));
71 ctx->set_output(0, *val);
72 }
73}
74
75RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
76 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
77 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
78}
79
80void RetvalOp::Compute(OpKernelContext* ctx) {
81 const Tensor& val = ctx->input(0);
82 OP_REQUIRES(ctx, val.dtype() == dtype_,
83 errors::InvalidArgument("Type mismatch: actual ",
84 DataTypeString(val.dtype()),
85 " vs. expect ", DataTypeString(dtype_)));
86 auto frame = ctx->call_frame();
87 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
88 OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
89}
90
91REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
92REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
93REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
94REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
95
96// TPU ops are only registered when they are required as part of the larger
97// TPU runtime, and does not need to be registered when selective registration
98// is turned on.
99REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
100
101#define REGISTER(type) \
102 REGISTER_KERNEL_BUILDER( \
103 Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), ArgOp);
104TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER);
105TF_CALL_QUANTIZED_TYPES(REGISTER);
106TF_CALL_bool(REGISTER);
107
108REGISTER_KERNEL_BUILDER(
109 Name(kDeviceArgOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T"),
110 ArgOp);
111
112REGISTER_KERNEL_BUILDER(Name(kArgOp)
113 .Device(DEVICE_DEFAULT)
114 .HostMemory("output")
115 .TypeConstraint<int32>("T"),
116 ArgOp);
117#undef REGISTER
118
119REGISTER_KERNEL_BUILDER(Name(kArgOp)
120 .Device(DEVICE_DEFAULT)
121 .HostMemory("output")
122 .TypeConstraint<ResourceHandle>("T"),
123 ArgOp);
124
125REGISTER_KERNEL_BUILDER(Name(kArgOp)
126 .Device(DEVICE_DEFAULT)
127 .HostMemory("output")
128 .TypeConstraint<tstring>("T"),
129 ArgOp);
130
131REGISTER_KERNEL_BUILDER(
132 Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<Variant>("T"), ArgOp);
133
134#define REGISTER(type) \
135 REGISTER_KERNEL_BUILDER( \
136 Name(kRetOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
137 RetvalOp);
138
139TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER);
140TF_CALL_QUANTIZED_TYPES(REGISTER);
141TF_CALL_qint16(REGISTER);
142TF_CALL_quint16(REGISTER);
143REGISTER(Variant);
144TF_CALL_bool(REGISTER);
145
146REGISTER_KERNEL_BUILDER(Name(kRetOp)
147 .Device(DEVICE_DEFAULT)
148 .HostMemory("input")
149 .TypeConstraint<int32>("T"),
150 RetvalOp);
151REGISTER_KERNEL_BUILDER(
152 Name(kDeviceRetOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T"),
153 RetvalOp);
154
155REGISTER_KERNEL_BUILDER(Name(kRetOp)
156 .Device(DEVICE_DEFAULT)
157 .TypeConstraint<ResourceHandle>("T")
158 .HostMemory("input"),
159 RetvalOp);
160
161REGISTER_KERNEL_BUILDER(Name(kRetOp)
162 .Device(DEVICE_DEFAULT)
163 .TypeConstraint<tstring>("T")
164 .HostMemory("input"),
165 RetvalOp);
166
167#undef REGISTER
168
169class PassOn : public OpKernel {
170 public:
171 explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
172 OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
173 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
174 " vs. ", ctx->num_outputs()));
175 for (int i = 0; i < ctx->num_inputs(); ++i) {
176 OP_REQUIRES(
177 ctx, input_type(i) == output_type(i),
178 errors::Internal("Input and output types for position ", i,
179 " do not match: ", DataTypeString(input_type(i)),
180 " vs. ", DataTypeString(output_type(i))));
181 }
182 }
183
184 void Compute(OpKernelContext* ctx) override {
185 for (int i = 0; i < ctx->num_inputs(); ++i) {
186 ctx->set_output(i, ctx->input(i));
187 }
188 }
189};
190
191REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
192REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
193
194#define REGISTER_DEFAULT_KERNELS(type) \
195 REGISTER_KERNEL_BUILDER( \
196 Name("_ListToArray").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
197 PassOn); \
198 REGISTER_KERNEL_BUILDER( \
199 Name("_ArrayToList").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \
200 PassOn);
201
202REGISTER_DEFAULT_KERNELS(Eigen::half);
203REGISTER_DEFAULT_KERNELS(float);
204REGISTER_DEFAULT_KERNELS(double);
205
206#undef REGISTER_DEFAULT_KERNELS
207
208REGISTER_KERNEL_BUILDER(Name("_ListToArray")
209 .Device(DEVICE_DEFAULT)
210 .HostMemory("input")
211 .HostMemory("output")
212 .TypeConstraint<int32>("T"),
213 PassOn);
214REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
215 .Device(DEVICE_DEFAULT)
216 .HostMemory("input")
217 .HostMemory("output")
218 .TypeConstraint<int32>("T"),
219 PassOn);
220
221class SymbolicGradientOp : public AsyncOpKernel {
222 public:
223 explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
224
225 ~SymbolicGradientOp() override {}
226
227 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
228 FunctionLibraryRuntime* lib = ctx->function_library();
229 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
230 errors::Internal("No function library is provided."),
231 done);
232
233 FunctionLibraryRuntime::Handle handle;
234 OP_REQUIRES_OK_ASYNC(
235 ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
236
237 FunctionLibraryRuntime::Options opts;
238 opts.rendezvous = ctx->rendezvous();
239 opts.cancellation_manager = ctx->cancellation_manager();
240 opts.collective_executor = ctx->collective_executor();
241 opts.runner = ctx->runner();
242 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
243 opts.stats_collector = ctx->stats_collector();
244 opts.step_container = ctx->step_container();
245 std::vector<Tensor> args;
246 args.reserve(ctx->num_inputs());
247 for (int i = 0; i < ctx->num_inputs(); ++i) {
248 args.push_back(ctx->input(i));
249 }
250 std::vector<Tensor>* rets = new std::vector<Tensor>;
251 profiler::TraceMe trace_me("SymbolicGradientOp");
252 lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
253 if (!status.ok()) {
254 ctx->SetStatus(status);
255 } else if (rets->size() != ctx->num_outputs()) {
256 ctx->SetStatus(errors::InvalidArgument(
257 "SymGrad expects to return ", ctx->num_outputs(),
258 " tensor(s), but get ", rets->size(), " tensor(s) instead."));
259 } else {
260 for (size_t i = 0; i < rets->size(); ++i) {
261 ctx->set_output(i, std::move((*rets)[i]));
262 }
263 }
264 delete rets;
265 done();
266 });
267 }
268
269 private:
270 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
271};
272
273REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
274 SymbolicGradientOp);
275REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_DEFAULT),
276 SymbolicGradientOp);
277
278RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx)
279 : AsyncOpKernel(ctx), return_type_(ctx->def().experimental_type()) {
280 OP_REQUIRES_OK(ctx,
281 ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
282 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
283 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
284}
285
286void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
287 FunctionLibraryRuntime* lib = ctx->function_library();
288 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
289 errors::Internal("No function library is provided."), done);
290
291 const string& source_device = lib->device()->name();
292 const Tensor* target;
293 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
294
295 FunctionTarget function_target;
296 OP_REQUIRES_OK_ASYNC(
297 ctx,
298 DeviceNameUtils::CanonicalizeDeviceName(
299 target->scalar<tstring>()(), source_device, &function_target.first),
300 done);
301 function_target.second = lib;
302
303 const string& target_device = function_target.first;
304 const string& func_name = func_.name();
305
306 FunctionLibraryRuntime::Handle handle;
307 {
308 mutex_lock l(mu_);
309 auto cached_entry = handle_cache_.find(function_target);
310 if (cached_entry != handle_cache_.end()) {
311 handle = cached_entry->second;
312 } else {
313 VLOG(1) << "Instantiating " << func_name << " on " << target_device;
314 profiler::TraceMe activity(
315 [&] {
316 return strings::StrCat("RemoteCall: Instantiate: ", func_name,
317 " on ", target_device);
318 },
319 profiler::TraceMeLevel::kInfo);
320 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
321 const auto* config = (ctx->function_library())
322 ? ctx->function_library()->config_proto()
323 : nullptr;
324 if (config) {
325 instantiate_opts.config_proto = *config;
326 }
327 instantiate_opts.target = target_device;
328 OP_REQUIRES_OK_ASYNC(ctx,
329 lib->Instantiate(func_name, AttrSlice(&func_.attr()),
330 instantiate_opts, &handle),
331 done);
332 auto insert_result = handle_cache_.insert({function_target, handle});
333 CHECK(insert_result.second) << "Insert unsuccessful.";
334 VLOG(1) << "Instantiated " << func_name << " on " << target_device
335 << ", resulting in handle: " << handle << " flr: " << lib;
336 }
337 }
338
339 OpInputList arguments;
340 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
341
342 FunctionLibraryRuntime::Options opts;
343 opts.runner = nullptr; // Use default runner at remote device.
344 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
345 opts.source_device = source_device;
346 if (opts.source_device != target_device) {
347 opts.remote_execution = true;
348 }
349 opts.create_rendezvous = true;
350 CancellationManager* cancel_mgr = nullptr;
351 if (ctx->cancellation_manager() != nullptr) {
352 cancel_mgr = new CancellationManager(ctx->cancellation_manager());
353 }
354 opts.cancellation_manager = cancel_mgr;
355 opts.collective_executor = ctx->collective_executor();
356 std::vector<Tensor> args(arguments.begin(), arguments.end());
357 opts.args_alloc_attrs.reserve(input_dtypes_.size());
358 for (const auto& dtype : input_dtypes_) {
359 AllocatorAttributes arg_alloc_attrs;
360 arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
361 opts.args_alloc_attrs.push_back(arg_alloc_attrs);
362 }
363 opts.rets_alloc_attrs.reserve(output_dtypes_.size());
364 DCHECK(!return_type_.IsInitialized() ||
365 (return_type_.type_id() == TFT_UNSET) ||
366 (output_dtypes_.size() == return_type_.args_size()))
367 << "RemoteCall op has a full type information for "
368 << return_type_.args_size() << " outputs but the number of outputs is "
369 << output_dtypes_.size();
370 for (const auto& dtype : output_dtypes_) {
371 AllocatorAttributes ret_alloc_attrs;
372 bool on_host = DataTypeAlwaysOnHost(dtype);
373 if (return_type_.IsInitialized() && (return_type_.type_id() != TFT_UNSET)) {
374 DCHECK(return_type_.type_id() == TFT_PRODUCT)
375 << return_type_.DebugString();
376 FullTypeDef ftd = full_type::GetArgDefaultUnset(
377 return_type_, opts.rets_alloc_attrs.size());
378 if (full_type::IsHostMemoryType(ftd)) {
379 on_host = true;
380 }
381 VLOG(5) << "FulltypeDef for RemoteCall output="
382 << opts.rets_alloc_attrs.size()
383 << ", IsHostMemoryType=" << full_type::IsHostMemoryType(ftd)
384 << ":\n"
385 << ftd.DebugString();
386 }
387 ret_alloc_attrs.set_on_host(on_host);
388 opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
389 }
390 auto* rets = new std::vector<Tensor>;
391 VLOG(1) << "Running " << func_name << " on " << target_device
392 << " with handle: " << handle;
393 profiler::TraceMe trace_me(
394 [&] {
395 return profiler::TraceMeEncode(
396 "RemoteCallOp",
397 {{"func_name", func_name}, {"device", target_device}});
398 },
399 profiler::TraceMeLevel::kInfo);
400 lib->Run(
401 opts, handle, args, rets,
402 [rets, done = std::move(done), func_name, ctx, cancel_mgr,
403 target_device = std::move(function_target.first)](const Status& status) {
404 profiler::TraceMe activity(
405 [&] {
406 return profiler::TraceMeEncode(
407 "RemoteCallOpDone",
408 {{"func_name", func_name}, {"device", target_device}});
409 },
410 profiler::TraceMeLevel::kInfo);
411 if (!status.ok()) {
412 ctx->SetStatus(status);
413 } else {
414 for (size_t i = 0; i < rets->size(); ++i) {
415 ctx->set_output(i, std::move((*rets)[i]));
416 }
417 }
418 delete cancel_mgr;
419 delete rets;
420 done();
421 });
422}
423
424string RemoteCallOp::TraceString(const OpKernelContext& ctx,
425 bool verbose) const {
426 string trace_string = profiler::TraceMeOp(
427 strings::StrCat(name_view(), "__", func_.name()), type_string_view());
428 if (verbose) {
429 string shape = ShapeTraceString(ctx);
430 if (!shape.empty()) {
431 trace_string =
432 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
433 }
434 }
435 return trace_string;
436}
437
438REGISTER_KERNEL_BUILDER(
439 Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
440REGISTER_KERNEL_BUILDER(
441 Name("RemoteCall").Device(DEVICE_DEFAULT).HostMemory("target"),
442 RemoteCallOp);
443} // namespace tensorflow
444