1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/types.h" |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
19 | #include "tensorflow/core/common_runtime/device.h" |
20 | #include "tensorflow/core/framework/device_base.h" |
21 | #include "tensorflow/core/framework/function.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/threadpool.h" |
27 | #include "tensorflow/core/platform/casts.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/macros.h" |
30 | #include "tensorflow/core/profiler/lib/traceme.h" |
31 | |
32 | namespace tensorflow { |
33 | typedef Eigen::GpuDevice GPUDevice; |
34 | typedef Eigen::ThreadPoolDevice CPUDevice; |
35 | typedef FunctionLibraryRuntime::Handle FHandle; |
36 | typedef std::vector<Tensor> TensorVec; |
37 | |
38 | namespace { |
39 | |
40 | // Helper to instantiate function "func" in the library "lib". |
41 | Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, |
42 | FunctionLibraryRuntime::Handle* handle) { |
43 | return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); |
44 | } |
45 | |
46 | Status Instantiate(OpKernelContext* ctx, const NameAttrList& func, |
47 | FunctionLibraryRuntime::Handle* handle) { |
48 | FunctionLibraryRuntime::InstantiateOptions opts; |
49 | opts.executor_type = ctx->executor_type(); |
50 | return ctx->function_library()->Instantiate( |
51 | func.name(), AttrSlice(&func.attr()), opts, handle); |
52 | } |
53 | |
54 | // If "t" is a scalar of a supported type, returns t != 0 in "*v". |
55 | Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) { |
56 | if (t.size() != 1) { |
57 | return errors::InvalidArgument( |
58 | "Expected a single scalar which can be converted to a boolean, got " , |
59 | t.size(), " tensors." ); |
60 | } |
61 | if (TensorShapeUtils::IsScalar(t[0].shape())) { |
62 | switch (t[0].dtype()) { |
63 | #define CASE(T) \ |
64 | case DataTypeToEnum<T>::value: \ |
65 | *v = t[0].scalar<T>()() != 0; \ |
66 | break; |
67 | |
68 | CASE(float); |
69 | CASE(double); |
70 | CASE(int32); |
71 | CASE(uint8); |
72 | CASE(int16); |
73 | CASE(int8); |
74 | CASE(int64_t); |
75 | #undef CASE |
76 | case DT_BOOL: |
77 | *v = t[0].scalar<bool>()(); |
78 | break; |
79 | case DT_STRING: |
80 | *v = !t[0].scalar<tstring>()().empty(); |
81 | break; |
82 | default: |
83 | return errors::InvalidArgument(DataTypeString(t[0].dtype()), |
84 | " cannot be converted to a boolean" ); |
85 | } |
86 | } else { |
87 | *v = t[0].NumElements() > 0; |
88 | } |
89 | return OkStatus(); |
90 | } |
91 | |
92 | // Sets "rets" to be the output of "ctx". Validates rets' types based |
93 | // on "kernel". |
94 | Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, |
95 | gtl::ArraySlice<Tensor> rets) { |
96 | if (rets.size() != ctx->num_outputs()) { |
97 | return errors::Internal("Expect to produce " , ctx->num_outputs(), |
98 | " tensors, but only get " , rets.size()); |
99 | } |
100 | for (int i = 0; i < rets.size(); ++i) { |
101 | if (rets[i].dtype() != kernel->output_type(i)) { |
102 | return errors::Internal("Expect " , i, "-th output is of type " , |
103 | DataTypeString(kernel->output_type(i)), |
104 | " but get " , DataTypeString(rets[i].dtype())); |
105 | } |
106 | ctx->set_output(i, rets[i]); |
107 | } |
108 | return OkStatus(); |
109 | } |
110 | |
111 | void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, |
112 | bool always_collect_stats) { |
113 | opts->rendezvous = ctx->rendezvous(); |
114 | opts->cancellation_manager = ctx->cancellation_manager(); |
115 | opts->collective_executor = ctx->collective_executor(); |
116 | if (always_collect_stats) { |
117 | opts->stats_collector = ctx->stats_collector(); |
118 | } |
119 | opts->runner = ctx->runner(); |
120 | opts->run_all_kernels_inline = ctx->run_all_kernels_inline(); |
121 | opts->step_container = ctx->step_container(); |
122 | } |
123 | |
124 | class IfOp : public AsyncOpKernel { |
125 | public: |
126 | explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |
127 | auto lib = ctx->function_library(); |
128 | OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library" )); |
129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch" , &then_func_)); |
130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch" , &else_func_)); |
131 | } |
132 | |
133 | ~IfOp() override {} |
134 | |
135 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
136 | FHandle then_handle; |
137 | FHandle else_handle; |
138 | OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle), |
139 | done); |
140 | bool cond; |
141 | OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); |
142 | (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); |
143 | } |
144 | |
145 | private: |
146 | NameAttrList then_func_; |
147 | NameAttrList else_func_; |
148 | |
149 | mutex mu_; |
150 | std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> |
151 | handles_ ABSL_GUARDED_BY(mu_); |
152 | |
153 | class State { |
154 | public: |
155 | State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, |
156 | FHandle else_handle, DoneCallback done) |
157 | : kernel_(kernel), |
158 | ctx_(ctx), |
159 | cond_(cond), |
160 | then_handle_(then_handle), |
161 | else_handle_(else_handle), |
162 | done_(std::move(done)), |
163 | lib_(CHECK_NOTNULL(ctx_->function_library())) { |
164 | SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); |
165 | for (int i = 1; i < ctx_->num_inputs(); ++i) { |
166 | args_.push_back(ctx_->input(i)); |
167 | } |
168 | } |
169 | |
170 | ~State() {} |
171 | |
172 | void Start() { |
173 | FHandle handle = cond_ ? then_handle_ : else_handle_; |
174 | rets_.clear(); |
175 | profiler::TraceMe trace_me("IfOp" ); |
176 | lib_->Run( |
177 | // Evaluate one of the branch. |
178 | opts_, handle, args_, &rets_, |
179 | // Done callback |
180 | [this](Status s) { |
181 | if (s.ok()) { |
182 | s = SetOutputs(kernel_, ctx_, rets_); |
183 | } |
184 | ctx_->SetStatus(s); |
185 | DoneCallback captured_done(std::move(done_)); |
186 | delete this; |
187 | captured_done(); |
188 | }); |
189 | } |
190 | |
191 | private: |
192 | IfOp* const kernel_; |
193 | OpKernelContext* const ctx_; |
194 | const bool cond_; |
195 | FHandle then_handle_; |
196 | FHandle else_handle_; |
197 | DoneCallback done_; |
198 | FunctionLibraryRuntime* const lib_; |
199 | FunctionLibraryRuntime::Options opts_; |
200 | TensorVec args_; |
201 | TensorVec rets_; |
202 | }; |
203 | |
204 | Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, |
205 | FHandle* else_handle) { |
206 | // TODO(b/37549631): Because this op has `SetIsStateful()` in its |
207 | // op registration, this kernel may be shared by multiple |
208 | // subgraphs, which have different associated |
209 | // `FunctionLibraryRuntime` objects and hence different `FHandle` |
210 | // namespaces. We currently work around this by caching the map |
211 | // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two |
212 | // functions this op uses. |
213 | auto lib = ctx->function_library(); |
214 | if (lib == nullptr) return errors::Internal("No function library" ); |
215 | *then_handle = kInvalidHandle; |
216 | *else_handle = kInvalidHandle; |
217 | { |
218 | tf_shared_lock l(mu_); |
219 | const auto iter = handles_.find(lib); |
220 | if (TF_PREDICT_TRUE(iter != handles_.end())) { |
221 | *then_handle = iter->second.first; |
222 | *else_handle = iter->second.second; |
223 | } |
224 | } |
225 | if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) { |
226 | mutex_lock l(mu_); |
227 | const auto iter = handles_.find(lib); |
228 | if (TF_PREDICT_TRUE(iter != handles_.end())) { |
229 | *then_handle = iter->second.first; |
230 | *else_handle = iter->second.second; |
231 | } else { |
232 | TF_RETURN_IF_ERROR(Instantiate(ctx, then_func_, then_handle)); |
233 | TF_RETURN_IF_ERROR(Instantiate(ctx, else_func_, else_handle)); |
234 | handles_[lib] = {*then_handle, *else_handle}; |
235 | } |
236 | } |
237 | return OkStatus(); |
238 | } |
239 | }; |
240 | |
241 | class CaseOp : public AsyncOpKernel { |
242 | public: |
243 | explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |
244 | auto lib = ctx->function_library(); |
245 | OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library" )); |
246 | OP_REQUIRES_OK(ctx, ctx->GetAttr("branches" , &branch_funcs_)); |
247 | } |
248 | |
249 | ~CaseOp() override {} |
250 | |
251 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
252 | auto lib = ctx->function_library(); |
253 | OP_REQUIRES_ASYNC(ctx, lib != nullptr, |
254 | errors::Internal("No function library" ), done); |
255 | |
256 | // TODO(b/37549631): Because this op has `SetIsStateful()` in its op |
257 | // registration, this kernel may be shared by multiple subgraphs, which have |
258 | // different associated `FunctionLibraryRuntime` objects and hence different |
259 | // `FHandle` namespaces. So we must call Instantiate() to make sure we get |
260 | // the correct function handles with respect to `lib`. Note the underlying |
261 | // `lib->Instantiate()` caches the created function handles, so calling |
262 | // `Instantiate()` repeatedly on the same `lib` and function is cheap. |
263 | std::vector<FHandle> branch_handles(branch_funcs_.size()); |
264 | for (int i = 0; i < branch_funcs_.size(); i++) { |
265 | OP_REQUIRES_OK_ASYNC( |
266 | ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done); |
267 | } |
268 | |
269 | const Tensor& branch_index = ctx->input(0); |
270 | OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()), |
271 | errors::InvalidArgument("branch_index must be scalar" ), |
272 | done); |
273 | int32_t branch = branch_index.scalar<int32>()(); |
274 | (new State(this, ctx, branch, branch_handles, done))->Start(); |
275 | } |
276 | |
277 | private: |
278 | std::vector<NameAttrList> branch_funcs_; |
279 | |
280 | class State { |
281 | public: |
282 | State(CaseOp* kernel, OpKernelContext* ctx, int branch, |
283 | std::vector<FHandle> branch_handles, DoneCallback done) |
284 | : kernel_(kernel), |
285 | ctx_(ctx), |
286 | branch_(branch), |
287 | branch_handles_(branch_handles), |
288 | done_(std::move(done)), |
289 | lib_(CHECK_NOTNULL(ctx_->function_library())) { |
290 | SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); |
291 | for (int i = 1; i < ctx_->num_inputs(); ++i) { |
292 | args_.push_back(ctx_->input(i)); |
293 | } |
294 | } |
295 | |
296 | ~State() {} |
297 | |
298 | void Start() { |
299 | int branch = branch_; |
300 | // The last branch is the default branch. |
301 | if (branch < 0 || branch >= branch_handles_.size()) { |
302 | branch = branch_handles_.size() - 1; |
303 | } |
304 | rets_.clear(); |
305 | profiler::TraceMe trace_me("CaseOp" ); |
306 | lib_->Run( |
307 | // Evaluate one of the branch. |
308 | opts_, branch_handles_[branch], args_, &rets_, |
309 | // Done callback |
310 | [this](Status s) { |
311 | if (s.ok()) { |
312 | s = SetOutputs(kernel_, ctx_, rets_); |
313 | } |
314 | ctx_->SetStatus(s); |
315 | DoneCallback captured_done(std::move(done_)); |
316 | delete this; |
317 | captured_done(); |
318 | }); |
319 | } |
320 | |
321 | private: |
322 | CaseOp* const kernel_; |
323 | OpKernelContext* const ctx_; |
324 | const int branch_; |
325 | std::vector<FHandle> branch_handles_; |
326 | DoneCallback done_; |
327 | FunctionLibraryRuntime* const lib_; |
328 | FunctionLibraryRuntime::Options opts_; |
329 | TensorVec args_; |
330 | TensorVec rets_; |
331 | }; |
332 | }; |
333 | |
334 | // TODO(drpng): remove this. |
335 | REGISTER_KERNEL_BUILDER(Name("_If" ).Device(DEVICE_CPU), IfOp); |
336 | REGISTER_KERNEL_BUILDER(Name("_If" ).Device(DEVICE_DEFAULT).HostMemory("cond" ), |
337 | IfOp); |
338 | |
339 | REGISTER_KERNEL_BUILDER(Name("If" ).Device(DEVICE_CPU), IfOp); |
340 | REGISTER_KERNEL_BUILDER(Name("If" ).Device(DEVICE_DEFAULT).HostMemory("cond" ), |
341 | IfOp); |
342 | |
343 | REGISTER_KERNEL_BUILDER(Name("Case" ).Device(DEVICE_CPU), CaseOp); |
344 | REGISTER_KERNEL_BUILDER( |
345 | Name("Case" ).Device(DEVICE_DEFAULT).HostMemory("branch_index" ), CaseOp); |
346 | REGISTER_KERNEL_BUILDER(Name("StatelessCase" ).Device(DEVICE_CPU), CaseOp); |
347 | REGISTER_KERNEL_BUILDER( |
348 | Name("StatelessCase" ).Device(DEVICE_DEFAULT).HostMemory("branch_index" ), |
349 | CaseOp); |
350 | |
351 | REGISTER_KERNEL_BUILDER(Name("StatelessIf" ).Device(DEVICE_CPU), IfOp); |
352 | REGISTER_KERNEL_BUILDER( |
353 | Name("StatelessIf" ).Device(DEVICE_DEFAULT).HostMemory("cond" ), IfOp); |
354 | |
355 | class WhileOp : public AsyncOpKernel { |
356 | public: |
357 | explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |
358 | OP_REQUIRES_OK(ctx, ctx->GetAttr("cond" , &cond_func_)); |
359 | OP_REQUIRES_OK(ctx, ctx->GetAttr("body" , &body_func_)); |
360 | } |
361 | |
362 | ~WhileOp() override {} |
363 | |
364 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
365 | if (ctx->run_all_kernels_inline()) { |
366 | // Use the non-callback-based implementation when kernels (and function |
367 | // callbacks) execute inline to avoid stack overflow. |
368 | OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done); |
369 | done(); |
370 | } else { |
371 | FHandle cond_handle; |
372 | FHandle body_handle; |
373 | OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle), |
374 | done); |
375 | (new State(this, ctx, cond_handle, body_handle, done))->Start(); |
376 | } |
377 | } |
378 | |
379 | void Compute(OpKernelContext* ctx) override { |
380 | // Use the non-callback-based implementation when the synchronous Compute() |
381 | // method is invoked, because the caller is explicitly donating a thread. |
382 | Status s = DoComputeSync(ctx); |
383 | // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is |
384 | // still an AsyncOpKernel, and there is a run-time check to avoid calling |
385 | // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in |
386 | // the event of an error). |
387 | if (TF_PREDICT_FALSE(!s.ok())) { |
388 | ctx->SetStatus(s); |
389 | } |
390 | } |
391 | |
392 | private: |
393 | NameAttrList cond_func_; |
394 | NameAttrList body_func_; |
395 | |
396 | mutex mu_; |
397 | std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> |
398 | handles_ ABSL_GUARDED_BY(mu_); |
399 | |
400 | static Status CondResultToBool(OpKernelContext* ctx, |
401 | const FunctionLibraryRuntime::Options& opts, |
402 | const Tensor& cond_t, bool* out_result) { |
403 | bool is_pluggable = ctx->op_device_context() && |
404 | ctx->op_device_context()->IsPluggableDevice(); |
405 | const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = |
406 | ctx->device()->tensorflow_accelerator_device_info(); |
407 | const bool is_hostmem_dtype = |
408 | cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64; |
409 | if (!is_hostmem_dtype && (is_pluggable || accelerator_device_info) && |
410 | (opts.rets_alloc_attrs.empty() || |
411 | !opts.rets_alloc_attrs[0].on_host())) { |
412 | // Copy the ret value to host if it's allocated on device. |
413 | Device* device = down_cast<Device*>(ctx->device()); |
414 | DeviceContext* device_ctx = ctx->op_device_context(); |
415 | Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape()); |
416 | TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync( |
417 | &cond_t, /*tensor_name=*/"" , device, &host_cond_t)); |
418 | return ToBool({host_cond_t}, out_result); |
419 | } |
420 | return ToBool({cond_t}, out_result); |
421 | } |
422 | |
423 | // The initial loop variable args are the inputs to the kernel. |
424 | // |
425 | // We attempt to forward the input so that it can be consumed inside the |
426 | // body function (and participate in buffer forwarding, etc.). |
427 | static void GetArgsFromContext(OpKernelContext* ctx, |
428 | std::vector<Tensor>* out_args, |
429 | DataTypeVector* out_var_types) { |
430 | const int num_loop_vars = ctx->num_inputs(); |
431 | out_args->reserve(num_loop_vars); |
432 | out_var_types->resize(num_loop_vars); |
433 | for (int i = 0; i < num_loop_vars; ++i) { |
434 | const Tensor& input = ctx->input(i); |
435 | (*out_var_types)[i] = input.dtype(); |
436 | std::unique_ptr<Tensor> maybe_forwarded_input = ctx->forward_input( |
437 | i, /* output_index= */ OpKernelContext::Params::kNoReservation, |
438 | input.dtype(), input.shape(), ctx->input_memory_type(i), |
439 | ctx->input_alloc_attr(i)); |
440 | if (maybe_forwarded_input) { |
441 | out_args->push_back(std::move(*maybe_forwarded_input)); |
442 | } else { |
443 | out_args->push_back(input); |
444 | } |
445 | } |
446 | } |
447 | |
448 | class BodyFuncCallFrame : public CallFrameInterface { |
449 | public: |
450 | BodyFuncCallFrame(std::vector<Tensor>* args, std::vector<Tensor>* retvals, |
451 | DataTypeSlice ret_types) |
452 | : args_(args), retvals_(retvals), ret_types_(ret_types) {} |
453 | |
454 | size_t num_args() const override { return args_->size(); } |
455 | size_t num_retvals() const override { return retvals_->size(); } |
456 | |
457 | Status GetArg(int index, const Tensor** val) override { |
458 | if (index < args_->size()) { |
459 | *val = &(*args_)[index]; |
460 | return OkStatus(); |
461 | } else { |
462 | return errors::InvalidArgument("Argument " , index, " is out of range." ); |
463 | } |
464 | } |
465 | |
466 | void ConsumeArg(int index, Tensor* val) override { |
467 | DCHECK_GE(index, 0); |
468 | DCHECK_LT(index, args_->size()); |
469 | *val = std::move((*args_)[index]); |
470 | } |
471 | bool CanConsumeArg(int index) const override { |
472 | return index >= 0 && index < args_->size(); |
473 | } |
474 | |
475 | Status SetRetval(int index, const Tensor& val) override { |
476 | if (TF_PREDICT_FALSE(index < 0)) { |
477 | return errors::InvalidArgument( |
478 | "Expected non-negative return value index, but got: " , index, "." ); |
479 | } else if (TF_PREDICT_FALSE(index >= retvals_->size())) { |
480 | return errors::InvalidArgument("While loop body returned " , index + 1, |
481 | " arguments. Expected: " , num_retvals(), |
482 | "." ); |
483 | } else if (TF_PREDICT_FALSE(val.dtype() != ret_types_[index])) { |
484 | return errors::InvalidArgument("Expected type " , |
485 | DataTypeString(ret_types_[index]), |
486 | " for return value " , index, " but got " , |
487 | DataTypeString(val.dtype()), "." ); |
488 | } |
489 | (*retvals_)[index] = val; |
490 | return OkStatus(); |
491 | } |
492 | |
493 | private: |
494 | std::vector<Tensor>* const args_; // Not owned. |
495 | std::vector<Tensor>* const retvals_; // Not owned. |
496 | DataTypeSlice ret_types_; |
497 | |
498 | TF_DISALLOW_COPY_AND_ASSIGN(BodyFuncCallFrame); |
499 | }; |
500 | |
501 | class State { |
502 | public: |
503 | State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, |
504 | FHandle body_handle, DoneCallback done) |
505 | : kernel_(kernel), |
506 | ctx_(ctx), |
507 | cond_handle_(cond_handle), |
508 | body_handle_(body_handle), |
509 | done_(std::move(done)), |
510 | lib_(CHECK_NOTNULL(ctx_->function_library())) { |
511 | SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); |
512 | GetArgsFromContext(ctx, &args_, &loop_var_types_); |
513 | body_frame_ = |
514 | std::make_unique<BodyFuncCallFrame>(&args_, &rets_, loop_var_types_); |
515 | } |
516 | |
517 | ~State() {} |
518 | |
519 | void Start() { EvalCond(); } |
520 | |
521 | private: |
522 | WhileOp* const kernel_; |
523 | OpKernelContext* const ctx_; |
524 | const FHandle cond_handle_; |
525 | const FHandle body_handle_; |
526 | const DoneCallback done_; |
527 | FunctionLibraryRuntime* const lib_; |
528 | FunctionLibraryRuntime::Options opts_; |
529 | TensorVec args_; |
530 | TensorVec rets_; |
531 | DataTypeVector loop_var_types_; |
532 | std::unique_ptr<BodyFuncCallFrame> body_frame_; |
533 | |
534 | void EvalCond() { |
535 | profiler::TraceMe trace_me("WhileOp-EvalCond" ); |
536 | lib_->Run( |
537 | // Evaluate the condition. |
538 | opts_, cond_handle_, args_, &rets_, |
539 | // Done cb. |
540 | [this](const Status& s) { |
541 | if (!s.ok()) { |
542 | return Finish(s); |
543 | } |
544 | StartBody(); |
545 | }); |
546 | } |
547 | |
548 | void StartBody() { |
549 | Status s; |
550 | if (rets_.size() != 1) { |
551 | s = errors::InvalidArgument( |
552 | "Expected a single scalar return value from WhileOp cond, got " , |
553 | rets_.size(), " tensors." ); |
554 | return Finish(s); |
555 | } |
556 | |
557 | if (!s.ok()) { |
558 | return Finish(s); |
559 | } |
560 | bool cond; |
561 | s = CondResultToBool(ctx_, opts_, rets_[0], &cond); |
562 | if (!s.ok()) { |
563 | return Finish(s); |
564 | } |
565 | |
566 | if (!cond) { |
567 | return Finish(OkStatus()); |
568 | } |
569 | rets_.clear(); |
570 | rets_.resize(args_.size()); |
571 | profiler::TraceMe trace_me("WhileOp-StartBody" ); |
572 | lib_->Run( |
573 | // Evaluate the body. |
574 | opts_, body_handle_, body_frame_.get(), |
575 | // Done callback |
576 | [this](const Status& s) { |
577 | if (!s.ok()) { |
578 | return Finish(s); |
579 | } |
580 | if (args_.size() != rets_.size()) { |
581 | return Finish(errors::InvalidArgument( |
582 | "While loop body returned " , rets_.size(), |
583 | " arguments. Expected: " , args_.size())); |
584 | } |
585 | args_.clear(); |
586 | using std::swap; |
587 | swap(args_, rets_); |
588 | EvalCond(); |
589 | }); |
590 | } |
591 | |
592 | void Finish(Status s) { |
593 | if (s.ok()) { |
594 | s = SetOutputs(kernel_, ctx_, args_); |
595 | } |
596 | ctx_->SetStatus(s); |
597 | done_(); |
598 | delete this; |
599 | } |
600 | }; |
601 | |
602 | Status DoComputeSync(OpKernelContext* ctx) { |
603 | FHandle cond_handle; |
604 | FHandle body_handle; |
605 | TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle)); |
606 | auto lib = ctx->function_library(); |
607 | FunctionLibraryRuntime::Options opts; |
608 | SetRunOptions(ctx, &opts, false /* always_collect_stats */); |
609 | |
610 | // Pre-allocate argument and return value vectors for the cond and body |
611 | // functions. |
612 | std::vector<Tensor> args; |
613 | const int num_loop_vars = ctx->num_inputs(); |
614 | DataTypeVector loop_var_types(num_loop_vars); |
615 | GetArgsFromContext(ctx, &args, &loop_var_types); |
616 | std::vector<Tensor> cond_rets; |
617 | cond_rets.reserve(1); |
618 | std::vector<Tensor> body_rets; |
619 | body_rets.reserve(num_loop_vars); |
620 | |
621 | // Implement the logic of the while loop as a single C++ do-while loop that |
622 | // executes the cond and body functions synchronously. |
623 | do { |
624 | // Evaluate the cond function on the current loop variables. |
625 | { |
626 | profiler::TraceMe trace_me("WhileOp-EvalCond" ); |
627 | TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets)); |
628 | } |
629 | if (cond_rets.size() != 1) { |
630 | return errors::InvalidArgument( |
631 | "Expected a single scalar return value from WhileOp cond, got " , |
632 | cond_rets.size(), " tensors." ); |
633 | } |
634 | |
635 | // If the cond function evaluates to false, we are done: output the |
636 | // current loop variables. |
637 | bool cond_result; |
638 | TF_RETURN_IF_ERROR( |
639 | CondResultToBool(ctx, opts, cond_rets[0], &cond_result)); |
640 | if (!cond_result) { |
641 | return SetOutputs(this, ctx, args); |
642 | } |
643 | |
644 | // Evaluate the body function on the current loop variables, to get an |
645 | // updated vector of loop variables. |
646 | { |
647 | profiler::TraceMe trace_me("WhileOp-StartBody" ); |
648 | body_rets.resize(num_loop_vars); |
649 | BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types); |
650 | TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame)); |
651 | } |
652 | std::swap(body_rets, args); |
653 | body_rets.clear(); |
654 | } while (true); |
655 | } |
656 | |
657 | Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, |
658 | FHandle* body_handle) { |
659 | // TODO(b/37549631): Because this op has `SetIsStateful()` in its |
660 | // op registration, this kernel may be shared by multiple |
661 | // subgraphs, which have different associated |
662 | // `FunctionLibraryRuntime` objects and hence different `FHandle` |
663 | // namespaces. We currently work around this by caching the map |
664 | // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two |
665 | // functions this op uses. |
666 | auto lib = ctx->function_library(); |
667 | if (lib == nullptr) return errors::Internal("No function library" ); |
668 | *cond_handle = kInvalidHandle; |
669 | *body_handle = kInvalidHandle; |
670 | { |
671 | tf_shared_lock l(mu_); |
672 | const auto iter = handles_.find(lib); |
673 | if (TF_PREDICT_TRUE(iter != handles_.end())) { |
674 | *cond_handle = iter->second.first; |
675 | *body_handle = iter->second.second; |
676 | } |
677 | } |
678 | if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) { |
679 | mutex_lock l(mu_); |
680 | const auto iter = handles_.find(lib); |
681 | if (TF_PREDICT_TRUE(iter != handles_.end())) { |
682 | *cond_handle = iter->second.first; |
683 | *body_handle = iter->second.second; |
684 | } else { |
685 | TF_RETURN_IF_ERROR(Instantiate(ctx, cond_func_, cond_handle)); |
686 | TF_RETURN_IF_ERROR(Instantiate(ctx, body_func_, body_handle)); |
687 | handles_[lib] = {*cond_handle, *body_handle}; |
688 | } |
689 | } |
690 | return OkStatus(); |
691 | } |
692 | }; |
693 | // TODO(drpng): remove these. |
694 | REGISTER_KERNEL_BUILDER(Name("_While" ).Device(DEVICE_CPU), WhileOp); |
695 | REGISTER_KERNEL_BUILDER(Name("_While" ).Device(DEVICE_DEFAULT), WhileOp); |
696 | |
697 | REGISTER_KERNEL_BUILDER(Name("While" ).Device(DEVICE_CPU), WhileOp); |
698 | REGISTER_KERNEL_BUILDER(Name("While" ).Device(DEVICE_DEFAULT), WhileOp); |
699 | |
700 | REGISTER_KERNEL_BUILDER(Name("StatelessWhile" ).Device(DEVICE_CPU), WhileOp); |
701 | REGISTER_KERNEL_BUILDER(Name("StatelessWhile" ).Device(DEVICE_DEFAULT), WhileOp); |
702 | |
703 | class ToBoolOp : public OpKernel { |
704 | public: |
705 | explicit ToBoolOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
706 | void Compute(OpKernelContext* ctx) override { |
707 | bool b; |
708 | OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &b)); |
709 | Tensor* out; |
710 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); |
711 | out->scalar<bool>()() = b; |
712 | } |
713 | }; |
714 | |
715 | REGISTER_KERNEL_BUILDER(Name("ToBool" ).Device(DEVICE_CPU), ToBoolOp); |
716 | |
717 | Status GetScalar(OpKernelContext* ctx, int index, int32* value, |
718 | const char* label) { |
719 | Tensor t = ctx->input(index); |
720 | if (!TensorShapeUtils::IsScalar(t.shape())) { |
721 | return errors::InvalidArgument(label, " must be a scalar, but " , |
722 | t.shape().DebugString()); |
723 | } |
724 | *value = t.scalar<int32>()(); |
725 | return OkStatus(); |
726 | } |
727 | |
728 | class ForOp : public AsyncOpKernel { |
729 | public: |
730 | explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { |
731 | auto lib = ctx->function_library(); |
732 | OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library" )); |
733 | const NameAttrList* func; |
734 | OP_REQUIRES_OK(ctx, ctx->GetAttr("body" , &func)); |
735 | OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_)); |
736 | } |
737 | |
738 | ~ForOp() override {} |
739 | |
740 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
741 | (new State(this, ctx, done))->Start(); |
742 | } |
743 | |
744 | private: |
745 | FHandle body_handle_; |
746 | |
747 | class State { |
748 | public: |
749 | State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done) |
750 | : kernel_(kernel), |
751 | ctx_(ctx), |
752 | done_(std::move(done)), |
753 | lib_(CHECK_NOTNULL(ctx_->function_library())), |
754 | args_(1 + ctx_->num_inputs() - 3) { |
755 | args_[0] = Tensor(DT_INT32, {}); |
756 | iter_ = &args_[0].scalar<int32>()(); |
757 | |
758 | const int32_t num_loop_inputs = ctx_->num_inputs() - 3; |
759 | rets_.reserve(num_loop_inputs); |
760 | for (int i = 0; i < num_loop_inputs; ++i) { |
761 | rets_.push_back(ctx_->input(3 + i)); |
762 | } |
763 | } |
764 | |
765 | ~State() {} |
766 | |
767 | void Start() { |
768 | Status s = StartLoop(); |
769 | if (!s.ok()) Finish(s); |
770 | } |
771 | |
772 | private: |
773 | ForOp* const kernel_; |
774 | OpKernelContext* const ctx_; |
775 | const DoneCallback done_; |
776 | FunctionLibraryRuntime* const lib_; |
777 | FunctionLibraryRuntime::Options opts_; |
778 | TensorVec args_; |
779 | TensorVec rets_; |
780 | |
781 | int32* iter_; // points to args_[0]. |
782 | int32 limit_; |
783 | int32 delta_; |
784 | |
785 | // If an error e is returned, caller must call Finish(e). |
786 | // If OK is returned, the async loop execution has been started. |
787 | Status StartLoop() { |
788 | SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); |
789 | |
790 | TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start" )); |
791 | TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit" )); |
792 | TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta" )); |
793 | |
794 | if ((delta_ > 0 && *iter_ <= limit_) || |
795 | (delta_ < 0 && *iter_ >= limit_) || |
796 | (delta_ == 0 && *iter_ == limit_)) { |
797 | RunNext(); |
798 | return OkStatus(); |
799 | } else { |
800 | return errors::InvalidArgument("Invalid start/limit/delta: " , *iter_, |
801 | " " , limit_, " " , delta_); |
802 | } |
803 | } |
804 | |
805 | void RunNext() { |
806 | bool done_loop; |
807 | if (delta_ > 0) { |
808 | done_loop = *iter_ >= limit_; |
809 | } else { |
810 | done_loop = *iter_ <= limit_; |
811 | } |
812 | if (done_loop) { |
813 | Finish(OkStatus()); |
814 | return; |
815 | } |
816 | |
817 | if (rets_.size() >= args_.size()) { |
818 | Finish(errors::InvalidArgument( |
819 | "For loop body returned " , rets_.size(), |
820 | " arguments. Expected: " , args_.size() - 1)); |
821 | return; |
822 | } |
823 | for (int i = 0; i < rets_.size(); ++i) { |
824 | args_[1 + i] = std::move(rets_[i]); |
825 | } |
826 | rets_.clear(); |
827 | profiler::TraceMe trace_me("ForOp" ); |
828 | lib_->Run(opts_, kernel_->body_handle_, args_, &rets_, |
829 | [this](const Status& s) { |
830 | if (s.ok()) { |
831 | *iter_ += delta_; |
832 | RunNext(); |
833 | } else { |
834 | Finish(s); |
835 | } |
836 | }); |
837 | } |
838 | |
839 | void Finish(Status s) { |
840 | if (s.ok()) { |
841 | s = SetOutputs(kernel_, ctx_, rets_); |
842 | } |
843 | ctx_->SetStatus(s); |
844 | done_(); |
845 | delete this; |
846 | } |
847 | }; |
848 | }; |
849 | |
850 | REGISTER_KERNEL_BUILDER(Name("For" ).Device(DEVICE_CPU), ForOp); |
851 | REGISTER_KERNEL_BUILDER(Name("For" ) |
852 | .Device(DEVICE_DEFAULT) |
853 | .HostMemory("start" ) |
854 | .HostMemory("limit" ) |
855 | .HostMemory("delta" ), |
856 | ForOp); |
857 | |
858 | // FakeParamOp allocates a tensor with a shape conforming to the expected |
859 | // output. This is necessary if the value will be stored in a while_loop's |
860 | // TensorList. The output is otherwise not expected to be consumed by anything |
861 | // else. |
862 | class FakeParamOp : public OpKernel { |
863 | public: |
864 | explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) { |
865 | DataType dtype; |
866 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype)); |
867 | |
868 | // Set shape to the specified shape, setting unknown dimensions to empty. |
869 | // If the specified shape is unknown, leave as an empty shape. |
870 | TensorShape shape; |
871 | PartialTensorShape partial_shape; |
872 | OP_REQUIRES_OK(context, context->GetAttr("shape" , &partial_shape)); |
873 | if (!partial_shape.unknown_rank()) { |
874 | for (int64_t d : partial_shape.dim_sizes()) { |
875 | shape.AddDim(d == -1 ? 0 : d); |
876 | } |
877 | } |
878 | |
879 | // Create a tensor that we can repeatedly return to save memory. |
880 | // TODO(b/119612758): add optimization to prevent sending this across |
881 | // devices on each Compute() call. |
882 | OP_REQUIRES_OK(context, context->allocate_temp(dtype, shape, &value_)); |
883 | } |
884 | |
885 | void Compute(OpKernelContext* context) override { |
886 | context->set_output(0, value_); |
887 | } |
888 | |
889 | private: |
890 | Tensor value_; |
891 | }; |
892 | |
893 | REGISTER_KERNEL_BUILDER(Name("FakeParam" ).Device(DEVICE_CPU), FakeParamOp); |
894 | REGISTER_KERNEL_BUILDER(Name("FakeParam" ).Device(DEVICE_DEFAULT), FakeParamOp); |
895 | |
896 | // DeviceIndexOP returns the current device index. |
897 | class DeviceIndexOp : public OpKernel { |
898 | public: |
899 | explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
900 | OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names" , &device_names_)); |
901 | } |
902 | |
903 | void Compute(OpKernelContext* ctx) override { |
904 | Tensor* device_name_t; |
905 | OP_REQUIRES_OK(ctx, |
906 | ctx->allocate_output(0, TensorShape({}), &device_name_t)); |
907 | DeviceNameUtils::ParsedName parsed_name; |
908 | int index = device_names_.size(); |
909 | if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) && |
910 | parsed_name.has_type) { |
911 | auto it = absl::c_find(device_names_, parsed_name.type); |
912 | if (it != device_names_.end()) { |
913 | index = it - device_names_.begin(); |
914 | } |
915 | } |
916 | device_name_t->scalar<int32>()() = index; |
917 | } |
918 | |
919 | private: |
920 | std::vector<string> device_names_; |
921 | }; |
922 | |
923 | REGISTER_KERNEL_BUILDER(Name("DeviceIndex" ).Device(DEVICE_CPU), DeviceIndexOp); |
924 | REGISTER_KERNEL_BUILDER( |
925 | Name("DeviceIndex" ).Device(DEVICE_DEFAULT).HostMemory("index" ), |
926 | DeviceIndexOp); |
927 | |
928 | } // namespace |
929 | } // namespace tensorflow |
930 | |