1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/control_flow_ops.h"
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/types.h"
22#include "tensorflow/core/platform/macros.h"
23
24namespace tensorflow {
25
26void SwitchOp::Compute(OpKernelContext* context) {
27 const Tensor& outputPorts = context->input(1);
28 OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
29 errors::InvalidArgument("The second input must be a scalar, "
30 "but it has shape ",
31 outputPorts.shape().DebugString()));
32
33 bool pred = outputPorts.scalar<bool>()();
34 int port = (pred) ? 1 : 0;
35 if (context->input_is_ref(0)) {
36 context->forward_ref_input_to_ref_output(0, port);
37 } else {
38 context->set_output(port, context->input(0));
39 }
40}
41
42void SwitchNOp::Compute(OpKernelContext* context) {
43 const Tensor& output_index_t = context->input(1);
44 OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()),
45 errors::InvalidArgument("The second input must be a scalar, "
46 "but it has shape ",
47 output_index_t.shape().DebugString()));
48 int output_index = output_index_t.scalar<int>()();
49 if (output_index < 0 || output_index >= num_outputs()) {
50 output_index = num_outputs() - 1;
51 }
52 context->set_output(output_index, context->input(0));
53}
54
55REGISTER_KERNEL_BUILDER(
56 Name("Switch").Device(DEVICE_DEFAULT).HostMemory("pred"), SwitchOp);
57REGISTER_KERNEL_BUILDER(
58 Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp);
59
60REGISTER_KERNEL_BUILDER(
61 Name("_SwitchN").Device(DEVICE_DEFAULT).HostMemory("output_index"),
62 SwitchNOp);
63
64#define REGISTER_CPU_SWITCH(type) \
65 REGISTER_KERNEL_BUILDER(Name("Switch") \
66 .Device(DEVICE_CPU) \
67 .HostMemory("pred") \
68 .TypeConstraint<type>("T"), \
69 SwitchOp) \
70 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
71 .Device(DEVICE_CPU) \
72 .HostMemory("output_index") \
73 .TypeConstraint<type>("T"), \
74 SwitchNOp)
75
76#define REGISTER_CPU_REF_SWITCH(type) \
77 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
78 .Device(DEVICE_CPU) \
79 .HostMemory("pred") \
80 .TypeConstraint<type>("T"), \
81 SwitchOp)
82
83#define REGISTER_GPU_SWITCH(type) \
84 REGISTER_KERNEL_BUILDER(Name("Switch") \
85 .Device(DEVICE_GPU) \
86 .HostMemory("pred") \
87 .TypeConstraint<type>("T"), \
88 SwitchOp) \
89 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
90 .Device(DEVICE_GPU) \
91 .HostMemory("output_index") \
92 .TypeConstraint<type>("T"), \
93 SwitchNOp)
94
95#define REGISTER_GPU_REF_SWITCH(type) \
96 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
97 .Device(DEVICE_GPU) \
98 .HostMemory("pred") \
99 .TypeConstraint<type>("T"), \
100 SwitchOp)
101
102TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
103TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
104TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
105TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
106
107TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
108TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
109TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
110TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
111TF_CALL_variant(REGISTER_GPU_SWITCH);
112TF_CALL_bool(REGISTER_GPU_SWITCH);
113TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
114
115#undef REGISTER_CPU_SWITCH
116#undef REGISTER_CPU_REF_SWITCH
117#undef REGISTER_GPU_SWITCH
118#undef REGISTER_GPU_REF_SWITCH
119
120// Special GPU kernels for int32, string & resource handles. Requiring all
121// inputs and outputs to be in host memory.
122// TODO(b/25387198): Also enable int32 in device memory.
123#define REGISTER_GPU_HOST_KERNEL(type) \
124 REGISTER_KERNEL_BUILDER(Name("Switch") \
125 .Device(DEVICE_GPU) \
126 .HostMemory("data") \
127 .HostMemory("pred") \
128 .HostMemory("output_false") \
129 .HostMemory("output_true") \
130 .TypeConstraint<type>("T"), \
131 SwitchOp) \
132 REGISTER_KERNEL_BUILDER(Name("_SwitchN") \
133 .Device(DEVICE_GPU) \
134 .HostMemory("data") \
135 .HostMemory("output_index") \
136 .HostMemory("outputs") \
137 .TypeConstraint<type>("T"), \
138 SwitchNOp)
139
140#define REGISTER_GPU_HOST_REF_KERNEL(type) \
141 REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
142 .Device(DEVICE_GPU) \
143 .HostMemory("data") \
144 .HostMemory("pred") \
145 .HostMemory("output_false") \
146 .HostMemory("output_true") \
147 .TypeConstraint<type>("T"), \
148 SwitchOp)
149
150REGISTER_GPU_HOST_KERNEL(int32);
151REGISTER_GPU_HOST_REF_KERNEL(int32);
152REGISTER_GPU_HOST_KERNEL(tstring);
153REGISTER_GPU_HOST_REF_KERNEL(tstring);
154REGISTER_GPU_HOST_KERNEL(ResourceHandle);
155
156#undef REGISTER_GPU_HOST_KERNEL
157#undef REGISTER_GPU_HOST_REF_KERNEL
158
159
160class RefSelectOp : public OpKernel {
161 public:
162 explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
163 OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
164 }
165
166 void Compute(OpKernelContext* context) override {
167 const Tensor& index_tensor = context->input(0);
168 OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()),
169 errors::InvalidArgument("Index must be a scalar, "
170 "but it has shape ",
171 index_tensor.shape().DebugString()));
172
173 int32_t index = index_tensor.scalar<int32>()();
174
175 OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
176 errors::InvalidArgument("Index must be in the range [0, ",
177 num_ref_inputs_, ") but got ", index));
178 context->forward_ref_input_to_ref_output(index + 1, 0);
179 }
180
181 bool IsExpensive() override { return false; }
182
183 ~RefSelectOp() override {}
184
185 TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
186
187 private:
188 int num_ref_inputs_;
189};
190
191#define REGISTER_CPU_REF_SELECT(type) \
192 REGISTER_KERNEL_BUILDER(Name("RefSelect") \
193 .Device(DEVICE_CPU) \
194 .HostMemory("index") \
195 .TypeConstraint<type>("T"), \
196 RefSelectOp)
197TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
198
199#undef REGISTER_CPU_REF_SWITCH
200
201MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
202 const DataType dt = context->input_type(0);
203 const int num_in = context->num_inputs();
204 OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
205 {dt, DT_INT32}));
206}
207
208void MergeOp::Compute(OpKernelContext* context) {
209 bool input_seen = false;
210 for (int i = 0; i < context->num_inputs(); ++i) {
211 if (context->has_input(i)) {
212 if (input_seen) {
213 LOG(WARNING) << "Merge op has more than one valid input. This "
214 << "indicates that the graph doesn't use merge op "
215 << "properly. Please check your graph. "
216 << FormatNodeDefForError(def());
217 return;
218 }
219 input_seen = true;
220
221 if (IsRefType(context->input_dtype(i))) {
222 context->forward_ref_input_to_ref_output(i, 0);
223 } else {
224 context->set_output(0, context->input(i));
225 }
226 // The value_index output is typically used only in gradient calculations,
227 // so we can avoid allocating in many inference workloads.
228 if (context->output_required(1)) {
229 Tensor* value_index = nullptr;
230 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
231 &value_index));
232 value_index->scalar<int32>()() = i;
233 }
234 }
235 }
236}
237
238REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
239REGISTER_KERNEL_BUILDER(
240 Name("Merge").Device(DEVICE_DEFAULT).HostMemory("value_index"), MergeOp);
241REGISTER_KERNEL_BUILDER(
242 Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp);
243REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
244
245#define REGISTER_GPU_KERNEL(type) \
246 REGISTER_KERNEL_BUILDER(Name("Merge") \
247 .Device(DEVICE_GPU) \
248 .TypeConstraint<type>("T") \
249 .HostMemory("value_index"), \
250 MergeOp);
251
252#define REGISTER_GPU_REF_KERNEL(type) \
253 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
254 .Device(DEVICE_GPU) \
255 .TypeConstraint<type>("T") \
256 .HostMemory("value_index"), \
257 MergeOp);
258
259TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
260TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
261TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
262TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
263REGISTER_GPU_KERNEL(bool);
264REGISTER_GPU_REF_KERNEL(bool);
265TF_CALL_variant(REGISTER_GPU_KERNEL);
266
267#undef REGISTER_GPU_KERNEL
268#undef REGISTER_GPU_REF_KERNEL
269
270
271// Special GPU kernels for int32 and string.
272// TODO(b/25387198): Also enable int32 in device memory. This kernel
273// registration requires all int32 inputs and outputs to be in host memory.
274#define REGISTER_GPU_HOST_KERNEL(type) \
275 REGISTER_KERNEL_BUILDER(Name("Merge") \
276 .Device(DEVICE_GPU) \
277 .HostMemory("inputs") \
278 .HostMemory("output") \
279 .HostMemory("value_index") \
280 .TypeConstraint<type>("T"), \
281 MergeOp); \
282 REGISTER_KERNEL_BUILDER(Name("RefMerge") \
283 .Device(DEVICE_GPU) \
284 .HostMemory("inputs") \
285 .HostMemory("output") \
286 .HostMemory("value_index") \
287 .TypeConstraint<type>("T"), \
288 MergeOp)
289
290REGISTER_GPU_HOST_KERNEL(int32);
291REGISTER_GPU_HOST_KERNEL(tstring);
292REGISTER_GPU_HOST_KERNEL(ResourceHandle);
293
294#undef REGISTER_GPU_HOST_KERNEL
295
296
297void EnterOp::Compute(OpKernelContext* context) {
298 if (IsRefType(context->input_dtype(0))) {
299 context->forward_ref_input_to_ref_output(0, 0);
300 } else {
301 context->set_output(0, context->input(0));
302 }
303}
304
305REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_DEFAULT), EnterOp);
306REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp);
307REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
308
309#define REGISTER_GPU_KERNEL(type) \
310 REGISTER_KERNEL_BUILDER( \
311 Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
312#define REGISTER_GPU_REF_KERNEL(type) \
313 REGISTER_KERNEL_BUILDER( \
314 Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp)
315
316TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
317TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
318REGISTER_GPU_KERNEL(bool);
319REGISTER_GPU_REF_KERNEL(bool);
320TF_CALL_variant(REGISTER_GPU_KERNEL);
321
322#undef REGISTER_GPU_KERNEL
323#undef REGISTER_GPU_REF_KERNEL
324
325
326// Special GPU kernels for int32 and string.
327// TODO(b/25387198): Also enable int32 in device memory. This kernel
328// registration requires all int32 inputs and outputs to be in host memory.
329#define REGISTER_GPU_HOST_KERNEL(type) \
330 REGISTER_KERNEL_BUILDER(Name("Enter") \
331 .Device(DEVICE_GPU) \
332 .HostMemory("data") \
333 .HostMemory("output") \
334 .TypeConstraint<type>("T"), \
335 EnterOp)
336
337#define REGISTER_GPU_HOST_REF_KERNEL(type) \
338 REGISTER_KERNEL_BUILDER(Name("RefEnter") \
339 .Device(DEVICE_GPU) \
340 .HostMemory("data") \
341 .HostMemory("output") \
342 .TypeConstraint<type>("T"), \
343 EnterOp)
344
345REGISTER_GPU_HOST_KERNEL(int32);
346REGISTER_GPU_HOST_REF_KERNEL(int32);
347REGISTER_GPU_HOST_KERNEL(tstring);
348REGISTER_GPU_HOST_REF_KERNEL(tstring);
349REGISTER_GPU_HOST_KERNEL(ResourceHandle);
350
351#undef REGISTER_GPU_HOST_KERNEL
352#undef REGISTER_GPU_HOST_REF_KERNEL
353
354void ExitOp::Compute(OpKernelContext* context) {
355 if (IsRefType(context->input_dtype(0))) {
356 context->forward_ref_input_to_ref_output(0, 0);
357 } else {
358 context->set_output(0, context->input(0));
359 }
360}
361
362REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_DEFAULT), ExitOp);
363REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp);
364REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
365
366#define REGISTER_GPU_KERNEL(type) \
367 REGISTER_KERNEL_BUILDER( \
368 Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
369#define REGISTER_GPU_REF_KERNEL(type) \
370 REGISTER_KERNEL_BUILDER( \
371 Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
372
373TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
374TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL);
375REGISTER_GPU_KERNEL(bool);
376REGISTER_GPU_REF_KERNEL(bool);
377TF_CALL_variant(REGISTER_GPU_KERNEL);
378
379#undef REGISTER_GPU_KERNEL
380#undef REGISTER_GPU_REF_KERNEL
381
382
383// Special GPU kernels for int32 and string.
384// TODO(b/25387198): Also enable int32 in device memory. This kernel
385// registration requires all int32 inputs and outputs to be in host memory.
386#define REGISTER_GPU_HOST_KERNEL(type) \
387 REGISTER_KERNEL_BUILDER(Name("Exit") \
388 .Device(DEVICE_GPU) \
389 .HostMemory("data") \
390 .HostMemory("output") \
391 .TypeConstraint<type>("T"), \
392 ExitOp); \
393 REGISTER_KERNEL_BUILDER(Name("RefExit") \
394 .Device(DEVICE_GPU) \
395 .HostMemory("data") \
396 .HostMemory("output") \
397 .TypeConstraint<type>("T"), \
398 ExitOp)
399
400REGISTER_GPU_HOST_KERNEL(int32);
401REGISTER_GPU_HOST_KERNEL(tstring);
402REGISTER_GPU_HOST_KERNEL(ResourceHandle);
403
404#undef REGISTER_GPU_HOST_KERNEL
405
406void NextIterationOp::Compute(OpKernelContext* context) {
407 if (IsRefType(context->input_dtype(0))) {
408 context->forward_ref_input_to_ref_output(0, 0);
409 } else {
410 context->set_output(0, context->input(0));
411 }
412}
413
414REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_DEFAULT),
415 NextIterationOp);
416REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM),
417 NextIterationOp);
418REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
419 NextIterationOp);
420
421#define REGISTER_GPU_KERNEL(type) \
422 REGISTER_KERNEL_BUILDER( \
423 Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
424 NextIterationOp); \
425 REGISTER_KERNEL_BUILDER( \
426 Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
427 NextIterationOp)
428
429TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
430REGISTER_GPU_KERNEL(bool);
431TF_CALL_variant(REGISTER_GPU_KERNEL);
432
433#undef REGISTER_GPU_KERNEL
434
435// Special GPU kernels for int32 and string.
436// TODO(b/25387198): Also enable int32 in device memory. This kernel
437// registration requires all int32 inputs and outputs to be in host memory.
438#define REGISTER_GPU_HOST_KERNEL(type) \
439 REGISTER_KERNEL_BUILDER(Name("NextIteration") \
440 .Device(DEVICE_GPU) \
441 .HostMemory("data") \
442 .HostMemory("output") \
443 .TypeConstraint<type>("T"), \
444 NextIterationOp); \
445 REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \
446 .Device(DEVICE_GPU) \
447 .HostMemory("data") \
448 .HostMemory("output") \
449 .TypeConstraint<type>("T"), \
450 NextIterationOp)
451
452REGISTER_GPU_HOST_KERNEL(int32);
453REGISTER_GPU_HOST_KERNEL(tstring);
454REGISTER_GPU_HOST_KERNEL(ResourceHandle);
455
456#undef REGISTER_GPU_HOST_KERNEL
457
458
459LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
460LoopCondOp::~LoopCondOp() = default;
461
462void LoopCondOp::Compute(OpKernelContext* context) {
463 CancellationManager* cm = context->cancellation_manager();
464 if (cm != nullptr) {
465 bool already_cancelled = cm->IsCancelled();
466 OP_REQUIRES(context, !already_cancelled,
467 errors::Cancelled("Loop execution was cancelled."));
468 }
469
470 context->set_output(0, context->input(0));
471}
472
473bool LoopCondOp::IsExpensive() { return false; }
474
475REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
476REGISTER_KERNEL_BUILDER(Name("LoopCond")
477 .Device(DEVICE_DEFAULT)
478 .HostMemory("input")
479 .HostMemory("output"),
480 LoopCondOp);
481REGISTER_KERNEL_BUILDER(Name("LoopCond")
482 .Device(DEVICE_TPU_SYSTEM)
483 .HostMemory("input")
484 .HostMemory("output"),
485 LoopCondOp);
486
487// ControlTrigger kernel
488REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_DEFAULT),
489 ControlTriggerOp);
490REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_TPU_SYSTEM),
491 ControlTriggerOp);
492
493// When called, abort op will abort the current process. This can be used to
494// abort remote PSs when needed.
495class AbortOp : public OpKernel {
496 public:
497 explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) {
498 OP_REQUIRES_OK(context, context->GetAttr("error_msg", &error_msg_));
499 OP_REQUIRES_OK(
500 context, context->GetAttr("exit_without_error", &exit_without_error_));
501 }
502
503 void Compute(OpKernelContext* context) override {
504 if (!exit_without_error_) {
505 LOG(FATAL) << "Abort_op intentional failure; " << error_msg_;
506 } else {
507 LOG(WARNING) << "Exiting the process: " << error_msg_;
508 exit(0);
509 }
510 }
511
512 private:
513 string error_msg_;
514 bool exit_without_error_;
515};
516
517REGISTER_KERNEL_BUILDER(Name("Abort").Device(DEVICE_CPU), AbortOp);
518
519} // namespace tensorflow
520