1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/control_flow_ops.h" |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/types.h" |
22 | #include "tensorflow/core/platform/macros.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | void SwitchOp::Compute(OpKernelContext* context) { |
27 | const Tensor& outputPorts = context->input(1); |
28 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()), |
29 | errors::InvalidArgument("The second input must be a scalar, " |
30 | "but it has shape " , |
31 | outputPorts.shape().DebugString())); |
32 | |
33 | bool pred = outputPorts.scalar<bool>()(); |
34 | int port = (pred) ? 1 : 0; |
35 | if (context->input_is_ref(0)) { |
36 | context->forward_ref_input_to_ref_output(0, port); |
37 | } else { |
38 | context->set_output(port, context->input(0)); |
39 | } |
40 | } |
41 | |
42 | void SwitchNOp::Compute(OpKernelContext* context) { |
43 | const Tensor& output_index_t = context->input(1); |
44 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_index_t.shape()), |
45 | errors::InvalidArgument("The second input must be a scalar, " |
46 | "but it has shape " , |
47 | output_index_t.shape().DebugString())); |
48 | int output_index = output_index_t.scalar<int>()(); |
49 | if (output_index < 0 || output_index >= num_outputs()) { |
50 | output_index = num_outputs() - 1; |
51 | } |
52 | context->set_output(output_index, context->input(0)); |
53 | } |
54 | |
55 | REGISTER_KERNEL_BUILDER( |
56 | Name("Switch" ).Device(DEVICE_DEFAULT).HostMemory("pred" ), SwitchOp); |
57 | REGISTER_KERNEL_BUILDER( |
58 | Name("Switch" ).Device(DEVICE_TPU_SYSTEM).HostMemory("pred" ), SwitchOp); |
59 | |
60 | REGISTER_KERNEL_BUILDER( |
61 | Name("_SwitchN" ).Device(DEVICE_DEFAULT).HostMemory("output_index" ), |
62 | SwitchNOp); |
63 | |
64 | #define REGISTER_CPU_SWITCH(type) \ |
65 | REGISTER_KERNEL_BUILDER(Name("Switch") \ |
66 | .Device(DEVICE_CPU) \ |
67 | .HostMemory("pred") \ |
68 | .TypeConstraint<type>("T"), \ |
69 | SwitchOp) \ |
70 | REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ |
71 | .Device(DEVICE_CPU) \ |
72 | .HostMemory("output_index") \ |
73 | .TypeConstraint<type>("T"), \ |
74 | SwitchNOp) |
75 | |
76 | #define REGISTER_CPU_REF_SWITCH(type) \ |
77 | REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ |
78 | .Device(DEVICE_CPU) \ |
79 | .HostMemory("pred") \ |
80 | .TypeConstraint<type>("T"), \ |
81 | SwitchOp) |
82 | |
83 | #define REGISTER_GPU_SWITCH(type) \ |
84 | REGISTER_KERNEL_BUILDER(Name("Switch") \ |
85 | .Device(DEVICE_GPU) \ |
86 | .HostMemory("pred") \ |
87 | .TypeConstraint<type>("T"), \ |
88 | SwitchOp) \ |
89 | REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ |
90 | .Device(DEVICE_GPU) \ |
91 | .HostMemory("output_index") \ |
92 | .TypeConstraint<type>("T"), \ |
93 | SwitchNOp) |
94 | |
95 | #define REGISTER_GPU_REF_SWITCH(type) \ |
96 | REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ |
97 | .Device(DEVICE_GPU) \ |
98 | .HostMemory("pred") \ |
99 | .TypeConstraint<type>("T"), \ |
100 | SwitchOp) |
101 | |
102 | TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH); |
103 | TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); |
104 | TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH); |
105 | TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH); |
106 | |
107 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); |
108 | TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH); |
109 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); |
110 | TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH); |
111 | TF_CALL_variant(REGISTER_GPU_SWITCH); |
112 | TF_CALL_bool(REGISTER_GPU_SWITCH); |
113 | TF_CALL_bool(REGISTER_GPU_REF_SWITCH); |
114 | |
115 | #undef REGISTER_CPU_SWITCH |
116 | #undef REGISTER_CPU_REF_SWITCH |
117 | #undef REGISTER_GPU_SWITCH |
118 | #undef REGISTER_GPU_REF_SWITCH |
119 | |
120 | // Special GPU kernels for int32, string & resource handles. Requiring all |
121 | // inputs and outputs to be in host memory. |
122 | // TODO(b/25387198): Also enable int32 in device memory. |
123 | #define REGISTER_GPU_HOST_KERNEL(type) \ |
124 | REGISTER_KERNEL_BUILDER(Name("Switch") \ |
125 | .Device(DEVICE_GPU) \ |
126 | .HostMemory("data") \ |
127 | .HostMemory("pred") \ |
128 | .HostMemory("output_false") \ |
129 | .HostMemory("output_true") \ |
130 | .TypeConstraint<type>("T"), \ |
131 | SwitchOp) \ |
132 | REGISTER_KERNEL_BUILDER(Name("_SwitchN") \ |
133 | .Device(DEVICE_GPU) \ |
134 | .HostMemory("data") \ |
135 | .HostMemory("output_index") \ |
136 | .HostMemory("outputs") \ |
137 | .TypeConstraint<type>("T"), \ |
138 | SwitchNOp) |
139 | |
140 | #define REGISTER_GPU_HOST_REF_KERNEL(type) \ |
141 | REGISTER_KERNEL_BUILDER(Name("RefSwitch") \ |
142 | .Device(DEVICE_GPU) \ |
143 | .HostMemory("data") \ |
144 | .HostMemory("pred") \ |
145 | .HostMemory("output_false") \ |
146 | .HostMemory("output_true") \ |
147 | .TypeConstraint<type>("T"), \ |
148 | SwitchOp) |
149 | |
150 | REGISTER_GPU_HOST_KERNEL(int32); |
151 | REGISTER_GPU_HOST_REF_KERNEL(int32); |
152 | REGISTER_GPU_HOST_KERNEL(tstring); |
153 | REGISTER_GPU_HOST_REF_KERNEL(tstring); |
154 | REGISTER_GPU_HOST_KERNEL(ResourceHandle); |
155 | |
156 | #undef REGISTER_GPU_HOST_KERNEL |
157 | #undef REGISTER_GPU_HOST_REF_KERNEL |
158 | |
159 | |
160 | class RefSelectOp : public OpKernel { |
161 | public: |
162 | explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) { |
163 | OP_REQUIRES_OK(context, context->GetAttr("N" , &num_ref_inputs_)); |
164 | } |
165 | |
166 | void Compute(OpKernelContext* context) override { |
167 | const Tensor& index_tensor = context->input(0); |
168 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(index_tensor.shape()), |
169 | errors::InvalidArgument("Index must be a scalar, " |
170 | "but it has shape " , |
171 | index_tensor.shape().DebugString())); |
172 | |
173 | int32_t index = index_tensor.scalar<int32>()(); |
174 | |
175 | OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_, |
176 | errors::InvalidArgument("Index must be in the range [0, " , |
177 | num_ref_inputs_, ") but got " , index)); |
178 | context->forward_ref_input_to_ref_output(index + 1, 0); |
179 | } |
180 | |
181 | bool IsExpensive() override { return false; } |
182 | |
183 | ~RefSelectOp() override {} |
184 | |
185 | TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp); |
186 | |
187 | private: |
188 | int num_ref_inputs_; |
189 | }; |
190 | |
191 | #define REGISTER_CPU_REF_SELECT(type) \ |
192 | REGISTER_KERNEL_BUILDER(Name("RefSelect") \ |
193 | .Device(DEVICE_CPU) \ |
194 | .HostMemory("index") \ |
195 | .TypeConstraint<type>("T"), \ |
196 | RefSelectOp) |
197 | TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT); |
198 | |
199 | #undef REGISTER_CPU_REF_SWITCH |
200 | |
201 | MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) { |
202 | const DataType dt = context->input_type(0); |
203 | const int num_in = context->num_inputs(); |
204 | OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), |
205 | {dt, DT_INT32})); |
206 | } |
207 | |
208 | void MergeOp::Compute(OpKernelContext* context) { |
209 | bool input_seen = false; |
210 | for (int i = 0; i < context->num_inputs(); ++i) { |
211 | if (context->has_input(i)) { |
212 | if (input_seen) { |
213 | LOG(WARNING) << "Merge op has more than one valid input. This " |
214 | << "indicates that the graph doesn't use merge op " |
215 | << "properly. Please check your graph. " |
216 | << FormatNodeDefForError(def()); |
217 | return; |
218 | } |
219 | input_seen = true; |
220 | |
221 | if (IsRefType(context->input_dtype(i))) { |
222 | context->forward_ref_input_to_ref_output(i, 0); |
223 | } else { |
224 | context->set_output(0, context->input(i)); |
225 | } |
226 | // The value_index output is typically used only in gradient calculations, |
227 | // so we can avoid allocating in many inference workloads. |
228 | if (context->output_required(1)) { |
229 | Tensor* value_index = nullptr; |
230 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), |
231 | &value_index)); |
232 | value_index->scalar<int32>()() = i; |
233 | } |
234 | } |
235 | } |
236 | } |
237 | |
238 | REGISTER_KERNEL_BUILDER(Name("Merge" ).Device(DEVICE_CPU), MergeOp); |
239 | REGISTER_KERNEL_BUILDER( |
240 | Name("Merge" ).Device(DEVICE_DEFAULT).HostMemory("value_index" ), MergeOp); |
241 | REGISTER_KERNEL_BUILDER( |
242 | Name("Merge" ).Device(DEVICE_TPU_SYSTEM).HostMemory("value_index" ), MergeOp); |
243 | REGISTER_KERNEL_BUILDER(Name("RefMerge" ).Device(DEVICE_CPU), MergeOp); |
244 | |
245 | #define REGISTER_GPU_KERNEL(type) \ |
246 | REGISTER_KERNEL_BUILDER(Name("Merge") \ |
247 | .Device(DEVICE_GPU) \ |
248 | .TypeConstraint<type>("T") \ |
249 | .HostMemory("value_index"), \ |
250 | MergeOp); |
251 | |
252 | #define REGISTER_GPU_REF_KERNEL(type) \ |
253 | REGISTER_KERNEL_BUILDER(Name("RefMerge") \ |
254 | .Device(DEVICE_GPU) \ |
255 | .TypeConstraint<type>("T") \ |
256 | .HostMemory("value_index"), \ |
257 | MergeOp); |
258 | |
259 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
260 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); |
261 | TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL); |
262 | TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL); |
263 | REGISTER_GPU_KERNEL(bool); |
264 | REGISTER_GPU_REF_KERNEL(bool); |
265 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
266 | |
267 | #undef REGISTER_GPU_KERNEL |
268 | #undef REGISTER_GPU_REF_KERNEL |
269 | |
270 | |
271 | // Special GPU kernels for int32 and string. |
272 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
273 | // registration requires all int32 inputs and outputs to be in host memory. |
274 | #define REGISTER_GPU_HOST_KERNEL(type) \ |
275 | REGISTER_KERNEL_BUILDER(Name("Merge") \ |
276 | .Device(DEVICE_GPU) \ |
277 | .HostMemory("inputs") \ |
278 | .HostMemory("output") \ |
279 | .HostMemory("value_index") \ |
280 | .TypeConstraint<type>("T"), \ |
281 | MergeOp); \ |
282 | REGISTER_KERNEL_BUILDER(Name("RefMerge") \ |
283 | .Device(DEVICE_GPU) \ |
284 | .HostMemory("inputs") \ |
285 | .HostMemory("output") \ |
286 | .HostMemory("value_index") \ |
287 | .TypeConstraint<type>("T"), \ |
288 | MergeOp) |
289 | |
290 | REGISTER_GPU_HOST_KERNEL(int32); |
291 | REGISTER_GPU_HOST_KERNEL(tstring); |
292 | REGISTER_GPU_HOST_KERNEL(ResourceHandle); |
293 | |
294 | #undef REGISTER_GPU_HOST_KERNEL |
295 | |
296 | |
297 | void EnterOp::Compute(OpKernelContext* context) { |
298 | if (IsRefType(context->input_dtype(0))) { |
299 | context->forward_ref_input_to_ref_output(0, 0); |
300 | } else { |
301 | context->set_output(0, context->input(0)); |
302 | } |
303 | } |
304 | |
305 | REGISTER_KERNEL_BUILDER(Name("Enter" ).Device(DEVICE_DEFAULT), EnterOp); |
306 | REGISTER_KERNEL_BUILDER(Name("Enter" ).Device(DEVICE_TPU_SYSTEM), EnterOp); |
307 | REGISTER_KERNEL_BUILDER(Name("RefEnter" ).Device(DEVICE_CPU), EnterOp); |
308 | |
309 | #define REGISTER_GPU_KERNEL(type) \ |
310 | REGISTER_KERNEL_BUILDER( \ |
311 | Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) |
312 | #define REGISTER_GPU_REF_KERNEL(type) \ |
313 | REGISTER_KERNEL_BUILDER( \ |
314 | Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp) |
315 | |
316 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
317 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); |
318 | REGISTER_GPU_KERNEL(bool); |
319 | REGISTER_GPU_REF_KERNEL(bool); |
320 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
321 | |
322 | #undef REGISTER_GPU_KERNEL |
323 | #undef REGISTER_GPU_REF_KERNEL |
324 | |
325 | |
326 | // Special GPU kernels for int32 and string. |
327 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
328 | // registration requires all int32 inputs and outputs to be in host memory. |
329 | #define REGISTER_GPU_HOST_KERNEL(type) \ |
330 | REGISTER_KERNEL_BUILDER(Name("Enter") \ |
331 | .Device(DEVICE_GPU) \ |
332 | .HostMemory("data") \ |
333 | .HostMemory("output") \ |
334 | .TypeConstraint<type>("T"), \ |
335 | EnterOp) |
336 | |
337 | #define REGISTER_GPU_HOST_REF_KERNEL(type) \ |
338 | REGISTER_KERNEL_BUILDER(Name("RefEnter") \ |
339 | .Device(DEVICE_GPU) \ |
340 | .HostMemory("data") \ |
341 | .HostMemory("output") \ |
342 | .TypeConstraint<type>("T"), \ |
343 | EnterOp) |
344 | |
345 | REGISTER_GPU_HOST_KERNEL(int32); |
346 | REGISTER_GPU_HOST_REF_KERNEL(int32); |
347 | REGISTER_GPU_HOST_KERNEL(tstring); |
348 | REGISTER_GPU_HOST_REF_KERNEL(tstring); |
349 | REGISTER_GPU_HOST_KERNEL(ResourceHandle); |
350 | |
351 | #undef REGISTER_GPU_HOST_KERNEL |
352 | #undef REGISTER_GPU_HOST_REF_KERNEL |
353 | |
354 | void ExitOp::Compute(OpKernelContext* context) { |
355 | if (IsRefType(context->input_dtype(0))) { |
356 | context->forward_ref_input_to_ref_output(0, 0); |
357 | } else { |
358 | context->set_output(0, context->input(0)); |
359 | } |
360 | } |
361 | |
362 | REGISTER_KERNEL_BUILDER(Name("Exit" ).Device(DEVICE_DEFAULT), ExitOp); |
363 | REGISTER_KERNEL_BUILDER(Name("Exit" ).Device(DEVICE_TPU_SYSTEM), ExitOp); |
364 | REGISTER_KERNEL_BUILDER(Name("RefExit" ).Device(DEVICE_CPU), ExitOp); |
365 | |
366 | #define REGISTER_GPU_KERNEL(type) \ |
367 | REGISTER_KERNEL_BUILDER( \ |
368 | Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); |
369 | #define REGISTER_GPU_REF_KERNEL(type) \ |
370 | REGISTER_KERNEL_BUILDER( \ |
371 | Name("RefExit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp); |
372 | |
373 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
374 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); |
375 | REGISTER_GPU_KERNEL(bool); |
376 | REGISTER_GPU_REF_KERNEL(bool); |
377 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
378 | |
379 | #undef REGISTER_GPU_KERNEL |
380 | #undef REGISTER_GPU_REF_KERNEL |
381 | |
382 | |
383 | // Special GPU kernels for int32 and string. |
384 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
385 | // registration requires all int32 inputs and outputs to be in host memory. |
386 | #define REGISTER_GPU_HOST_KERNEL(type) \ |
387 | REGISTER_KERNEL_BUILDER(Name("Exit") \ |
388 | .Device(DEVICE_GPU) \ |
389 | .HostMemory("data") \ |
390 | .HostMemory("output") \ |
391 | .TypeConstraint<type>("T"), \ |
392 | ExitOp); \ |
393 | REGISTER_KERNEL_BUILDER(Name("RefExit") \ |
394 | .Device(DEVICE_GPU) \ |
395 | .HostMemory("data") \ |
396 | .HostMemory("output") \ |
397 | .TypeConstraint<type>("T"), \ |
398 | ExitOp) |
399 | |
400 | REGISTER_GPU_HOST_KERNEL(int32); |
401 | REGISTER_GPU_HOST_KERNEL(tstring); |
402 | REGISTER_GPU_HOST_KERNEL(ResourceHandle); |
403 | |
404 | #undef REGISTER_GPU_HOST_KERNEL |
405 | |
406 | void NextIterationOp::Compute(OpKernelContext* context) { |
407 | if (IsRefType(context->input_dtype(0))) { |
408 | context->forward_ref_input_to_ref_output(0, 0); |
409 | } else { |
410 | context->set_output(0, context->input(0)); |
411 | } |
412 | } |
413 | |
414 | REGISTER_KERNEL_BUILDER(Name("NextIteration" ).Device(DEVICE_DEFAULT), |
415 | NextIterationOp); |
416 | REGISTER_KERNEL_BUILDER(Name("NextIteration" ).Device(DEVICE_TPU_SYSTEM), |
417 | NextIterationOp); |
418 | REGISTER_KERNEL_BUILDER(Name("RefNextIteration" ).Device(DEVICE_CPU), |
419 | NextIterationOp); |
420 | |
421 | #define REGISTER_GPU_KERNEL(type) \ |
422 | REGISTER_KERNEL_BUILDER( \ |
423 | Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
424 | NextIterationOp); \ |
425 | REGISTER_KERNEL_BUILDER( \ |
426 | Name("RefNextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
427 | NextIterationOp) |
428 | |
429 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
430 | REGISTER_GPU_KERNEL(bool); |
431 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
432 | |
433 | #undef REGISTER_GPU_KERNEL |
434 | |
435 | // Special GPU kernels for int32 and string. |
436 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
437 | // registration requires all int32 inputs and outputs to be in host memory. |
438 | #define REGISTER_GPU_HOST_KERNEL(type) \ |
439 | REGISTER_KERNEL_BUILDER(Name("NextIteration") \ |
440 | .Device(DEVICE_GPU) \ |
441 | .HostMemory("data") \ |
442 | .HostMemory("output") \ |
443 | .TypeConstraint<type>("T"), \ |
444 | NextIterationOp); \ |
445 | REGISTER_KERNEL_BUILDER(Name("RefNextIteration") \ |
446 | .Device(DEVICE_GPU) \ |
447 | .HostMemory("data") \ |
448 | .HostMemory("output") \ |
449 | .TypeConstraint<type>("T"), \ |
450 | NextIterationOp) |
451 | |
452 | REGISTER_GPU_HOST_KERNEL(int32); |
453 | REGISTER_GPU_HOST_KERNEL(tstring); |
454 | REGISTER_GPU_HOST_KERNEL(ResourceHandle); |
455 | |
456 | #undef REGISTER_GPU_HOST_KERNEL |
457 | |
458 | |
459 | LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} |
460 | LoopCondOp::~LoopCondOp() = default; |
461 | |
462 | void LoopCondOp::Compute(OpKernelContext* context) { |
463 | CancellationManager* cm = context->cancellation_manager(); |
464 | if (cm != nullptr) { |
465 | bool already_cancelled = cm->IsCancelled(); |
466 | OP_REQUIRES(context, !already_cancelled, |
467 | errors::Cancelled("Loop execution was cancelled." )); |
468 | } |
469 | |
470 | context->set_output(0, context->input(0)); |
471 | } |
472 | |
473 | bool LoopCondOp::IsExpensive() { return false; } |
474 | |
475 | REGISTER_KERNEL_BUILDER(Name("LoopCond" ).Device(DEVICE_CPU), LoopCondOp); |
476 | REGISTER_KERNEL_BUILDER(Name("LoopCond" ) |
477 | .Device(DEVICE_DEFAULT) |
478 | .HostMemory("input" ) |
479 | .HostMemory("output" ), |
480 | LoopCondOp); |
481 | REGISTER_KERNEL_BUILDER(Name("LoopCond" ) |
482 | .Device(DEVICE_TPU_SYSTEM) |
483 | .HostMemory("input" ) |
484 | .HostMemory("output" ), |
485 | LoopCondOp); |
486 | |
487 | // ControlTrigger kernel |
488 | REGISTER_KERNEL_BUILDER(Name("ControlTrigger" ).Device(DEVICE_DEFAULT), |
489 | ControlTriggerOp); |
490 | REGISTER_KERNEL_BUILDER(Name("ControlTrigger" ).Device(DEVICE_TPU_SYSTEM), |
491 | ControlTriggerOp); |
492 | |
493 | // When called, abort op will abort the current process. This can be used to |
494 | // abort remote PSs when needed. |
495 | class AbortOp : public OpKernel { |
496 | public: |
497 | explicit AbortOp(OpKernelConstruction* context) : OpKernel(context) { |
498 | OP_REQUIRES_OK(context, context->GetAttr("error_msg" , &error_msg_)); |
499 | OP_REQUIRES_OK( |
500 | context, context->GetAttr("exit_without_error" , &exit_without_error_)); |
501 | } |
502 | |
503 | void Compute(OpKernelContext* context) override { |
504 | if (!exit_without_error_) { |
505 | LOG(FATAL) << "Abort_op intentional failure; " << error_msg_; |
506 | } else { |
507 | LOG(WARNING) << "Exiting the process: " << error_msg_; |
508 | exit(0); |
509 | } |
510 | } |
511 | |
512 | private: |
513 | string error_msg_; |
514 | bool exit_without_error_; |
515 | }; |
516 | |
517 | REGISTER_KERNEL_BUILDER(Name("Abort" ).Device(DEVICE_CPU), AbortOp); |
518 | |
519 | } // namespace tensorflow |
520 | |