1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/function_ops.h" |
17 | |
18 | #include <deque> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/common_runtime/device.h" |
22 | #include "tensorflow/core/common_runtime/executor.h" |
23 | #include "tensorflow/core/common_runtime/function.h" |
24 | #include "tensorflow/core/common_runtime/gradients.h" |
25 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
26 | #include "tensorflow/core/common_runtime/memory_types.h" |
27 | #include "tensorflow/core/framework/cancellation.h" |
28 | #include "tensorflow/core/framework/full_type.pb.h" |
29 | #include "tensorflow/core/framework/full_type_util.h" |
30 | #include "tensorflow/core/framework/op.h" |
31 | #include "tensorflow/core/framework/register_types.h" |
32 | #include "tensorflow/core/graph/algorithm.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | #include "tensorflow/core/platform/tracing.h" |
35 | #include "tensorflow/core/profiler/lib/traceme.h" |
36 | #include "tensorflow/core/util/device_name_utils.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | static constexpr const char* const kGradientOp = |
41 | FunctionLibraryDefinition::kGradientOp; |
42 | |
43 | ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
44 | OP_REQUIRES_OK(ctx, ctx->GetAttr("T" , &dtype_)); |
45 | OP_REQUIRES_OK(ctx, ctx->GetAttr("index" , &index_)); |
46 | } |
47 | |
48 | void ArgOp::Compute(OpKernelContext* ctx) { |
49 | auto frame = ctx->call_frame(); |
50 | OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame" )); |
51 | const Tensor* val; |
52 | |
53 | auto validate_type = [this](const Tensor& val) { |
54 | if (val.dtype() == dtype_) { |
55 | return OkStatus(); |
56 | } else { |
57 | return errors::InvalidArgument("Type mismatch: actual " , |
58 | DataTypeString(val.dtype()), |
59 | " vs. expect " , DataTypeString(dtype_)); |
60 | } |
61 | }; |
62 | |
63 | if (frame->CanConsumeArg(index_)) { |
64 | Tensor val; |
65 | frame->ConsumeArg(index_, &val); |
66 | OP_REQUIRES_OK(ctx, validate_type(val)); |
67 | ctx->set_output(0, std::move(val)); |
68 | } else { |
69 | OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); |
70 | OP_REQUIRES_OK(ctx, validate_type(*val)); |
71 | ctx->set_output(0, *val); |
72 | } |
73 | } |
74 | |
75 | RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
76 | OP_REQUIRES_OK(ctx, ctx->GetAttr("T" , &dtype_)); |
77 | OP_REQUIRES_OK(ctx, ctx->GetAttr("index" , &index_)); |
78 | } |
79 | |
80 | void RetvalOp::Compute(OpKernelContext* ctx) { |
81 | const Tensor& val = ctx->input(0); |
82 | OP_REQUIRES(ctx, val.dtype() == dtype_, |
83 | errors::InvalidArgument("Type mismatch: actual " , |
84 | DataTypeString(val.dtype()), |
85 | " vs. expect " , DataTypeString(dtype_))); |
86 | auto frame = ctx->call_frame(); |
87 | OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame" )); |
88 | OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val)); |
89 | } |
90 | |
91 | REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp); |
92 | REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp); |
93 | REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp); |
94 | REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp); |
95 | |
96 | // TPU ops are only registered when they are required as part of the larger |
97 | // TPU runtime, and does not need to be registered when selective registration |
98 | // is turned on. |
99 | REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp); |
100 | |
101 | #define REGISTER(type) \ |
102 | REGISTER_KERNEL_BUILDER( \ |
103 | Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), ArgOp); |
104 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER); |
105 | TF_CALL_QUANTIZED_TYPES(REGISTER); |
106 | TF_CALL_bool(REGISTER); |
107 | |
108 | REGISTER_KERNEL_BUILDER( |
109 | Name(kDeviceArgOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T" ), |
110 | ArgOp); |
111 | |
112 | REGISTER_KERNEL_BUILDER(Name(kArgOp) |
113 | .Device(DEVICE_DEFAULT) |
114 | .HostMemory("output" ) |
115 | .TypeConstraint<int32>("T" ), |
116 | ArgOp); |
117 | #undef REGISTER |
118 | |
119 | REGISTER_KERNEL_BUILDER(Name(kArgOp) |
120 | .Device(DEVICE_DEFAULT) |
121 | .HostMemory("output" ) |
122 | .TypeConstraint<ResourceHandle>("T" ), |
123 | ArgOp); |
124 | |
125 | REGISTER_KERNEL_BUILDER(Name(kArgOp) |
126 | .Device(DEVICE_DEFAULT) |
127 | .HostMemory("output" ) |
128 | .TypeConstraint<tstring>("T" ), |
129 | ArgOp); |
130 | |
131 | REGISTER_KERNEL_BUILDER( |
132 | Name(kArgOp).Device(DEVICE_DEFAULT).TypeConstraint<Variant>("T" ), ArgOp); |
133 | |
134 | #define REGISTER(type) \ |
135 | REGISTER_KERNEL_BUILDER( \ |
136 | Name(kRetOp).Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \ |
137 | RetvalOp); |
138 | |
139 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER); |
140 | TF_CALL_QUANTIZED_TYPES(REGISTER); |
141 | TF_CALL_qint16(REGISTER); |
142 | TF_CALL_quint16(REGISTER); |
143 | REGISTER(Variant); |
144 | TF_CALL_bool(REGISTER); |
145 | |
146 | REGISTER_KERNEL_BUILDER(Name(kRetOp) |
147 | .Device(DEVICE_DEFAULT) |
148 | .HostMemory("input" ) |
149 | .TypeConstraint<int32>("T" ), |
150 | RetvalOp); |
151 | REGISTER_KERNEL_BUILDER( |
152 | Name(kDeviceRetOp).Device(DEVICE_DEFAULT).TypeConstraint<int32>("T" ), |
153 | RetvalOp); |
154 | |
155 | REGISTER_KERNEL_BUILDER(Name(kRetOp) |
156 | .Device(DEVICE_DEFAULT) |
157 | .TypeConstraint<ResourceHandle>("T" ) |
158 | .HostMemory("input" ), |
159 | RetvalOp); |
160 | |
161 | REGISTER_KERNEL_BUILDER(Name(kRetOp) |
162 | .Device(DEVICE_DEFAULT) |
163 | .TypeConstraint<tstring>("T" ) |
164 | .HostMemory("input" ), |
165 | RetvalOp); |
166 | |
167 | #undef REGISTER |
168 | |
169 | class PassOn : public OpKernel { |
170 | public: |
171 | explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) { |
172 | OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(), |
173 | errors::Internal("#inputs != #outputs : " , ctx->num_inputs(), |
174 | " vs. " , ctx->num_outputs())); |
175 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
176 | OP_REQUIRES( |
177 | ctx, input_type(i) == output_type(i), |
178 | errors::Internal("Input and output types for position " , i, |
179 | " do not match: " , DataTypeString(input_type(i)), |
180 | " vs. " , DataTypeString(output_type(i)))); |
181 | } |
182 | } |
183 | |
184 | void Compute(OpKernelContext* ctx) override { |
185 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
186 | ctx->set_output(i, ctx->input(i)); |
187 | } |
188 | } |
189 | }; |
190 | |
191 | REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray" ).Device(DEVICE_CPU), PassOn); |
192 | REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList" ).Device(DEVICE_CPU), PassOn); |
193 | |
194 | #define REGISTER_DEFAULT_KERNELS(type) \ |
195 | REGISTER_KERNEL_BUILDER( \ |
196 | Name("_ListToArray").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \ |
197 | PassOn); \ |
198 | REGISTER_KERNEL_BUILDER( \ |
199 | Name("_ArrayToList").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \ |
200 | PassOn); |
201 | |
202 | REGISTER_DEFAULT_KERNELS(Eigen::half); |
203 | REGISTER_DEFAULT_KERNELS(float); |
204 | REGISTER_DEFAULT_KERNELS(double); |
205 | |
206 | #undef REGISTER_DEFAULT_KERNELS |
207 | |
208 | REGISTER_KERNEL_BUILDER(Name("_ListToArray" ) |
209 | .Device(DEVICE_DEFAULT) |
210 | .HostMemory("input" ) |
211 | .HostMemory("output" ) |
212 | .TypeConstraint<int32>("T" ), |
213 | PassOn); |
214 | REGISTER_KERNEL_BUILDER(Name("_ArrayToList" ) |
215 | .Device(DEVICE_DEFAULT) |
216 | .HostMemory("input" ) |
217 | .HostMemory("output" ) |
218 | .TypeConstraint<int32>("T" ), |
219 | PassOn); |
220 | |
221 | class SymbolicGradientOp : public AsyncOpKernel { |
222 | public: |
223 | explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {} |
224 | |
225 | ~SymbolicGradientOp() override {} |
226 | |
227 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
228 | FunctionLibraryRuntime* lib = ctx->function_library(); |
229 | OP_REQUIRES_ASYNC(ctx, lib != nullptr, |
230 | errors::Internal("No function library is provided." ), |
231 | done); |
232 | |
233 | FunctionLibraryRuntime::Handle handle; |
234 | OP_REQUIRES_OK_ASYNC( |
235 | ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done); |
236 | |
237 | FunctionLibraryRuntime::Options opts; |
238 | opts.rendezvous = ctx->rendezvous(); |
239 | opts.cancellation_manager = ctx->cancellation_manager(); |
240 | opts.collective_executor = ctx->collective_executor(); |
241 | opts.runner = ctx->runner(); |
242 | opts.run_all_kernels_inline = ctx->run_all_kernels_inline(); |
243 | opts.stats_collector = ctx->stats_collector(); |
244 | opts.step_container = ctx->step_container(); |
245 | std::vector<Tensor> args; |
246 | args.reserve(ctx->num_inputs()); |
247 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
248 | args.push_back(ctx->input(i)); |
249 | } |
250 | std::vector<Tensor>* rets = new std::vector<Tensor>; |
251 | profiler::TraceMe trace_me("SymbolicGradientOp" ); |
252 | lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { |
253 | if (!status.ok()) { |
254 | ctx->SetStatus(status); |
255 | } else if (rets->size() != ctx->num_outputs()) { |
256 | ctx->SetStatus(errors::InvalidArgument( |
257 | "SymGrad expects to return " , ctx->num_outputs(), |
258 | " tensor(s), but get " , rets->size(), " tensor(s) instead." )); |
259 | } else { |
260 | for (size_t i = 0; i < rets->size(); ++i) { |
261 | ctx->set_output(i, std::move((*rets)[i])); |
262 | } |
263 | } |
264 | delete rets; |
265 | done(); |
266 | }); |
267 | } |
268 | |
269 | private: |
270 | TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); |
271 | }; |
272 | |
273 | REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU), |
274 | SymbolicGradientOp); |
275 | REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_DEFAULT), |
276 | SymbolicGradientOp); |
277 | |
278 | RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) |
279 | : AsyncOpKernel(ctx), return_type_(ctx->def().experimental_type()) { |
280 | OP_REQUIRES_OK(ctx, |
281 | ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_)); |
282 | OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin" , &input_dtypes_)); |
283 | OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout" , &output_dtypes_)); |
284 | } |
285 | |
286 | void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { |
287 | FunctionLibraryRuntime* lib = ctx->function_library(); |
288 | OP_REQUIRES_ASYNC(ctx, lib != nullptr, |
289 | errors::Internal("No function library is provided." ), done); |
290 | |
291 | const string& source_device = lib->device()->name(); |
292 | const Tensor* target; |
293 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target" , &target), done); |
294 | |
295 | FunctionTarget function_target; |
296 | OP_REQUIRES_OK_ASYNC( |
297 | ctx, |
298 | DeviceNameUtils::CanonicalizeDeviceName( |
299 | target->scalar<tstring>()(), source_device, &function_target.first), |
300 | done); |
301 | function_target.second = lib; |
302 | |
303 | const string& target_device = function_target.first; |
304 | const string& func_name = func_.name(); |
305 | |
306 | FunctionLibraryRuntime::Handle handle; |
307 | { |
308 | mutex_lock l(mu_); |
309 | auto cached_entry = handle_cache_.find(function_target); |
310 | if (cached_entry != handle_cache_.end()) { |
311 | handle = cached_entry->second; |
312 | } else { |
313 | VLOG(1) << "Instantiating " << func_name << " on " << target_device; |
314 | profiler::TraceMe activity( |
315 | [&] { |
316 | return strings::StrCat("RemoteCall: Instantiate: " , func_name, |
317 | " on " , target_device); |
318 | }, |
319 | profiler::TraceMeLevel::kInfo); |
320 | FunctionLibraryRuntime::InstantiateOptions instantiate_opts; |
321 | const auto* config = (ctx->function_library()) |
322 | ? ctx->function_library()->config_proto() |
323 | : nullptr; |
324 | if (config) { |
325 | instantiate_opts.config_proto = *config; |
326 | } |
327 | instantiate_opts.target = target_device; |
328 | OP_REQUIRES_OK_ASYNC(ctx, |
329 | lib->Instantiate(func_name, AttrSlice(&func_.attr()), |
330 | instantiate_opts, &handle), |
331 | done); |
332 | auto insert_result = handle_cache_.insert({function_target, handle}); |
333 | CHECK(insert_result.second) << "Insert unsuccessful." ; |
334 | VLOG(1) << "Instantiated " << func_name << " on " << target_device |
335 | << ", resulting in handle: " << handle << " flr: " << lib; |
336 | } |
337 | } |
338 | |
339 | OpInputList arguments; |
340 | OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args" , &arguments), done); |
341 | |
342 | FunctionLibraryRuntime::Options opts; |
343 | opts.runner = nullptr; // Use default runner at remote device. |
344 | opts.run_all_kernels_inline = ctx->run_all_kernels_inline(); |
345 | opts.source_device = source_device; |
346 | if (opts.source_device != target_device) { |
347 | opts.remote_execution = true; |
348 | } |
349 | opts.create_rendezvous = true; |
350 | CancellationManager* cancel_mgr = nullptr; |
351 | if (ctx->cancellation_manager() != nullptr) { |
352 | cancel_mgr = new CancellationManager(ctx->cancellation_manager()); |
353 | } |
354 | opts.cancellation_manager = cancel_mgr; |
355 | opts.collective_executor = ctx->collective_executor(); |
356 | std::vector<Tensor> args(arguments.begin(), arguments.end()); |
357 | opts.args_alloc_attrs.reserve(input_dtypes_.size()); |
358 | for (const auto& dtype : input_dtypes_) { |
359 | AllocatorAttributes arg_alloc_attrs; |
360 | arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype)); |
361 | opts.args_alloc_attrs.push_back(arg_alloc_attrs); |
362 | } |
363 | opts.rets_alloc_attrs.reserve(output_dtypes_.size()); |
364 | DCHECK(!return_type_.IsInitialized() || |
365 | (return_type_.type_id() == TFT_UNSET) || |
366 | (output_dtypes_.size() == return_type_.args_size())) |
367 | << "RemoteCall op has a full type information for " |
368 | << return_type_.args_size() << " outputs but the number of outputs is " |
369 | << output_dtypes_.size(); |
370 | for (const auto& dtype : output_dtypes_) { |
371 | AllocatorAttributes ret_alloc_attrs; |
372 | bool on_host = DataTypeAlwaysOnHost(dtype); |
373 | if (return_type_.IsInitialized() && (return_type_.type_id() != TFT_UNSET)) { |
374 | DCHECK(return_type_.type_id() == TFT_PRODUCT) |
375 | << return_type_.DebugString(); |
376 | FullTypeDef ftd = full_type::GetArgDefaultUnset( |
377 | return_type_, opts.rets_alloc_attrs.size()); |
378 | if (full_type::IsHostMemoryType(ftd)) { |
379 | on_host = true; |
380 | } |
381 | VLOG(5) << "FulltypeDef for RemoteCall output=" |
382 | << opts.rets_alloc_attrs.size() |
383 | << ", IsHostMemoryType=" << full_type::IsHostMemoryType(ftd) |
384 | << ":\n" |
385 | << ftd.DebugString(); |
386 | } |
387 | ret_alloc_attrs.set_on_host(on_host); |
388 | opts.rets_alloc_attrs.push_back(ret_alloc_attrs); |
389 | } |
390 | auto* rets = new std::vector<Tensor>; |
391 | VLOG(1) << "Running " << func_name << " on " << target_device |
392 | << " with handle: " << handle; |
393 | profiler::TraceMe trace_me( |
394 | [&] { |
395 | return profiler::TraceMeEncode( |
396 | "RemoteCallOp" , |
397 | {{"func_name" , func_name}, {"device" , target_device}}); |
398 | }, |
399 | profiler::TraceMeLevel::kInfo); |
400 | lib->Run( |
401 | opts, handle, args, rets, |
402 | [rets, done = std::move(done), func_name, ctx, cancel_mgr, |
403 | target_device = std::move(function_target.first)](const Status& status) { |
404 | profiler::TraceMe activity( |
405 | [&] { |
406 | return profiler::TraceMeEncode( |
407 | "RemoteCallOpDone" , |
408 | {{"func_name" , func_name}, {"device" , target_device}}); |
409 | }, |
410 | profiler::TraceMeLevel::kInfo); |
411 | if (!status.ok()) { |
412 | ctx->SetStatus(status); |
413 | } else { |
414 | for (size_t i = 0; i < rets->size(); ++i) { |
415 | ctx->set_output(i, std::move((*rets)[i])); |
416 | } |
417 | } |
418 | delete cancel_mgr; |
419 | delete rets; |
420 | done(); |
421 | }); |
422 | } |
423 | |
424 | string RemoteCallOp::TraceString(const OpKernelContext& ctx, |
425 | bool verbose) const { |
426 | string trace_string = profiler::TraceMeOp( |
427 | strings::StrCat(name_view(), "__" , func_.name()), type_string_view()); |
428 | if (verbose) { |
429 | string shape = ShapeTraceString(ctx); |
430 | if (!shape.empty()) { |
431 | trace_string = |
432 | profiler::TraceMeEncode(std::move(trace_string), {{"shape" , shape}}); |
433 | } |
434 | } |
435 | return trace_string; |
436 | } |
437 | |
438 | REGISTER_KERNEL_BUILDER( |
439 | Name("RemoteCall" ).Device(DEVICE_CPU).HostMemory("target" ), RemoteCallOp); |
440 | REGISTER_KERNEL_BUILDER( |
441 | Name("RemoteCall" ).Device(DEVICE_DEFAULT).HostMemory("target" ), |
442 | RemoteCallOp); |
443 | } // namespace tensorflow |
444 | |