1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/types.h"
16#define EIGEN_USE_THREADS
17
18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19#include "tensorflow/core/common_runtime/device.h"
20#include "tensorflow/core/framework/device_base.h"
21#include "tensorflow/core/framework/function.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/threadpool.h"
27#include "tensorflow/core/platform/casts.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/platform/macros.h"
30#include "tensorflow/core/profiler/lib/traceme.h"
31
32namespace tensorflow {
33typedef Eigen::GpuDevice GPUDevice;
34typedef Eigen::ThreadPoolDevice CPUDevice;
35typedef FunctionLibraryRuntime::Handle FHandle;
36typedef std::vector<Tensor> TensorVec;
37
38namespace {
39
40// Helper to instantiate function "func" in the library "lib".
41Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
42 FunctionLibraryRuntime::Handle* handle) {
43 return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
44}
45
46Status Instantiate(OpKernelContext* ctx, const NameAttrList& func,
47 FunctionLibraryRuntime::Handle* handle) {
48 FunctionLibraryRuntime::InstantiateOptions opts;
49 opts.executor_type = ctx->executor_type();
50 return ctx->function_library()->Instantiate(
51 func.name(), AttrSlice(&func.attr()), opts, handle);
52}
53
54// If "t" is a scalar of a supported type, returns t != 0 in "*v".
55Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
56 if (t.size() != 1) {
57 return errors::InvalidArgument(
58 "Expected a single scalar which can be converted to a boolean, got ",
59 t.size(), " tensors.");
60 }
61 if (TensorShapeUtils::IsScalar(t[0].shape())) {
62 switch (t[0].dtype()) {
63#define CASE(T) \
64 case DataTypeToEnum<T>::value: \
65 *v = t[0].scalar<T>()() != 0; \
66 break;
67
68 CASE(float);
69 CASE(double);
70 CASE(int32);
71 CASE(uint8);
72 CASE(int16);
73 CASE(int8);
74 CASE(int64_t);
75#undef CASE
76 case DT_BOOL:
77 *v = t[0].scalar<bool>()();
78 break;
79 case DT_STRING:
80 *v = !t[0].scalar<tstring>()().empty();
81 break;
82 default:
83 return errors::InvalidArgument(DataTypeString(t[0].dtype()),
84 " cannot be converted to a boolean");
85 }
86 } else {
87 *v = t[0].NumElements() > 0;
88 }
89 return OkStatus();
90}
91
92// Sets "rets" to be the output of "ctx". Validates rets' types based
93// on "kernel".
94Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
95 gtl::ArraySlice<Tensor> rets) {
96 if (rets.size() != ctx->num_outputs()) {
97 return errors::Internal("Expect to produce ", ctx->num_outputs(),
98 " tensors, but only get ", rets.size());
99 }
100 for (int i = 0; i < rets.size(); ++i) {
101 if (rets[i].dtype() != kernel->output_type(i)) {
102 return errors::Internal("Expect ", i, "-th output is of type ",
103 DataTypeString(kernel->output_type(i)),
104 " but get ", DataTypeString(rets[i].dtype()));
105 }
106 ctx->set_output(i, rets[i]);
107 }
108 return OkStatus();
109}
110
111void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
112 bool always_collect_stats) {
113 opts->rendezvous = ctx->rendezvous();
114 opts->cancellation_manager = ctx->cancellation_manager();
115 opts->collective_executor = ctx->collective_executor();
116 if (always_collect_stats) {
117 opts->stats_collector = ctx->stats_collector();
118 }
119 opts->runner = ctx->runner();
120 opts->run_all_kernels_inline = ctx->run_all_kernels_inline();
121 opts->step_container = ctx->step_container();
122}
123
124class IfOp : public AsyncOpKernel {
125 public:
126 explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
127 auto lib = ctx->function_library();
128 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
129 OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
130 OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
131 }
132
133 ~IfOp() override {}
134
135 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
136 FHandle then_handle;
137 FHandle else_handle;
138 OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &then_handle, &else_handle),
139 done);
140 bool cond;
141 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
142 (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
143 }
144
145 private:
146 NameAttrList then_func_;
147 NameAttrList else_func_;
148
149 mutex mu_;
150 std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
151 handles_ ABSL_GUARDED_BY(mu_);
152
153 class State {
154 public:
155 State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
156 FHandle else_handle, DoneCallback done)
157 : kernel_(kernel),
158 ctx_(ctx),
159 cond_(cond),
160 then_handle_(then_handle),
161 else_handle_(else_handle),
162 done_(std::move(done)),
163 lib_(CHECK_NOTNULL(ctx_->function_library())) {
164 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
165 for (int i = 1; i < ctx_->num_inputs(); ++i) {
166 args_.push_back(ctx_->input(i));
167 }
168 }
169
170 ~State() {}
171
172 void Start() {
173 FHandle handle = cond_ ? then_handle_ : else_handle_;
174 rets_.clear();
175 profiler::TraceMe trace_me("IfOp");
176 lib_->Run(
177 // Evaluate one of the branch.
178 opts_, handle, args_, &rets_,
179 // Done callback
180 [this](Status s) {
181 if (s.ok()) {
182 s = SetOutputs(kernel_, ctx_, rets_);
183 }
184 ctx_->SetStatus(s);
185 DoneCallback captured_done(std::move(done_));
186 delete this;
187 captured_done();
188 });
189 }
190
191 private:
192 IfOp* const kernel_;
193 OpKernelContext* const ctx_;
194 const bool cond_;
195 FHandle then_handle_;
196 FHandle else_handle_;
197 DoneCallback done_;
198 FunctionLibraryRuntime* const lib_;
199 FunctionLibraryRuntime::Options opts_;
200 TensorVec args_;
201 TensorVec rets_;
202 };
203
204 Status GetHandles(OpKernelContext* ctx, FHandle* then_handle,
205 FHandle* else_handle) {
206 // TODO(b/37549631): Because this op has `SetIsStateful()` in its
207 // op registration, this kernel may be shared by multiple
208 // subgraphs, which have different associated
209 // `FunctionLibraryRuntime` objects and hence different `FHandle`
210 // namespaces. We currently work around this by caching the map
211 // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
212 // functions this op uses.
213 auto lib = ctx->function_library();
214 if (lib == nullptr) return errors::Internal("No function library");
215 *then_handle = kInvalidHandle;
216 *else_handle = kInvalidHandle;
217 {
218 tf_shared_lock l(mu_);
219 const auto iter = handles_.find(lib);
220 if (TF_PREDICT_TRUE(iter != handles_.end())) {
221 *then_handle = iter->second.first;
222 *else_handle = iter->second.second;
223 }
224 }
225 if (TF_PREDICT_FALSE(*then_handle == kInvalidHandle)) {
226 mutex_lock l(mu_);
227 const auto iter = handles_.find(lib);
228 if (TF_PREDICT_TRUE(iter != handles_.end())) {
229 *then_handle = iter->second.first;
230 *else_handle = iter->second.second;
231 } else {
232 TF_RETURN_IF_ERROR(Instantiate(ctx, then_func_, then_handle));
233 TF_RETURN_IF_ERROR(Instantiate(ctx, else_func_, else_handle));
234 handles_[lib] = {*then_handle, *else_handle};
235 }
236 }
237 return OkStatus();
238 }
239};
240
241class CaseOp : public AsyncOpKernel {
242 public:
243 explicit CaseOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
244 auto lib = ctx->function_library();
245 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
246 OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branch_funcs_));
247 }
248
249 ~CaseOp() override {}
250
251 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
252 auto lib = ctx->function_library();
253 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
254 errors::Internal("No function library"), done);
255
256 // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
257 // registration, this kernel may be shared by multiple subgraphs, which have
258 // different associated `FunctionLibraryRuntime` objects and hence different
259 // `FHandle` namespaces. So we must call Instantiate() to make sure we get
260 // the correct function handles with respect to `lib`. Note the underlying
261 // `lib->Instantiate()` caches the created function handles, so calling
262 // `Instantiate()` repeatedly on the same `lib` and function is cheap.
263 std::vector<FHandle> branch_handles(branch_funcs_.size());
264 for (int i = 0; i < branch_funcs_.size(); i++) {
265 OP_REQUIRES_OK_ASYNC(
266 ctx, Instantiate(lib, branch_funcs_[i], &branch_handles[i]), done);
267 }
268
269 const Tensor& branch_index = ctx->input(0);
270 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()),
271 errors::InvalidArgument("branch_index must be scalar"),
272 done);
273 int32_t branch = branch_index.scalar<int32>()();
274 (new State(this, ctx, branch, branch_handles, done))->Start();
275 }
276
277 private:
278 std::vector<NameAttrList> branch_funcs_;
279
280 class State {
281 public:
282 State(CaseOp* kernel, OpKernelContext* ctx, int branch,
283 std::vector<FHandle> branch_handles, DoneCallback done)
284 : kernel_(kernel),
285 ctx_(ctx),
286 branch_(branch),
287 branch_handles_(branch_handles),
288 done_(std::move(done)),
289 lib_(CHECK_NOTNULL(ctx_->function_library())) {
290 SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
291 for (int i = 1; i < ctx_->num_inputs(); ++i) {
292 args_.push_back(ctx_->input(i));
293 }
294 }
295
296 ~State() {}
297
298 void Start() {
299 int branch = branch_;
300 // The last branch is the default branch.
301 if (branch < 0 || branch >= branch_handles_.size()) {
302 branch = branch_handles_.size() - 1;
303 }
304 rets_.clear();
305 profiler::TraceMe trace_me("CaseOp");
306 lib_->Run(
307 // Evaluate one of the branch.
308 opts_, branch_handles_[branch], args_, &rets_,
309 // Done callback
310 [this](Status s) {
311 if (s.ok()) {
312 s = SetOutputs(kernel_, ctx_, rets_);
313 }
314 ctx_->SetStatus(s);
315 DoneCallback captured_done(std::move(done_));
316 delete this;
317 captured_done();
318 });
319 }
320
321 private:
322 CaseOp* const kernel_;
323 OpKernelContext* const ctx_;
324 const int branch_;
325 std::vector<FHandle> branch_handles_;
326 DoneCallback done_;
327 FunctionLibraryRuntime* const lib_;
328 FunctionLibraryRuntime::Options opts_;
329 TensorVec args_;
330 TensorVec rets_;
331 };
332};
333
334// TODO(drpng): remove this.
335REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), IfOp);
336REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_DEFAULT).HostMemory("cond"),
337 IfOp);
338
339REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
340REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_DEFAULT).HostMemory("cond"),
341 IfOp);
342
343REGISTER_KERNEL_BUILDER(Name("Case").Device(DEVICE_CPU), CaseOp);
344REGISTER_KERNEL_BUILDER(
345 Name("Case").Device(DEVICE_DEFAULT).HostMemory("branch_index"), CaseOp);
346REGISTER_KERNEL_BUILDER(Name("StatelessCase").Device(DEVICE_CPU), CaseOp);
347REGISTER_KERNEL_BUILDER(
348 Name("StatelessCase").Device(DEVICE_DEFAULT).HostMemory("branch_index"),
349 CaseOp);
350
351REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
352REGISTER_KERNEL_BUILDER(
353 Name("StatelessIf").Device(DEVICE_DEFAULT).HostMemory("cond"), IfOp);
354
355class WhileOp : public AsyncOpKernel {
356 public:
357 explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
358 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
359 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
360 }
361
362 ~WhileOp() override {}
363
364 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
365 if (ctx->run_all_kernels_inline()) {
366 // Use the non-callback-based implementation when kernels (and function
367 // callbacks) execute inline to avoid stack overflow.
368 OP_REQUIRES_OK_ASYNC(ctx, DoComputeSync(ctx), done);
369 done();
370 } else {
371 FHandle cond_handle;
372 FHandle body_handle;
373 OP_REQUIRES_OK_ASYNC(ctx, GetHandles(ctx, &cond_handle, &body_handle),
374 done);
375 (new State(this, ctx, cond_handle, body_handle, done))->Start();
376 }
377 }
378
379 void Compute(OpKernelContext* ctx) override {
380 // Use the non-callback-based implementation when the synchronous Compute()
381 // method is invoked, because the caller is explicitly donating a thread.
382 Status s = DoComputeSync(ctx);
383 // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is
384 // still an AsyncOpKernel, and there is a run-time check to avoid calling
385 // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in
386 // the event of an error).
387 if (TF_PREDICT_FALSE(!s.ok())) {
388 ctx->SetStatus(s);
389 }
390 }
391
392 private:
393 NameAttrList cond_func_;
394 NameAttrList body_func_;
395
396 mutex mu_;
397 std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
398 handles_ ABSL_GUARDED_BY(mu_);
399
400 static Status CondResultToBool(OpKernelContext* ctx,
401 const FunctionLibraryRuntime::Options& opts,
402 const Tensor& cond_t, bool* out_result) {
403 bool is_pluggable = ctx->op_device_context() &&
404 ctx->op_device_context()->IsPluggableDevice();
405 const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
406 ctx->device()->tensorflow_accelerator_device_info();
407 const bool is_hostmem_dtype =
408 cond_t.dtype() == DT_INT32 || cond_t.dtype() == DT_INT64;
409 if (!is_hostmem_dtype && (is_pluggable || accelerator_device_info) &&
410 (opts.rets_alloc_attrs.empty() ||
411 !opts.rets_alloc_attrs[0].on_host())) {
412 // Copy the ret value to host if it's allocated on device.
413 Device* device = down_cast<Device*>(ctx->device());
414 DeviceContext* device_ctx = ctx->op_device_context();
415 Tensor host_cond_t = Tensor(cond_t.dtype(), cond_t.shape());
416 TF_RETURN_IF_ERROR(device_ctx->CopyDeviceTensorToCPUSync(
417 &cond_t, /*tensor_name=*/"", device, &host_cond_t));
418 return ToBool({host_cond_t}, out_result);
419 }
420 return ToBool({cond_t}, out_result);
421 }
422
423 // The initial loop variable args are the inputs to the kernel.
424 //
425 // We attempt to forward the input so that it can be consumed inside the
426 // body function (and participate in buffer forwarding, etc.).
427 static void GetArgsFromContext(OpKernelContext* ctx,
428 std::vector<Tensor>* out_args,
429 DataTypeVector* out_var_types) {
430 const int num_loop_vars = ctx->num_inputs();
431 out_args->reserve(num_loop_vars);
432 out_var_types->resize(num_loop_vars);
433 for (int i = 0; i < num_loop_vars; ++i) {
434 const Tensor& input = ctx->input(i);
435 (*out_var_types)[i] = input.dtype();
436 std::unique_ptr<Tensor> maybe_forwarded_input = ctx->forward_input(
437 i, /* output_index= */ OpKernelContext::Params::kNoReservation,
438 input.dtype(), input.shape(), ctx->input_memory_type(i),
439 ctx->input_alloc_attr(i));
440 if (maybe_forwarded_input) {
441 out_args->push_back(std::move(*maybe_forwarded_input));
442 } else {
443 out_args->push_back(input);
444 }
445 }
446 }
447
448 class BodyFuncCallFrame : public CallFrameInterface {
449 public:
450 BodyFuncCallFrame(std::vector<Tensor>* args, std::vector<Tensor>* retvals,
451 DataTypeSlice ret_types)
452 : args_(args), retvals_(retvals), ret_types_(ret_types) {}
453
454 size_t num_args() const override { return args_->size(); }
455 size_t num_retvals() const override { return retvals_->size(); }
456
457 Status GetArg(int index, const Tensor** val) override {
458 if (index < args_->size()) {
459 *val = &(*args_)[index];
460 return OkStatus();
461 } else {
462 return errors::InvalidArgument("Argument ", index, " is out of range.");
463 }
464 }
465
466 void ConsumeArg(int index, Tensor* val) override {
467 DCHECK_GE(index, 0);
468 DCHECK_LT(index, args_->size());
469 *val = std::move((*args_)[index]);
470 }
471 bool CanConsumeArg(int index) const override {
472 return index >= 0 && index < args_->size();
473 }
474
475 Status SetRetval(int index, const Tensor& val) override {
476 if (TF_PREDICT_FALSE(index < 0)) {
477 return errors::InvalidArgument(
478 "Expected non-negative return value index, but got: ", index, ".");
479 } else if (TF_PREDICT_FALSE(index >= retvals_->size())) {
480 return errors::InvalidArgument("While loop body returned ", index + 1,
481 " arguments. Expected: ", num_retvals(),
482 ".");
483 } else if (TF_PREDICT_FALSE(val.dtype() != ret_types_[index])) {
484 return errors::InvalidArgument("Expected type ",
485 DataTypeString(ret_types_[index]),
486 " for return value ", index, " but got ",
487 DataTypeString(val.dtype()), ".");
488 }
489 (*retvals_)[index] = val;
490 return OkStatus();
491 }
492
493 private:
494 std::vector<Tensor>* const args_; // Not owned.
495 std::vector<Tensor>* const retvals_; // Not owned.
496 DataTypeSlice ret_types_;
497
498 TF_DISALLOW_COPY_AND_ASSIGN(BodyFuncCallFrame);
499 };
500
501 class State {
502 public:
503 State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
504 FHandle body_handle, DoneCallback done)
505 : kernel_(kernel),
506 ctx_(ctx),
507 cond_handle_(cond_handle),
508 body_handle_(body_handle),
509 done_(std::move(done)),
510 lib_(CHECK_NOTNULL(ctx_->function_library())) {
511 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
512 GetArgsFromContext(ctx, &args_, &loop_var_types_);
513 body_frame_ =
514 std::make_unique<BodyFuncCallFrame>(&args_, &rets_, loop_var_types_);
515 }
516
517 ~State() {}
518
519 void Start() { EvalCond(); }
520
521 private:
522 WhileOp* const kernel_;
523 OpKernelContext* const ctx_;
524 const FHandle cond_handle_;
525 const FHandle body_handle_;
526 const DoneCallback done_;
527 FunctionLibraryRuntime* const lib_;
528 FunctionLibraryRuntime::Options opts_;
529 TensorVec args_;
530 TensorVec rets_;
531 DataTypeVector loop_var_types_;
532 std::unique_ptr<BodyFuncCallFrame> body_frame_;
533
534 void EvalCond() {
535 profiler::TraceMe trace_me("WhileOp-EvalCond");
536 lib_->Run(
537 // Evaluate the condition.
538 opts_, cond_handle_, args_, &rets_,
539 // Done cb.
540 [this](const Status& s) {
541 if (!s.ok()) {
542 return Finish(s);
543 }
544 StartBody();
545 });
546 }
547
548 void StartBody() {
549 Status s;
550 if (rets_.size() != 1) {
551 s = errors::InvalidArgument(
552 "Expected a single scalar return value from WhileOp cond, got ",
553 rets_.size(), " tensors.");
554 return Finish(s);
555 }
556
557 if (!s.ok()) {
558 return Finish(s);
559 }
560 bool cond;
561 s = CondResultToBool(ctx_, opts_, rets_[0], &cond);
562 if (!s.ok()) {
563 return Finish(s);
564 }
565
566 if (!cond) {
567 return Finish(OkStatus());
568 }
569 rets_.clear();
570 rets_.resize(args_.size());
571 profiler::TraceMe trace_me("WhileOp-StartBody");
572 lib_->Run(
573 // Evaluate the body.
574 opts_, body_handle_, body_frame_.get(),
575 // Done callback
576 [this](const Status& s) {
577 if (!s.ok()) {
578 return Finish(s);
579 }
580 if (args_.size() != rets_.size()) {
581 return Finish(errors::InvalidArgument(
582 "While loop body returned ", rets_.size(),
583 " arguments. Expected: ", args_.size()));
584 }
585 args_.clear();
586 using std::swap;
587 swap(args_, rets_);
588 EvalCond();
589 });
590 }
591
592 void Finish(Status s) {
593 if (s.ok()) {
594 s = SetOutputs(kernel_, ctx_, args_);
595 }
596 ctx_->SetStatus(s);
597 done_();
598 delete this;
599 }
600 };
601
602 Status DoComputeSync(OpKernelContext* ctx) {
603 FHandle cond_handle;
604 FHandle body_handle;
605 TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle));
606 auto lib = ctx->function_library();
607 FunctionLibraryRuntime::Options opts;
608 SetRunOptions(ctx, &opts, false /* always_collect_stats */);
609
610 // Pre-allocate argument and return value vectors for the cond and body
611 // functions.
612 std::vector<Tensor> args;
613 const int num_loop_vars = ctx->num_inputs();
614 DataTypeVector loop_var_types(num_loop_vars);
615 GetArgsFromContext(ctx, &args, &loop_var_types);
616 std::vector<Tensor> cond_rets;
617 cond_rets.reserve(1);
618 std::vector<Tensor> body_rets;
619 body_rets.reserve(num_loop_vars);
620
621 // Implement the logic of the while loop as a single C++ do-while loop that
622 // executes the cond and body functions synchronously.
623 do {
624 // Evaluate the cond function on the current loop variables.
625 {
626 profiler::TraceMe trace_me("WhileOp-EvalCond");
627 TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets));
628 }
629 if (cond_rets.size() != 1) {
630 return errors::InvalidArgument(
631 "Expected a single scalar return value from WhileOp cond, got ",
632 cond_rets.size(), " tensors.");
633 }
634
635 // If the cond function evaluates to false, we are done: output the
636 // current loop variables.
637 bool cond_result;
638 TF_RETURN_IF_ERROR(
639 CondResultToBool(ctx, opts, cond_rets[0], &cond_result));
640 if (!cond_result) {
641 return SetOutputs(this, ctx, args);
642 }
643
644 // Evaluate the body function on the current loop variables, to get an
645 // updated vector of loop variables.
646 {
647 profiler::TraceMe trace_me("WhileOp-StartBody");
648 body_rets.resize(num_loop_vars);
649 BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types);
650 TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame));
651 }
652 std::swap(body_rets, args);
653 body_rets.clear();
654 } while (true);
655 }
656
657 Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle,
658 FHandle* body_handle) {
659 // TODO(b/37549631): Because this op has `SetIsStateful()` in its
660 // op registration, this kernel may be shared by multiple
661 // subgraphs, which have different associated
662 // `FunctionLibraryRuntime` objects and hence different `FHandle`
663 // namespaces. We currently work around this by caching the map
664 // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
665 // functions this op uses.
666 auto lib = ctx->function_library();
667 if (lib == nullptr) return errors::Internal("No function library");
668 *cond_handle = kInvalidHandle;
669 *body_handle = kInvalidHandle;
670 {
671 tf_shared_lock l(mu_);
672 const auto iter = handles_.find(lib);
673 if (TF_PREDICT_TRUE(iter != handles_.end())) {
674 *cond_handle = iter->second.first;
675 *body_handle = iter->second.second;
676 }
677 }
678 if (TF_PREDICT_FALSE(*cond_handle == kInvalidHandle)) {
679 mutex_lock l(mu_);
680 const auto iter = handles_.find(lib);
681 if (TF_PREDICT_TRUE(iter != handles_.end())) {
682 *cond_handle = iter->second.first;
683 *body_handle = iter->second.second;
684 } else {
685 TF_RETURN_IF_ERROR(Instantiate(ctx, cond_func_, cond_handle));
686 TF_RETURN_IF_ERROR(Instantiate(ctx, body_func_, body_handle));
687 handles_[lib] = {*cond_handle, *body_handle};
688 }
689 }
690 return OkStatus();
691 }
692};
693// TODO(drpng): remove these.
694REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), WhileOp);
695REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_DEFAULT), WhileOp);
696
697REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
698REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_DEFAULT), WhileOp);
699
700REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
701REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_DEFAULT), WhileOp);
702
703class ToBoolOp : public OpKernel {
704 public:
705 explicit ToBoolOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
706 void Compute(OpKernelContext* ctx) override {
707 bool b;
708 OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &b));
709 Tensor* out;
710 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
711 out->scalar<bool>()() = b;
712 }
713};
714
715REGISTER_KERNEL_BUILDER(Name("ToBool").Device(DEVICE_CPU), ToBoolOp);
716
717Status GetScalar(OpKernelContext* ctx, int index, int32* value,
718 const char* label) {
719 Tensor t = ctx->input(index);
720 if (!TensorShapeUtils::IsScalar(t.shape())) {
721 return errors::InvalidArgument(label, " must be a scalar, but ",
722 t.shape().DebugString());
723 }
724 *value = t.scalar<int32>()();
725 return OkStatus();
726}
727
728class ForOp : public AsyncOpKernel {
729 public:
730 explicit ForOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
731 auto lib = ctx->function_library();
732 OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
733 const NameAttrList* func;
734 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &func));
735 OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &body_handle_));
736 }
737
738 ~ForOp() override {}
739
740 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
741 (new State(this, ctx, done))->Start();
742 }
743
744 private:
745 FHandle body_handle_;
746
747 class State {
748 public:
749 State(ForOp* kernel, OpKernelContext* ctx, DoneCallback done)
750 : kernel_(kernel),
751 ctx_(ctx),
752 done_(std::move(done)),
753 lib_(CHECK_NOTNULL(ctx_->function_library())),
754 args_(1 + ctx_->num_inputs() - 3) {
755 args_[0] = Tensor(DT_INT32, {});
756 iter_ = &args_[0].scalar<int32>()();
757
758 const int32_t num_loop_inputs = ctx_->num_inputs() - 3;
759 rets_.reserve(num_loop_inputs);
760 for (int i = 0; i < num_loop_inputs; ++i) {
761 rets_.push_back(ctx_->input(3 + i));
762 }
763 }
764
765 ~State() {}
766
767 void Start() {
768 Status s = StartLoop();
769 if (!s.ok()) Finish(s);
770 }
771
772 private:
773 ForOp* const kernel_;
774 OpKernelContext* const ctx_;
775 const DoneCallback done_;
776 FunctionLibraryRuntime* const lib_;
777 FunctionLibraryRuntime::Options opts_;
778 TensorVec args_;
779 TensorVec rets_;
780
781 int32* iter_; // points to args_[0].
782 int32 limit_;
783 int32 delta_;
784
785 // If an error e is returned, caller must call Finish(e).
786 // If OK is returned, the async loop execution has been started.
787 Status StartLoop() {
788 SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
789
790 TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start"));
791 TF_RETURN_IF_ERROR(GetScalar(ctx_, 1, &limit_, "limit"));
792 TF_RETURN_IF_ERROR(GetScalar(ctx_, 2, &delta_, "delta"));
793
794 if ((delta_ > 0 && *iter_ <= limit_) ||
795 (delta_ < 0 && *iter_ >= limit_) ||
796 (delta_ == 0 && *iter_ == limit_)) {
797 RunNext();
798 return OkStatus();
799 } else {
800 return errors::InvalidArgument("Invalid start/limit/delta: ", *iter_,
801 " ", limit_, " ", delta_);
802 }
803 }
804
805 void RunNext() {
806 bool done_loop;
807 if (delta_ > 0) {
808 done_loop = *iter_ >= limit_;
809 } else {
810 done_loop = *iter_ <= limit_;
811 }
812 if (done_loop) {
813 Finish(OkStatus());
814 return;
815 }
816
817 if (rets_.size() >= args_.size()) {
818 Finish(errors::InvalidArgument(
819 "For loop body returned ", rets_.size(),
820 " arguments. Expected: ", args_.size() - 1));
821 return;
822 }
823 for (int i = 0; i < rets_.size(); ++i) {
824 args_[1 + i] = std::move(rets_[i]);
825 }
826 rets_.clear();
827 profiler::TraceMe trace_me("ForOp");
828 lib_->Run(opts_, kernel_->body_handle_, args_, &rets_,
829 [this](const Status& s) {
830 if (s.ok()) {
831 *iter_ += delta_;
832 RunNext();
833 } else {
834 Finish(s);
835 }
836 });
837 }
838
839 void Finish(Status s) {
840 if (s.ok()) {
841 s = SetOutputs(kernel_, ctx_, rets_);
842 }
843 ctx_->SetStatus(s);
844 done_();
845 delete this;
846 }
847 };
848};
849
850REGISTER_KERNEL_BUILDER(Name("For").Device(DEVICE_CPU), ForOp);
851REGISTER_KERNEL_BUILDER(Name("For")
852 .Device(DEVICE_DEFAULT)
853 .HostMemory("start")
854 .HostMemory("limit")
855 .HostMemory("delta"),
856 ForOp);
857
858// FakeParamOp allocates a tensor with a shape conforming to the expected
859// output. This is necessary if the value will be stored in a while_loop's
860// TensorList. The output is otherwise not expected to be consumed by anything
861// else.
862class FakeParamOp : public OpKernel {
863 public:
864 explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
865 DataType dtype;
866 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
867
868 // Set shape to the specified shape, setting unknown dimensions to empty.
869 // If the specified shape is unknown, leave as an empty shape.
870 TensorShape shape;
871 PartialTensorShape partial_shape;
872 OP_REQUIRES_OK(context, context->GetAttr("shape", &partial_shape));
873 if (!partial_shape.unknown_rank()) {
874 for (int64_t d : partial_shape.dim_sizes()) {
875 shape.AddDim(d == -1 ? 0 : d);
876 }
877 }
878
879 // Create a tensor that we can repeatedly return to save memory.
880 // TODO(b/119612758): add optimization to prevent sending this across
881 // devices on each Compute() call.
882 OP_REQUIRES_OK(context, context->allocate_temp(dtype, shape, &value_));
883 }
884
885 void Compute(OpKernelContext* context) override {
886 context->set_output(0, value_);
887 }
888
889 private:
890 Tensor value_;
891};
892
893REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
894REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_DEFAULT), FakeParamOp);
895
896// DeviceIndexOP returns the current device index.
897class DeviceIndexOp : public OpKernel {
898 public:
899 explicit DeviceIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
900 OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_));
901 }
902
903 void Compute(OpKernelContext* ctx) override {
904 Tensor* device_name_t;
905 OP_REQUIRES_OK(ctx,
906 ctx->allocate_output(0, TensorShape({}), &device_name_t));
907 DeviceNameUtils::ParsedName parsed_name;
908 int index = device_names_.size();
909 if (DeviceNameUtils::ParseFullName(ctx->device()->name(), &parsed_name) &&
910 parsed_name.has_type) {
911 auto it = absl::c_find(device_names_, parsed_name.type);
912 if (it != device_names_.end()) {
913 index = it - device_names_.begin();
914 }
915 }
916 device_name_t->scalar<int32>()() = index;
917 }
918
919 private:
920 std::vector<string> device_names_;
921};
922
923REGISTER_KERNEL_BUILDER(Name("DeviceIndex").Device(DEVICE_CPU), DeviceIndexOp);
924REGISTER_KERNEL_BUILDER(
925 Name("DeviceIndex").Device(DEVICE_DEFAULT).HostMemory("index"),
926 DeviceIndexOp);
927
928} // namespace
929} // namespace tensorflow
930