1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "absl/time/clock.h"
17#include "absl/time/time.h"
18#include "tensorflow/core/framework/common_shape_fns.h"
19#include "tensorflow/core/framework/op.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/resource_handle.pb.h"
22#include "tensorflow/core/framework/resource_mgr.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/platform/errors.h"
25#include "tensorflow/core/platform/tensor_float_32_utils.h"
26#include "tensorflow/core/public/version.h"
27
28namespace tensorflow {
29
30REGISTER_OP("KernelLabel")
31 .Output("result: string")
32 .SetShapeFn(shape_inference::ScalarShape);
33
34REGISTER_OP("KernelLabelRequired")
35 .Input("input: int32")
36 .Output("result: string")
37 .SetShapeFn([](shape_inference::InferenceContext* c) {
38 shape_inference::ShapeHandle out;
39 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
40 c->set_output(0, c->Scalar());
41 return OkStatus();
42 });
43
44REGISTER_OP("GraphDefVersion")
45 .Output("version: int32")
46 .SetIsStateful()
47 .SetShapeFn(shape_inference::ScalarShape);
48
49REGISTER_OP("RequiresOlderGraphVersion")
50 .Output("version: int32")
51 .SetIsStateful()
52 .SetShapeFn([](shape_inference::InferenceContext* c) {
53 if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) {
54 return errors::InvalidArgument("Wrong graph version for shape");
55 }
56 return shape_inference::ScalarShape(c);
57 });
58
59REGISTER_OP("Old")
60 .SetShapeFn(shape_inference::UnknownShape)
61 .Deprecated(8, "For reasons");
62
63REGISTER_OP("GetDeadline")
64 .Output("deadline_from_epoch_micros: int64")
65 .SetShapeFn(shape_inference::UnknownShape);
66
67REGISTER_OP("SleepOp")
68 .Input("sleep_seconds: int32")
69 .SetShapeFn(shape_inference::UnknownShape);
70
71REGISTER_OP("SleepIdentityOp")
72 .Input("sleep_seconds: int32")
73 .Input("input: T")
74 .Output("output: T")
75 .Attr("T: type")
76 .SetShapeFn(shape_inference::UnchangedShape);
77
78REGISTER_RESOURCE_HANDLE_OP(StubResource);
79
80REGISTER_OP("ResourceInitializedOp")
81 .Input("resource: resource")
82 .Output("initialized: bool")
83 .SetShapeFn(shape_inference::ScalarShape);
84
85REGISTER_OP("ResourceCreateOp")
86 .Input("resource: resource")
87 .SetShapeFn(shape_inference::UnknownShape);
88
89REGISTER_OP("ResourceUsingOp")
90 .Input("resource: resource")
91 .SetShapeFn(shape_inference::UnknownShape);
92
93REGISTER_OP("IsResourceHandleRefCounting")
94 .Input("handle: resource")
95 .Output("result: bool")
96 .SetShapeFn(shape_inference::ScalarShape);
97
98REGISTER_OP("MakeWeakResourceHandle")
99 .Input("handle: resource")
100 .Output("dup: resource")
101 .SetIsStateful()
102 .SetShapeFn(tensorflow::shape_inference::ScalarShape);
103
104REGISTER_OP("TestStringOutput")
105 .Input("input: float")
106 .Output("output1: float")
107 .Output("output2: string")
108 .SetShapeFn(shape_inference::UnknownShape);
109
110REGISTER_OP("Namespace>TestStringOutput")
111 .Input("input: float")
112 .Output("output1: float")
113 .Output("output2: string")
114 .SetShapeFn(shape_inference::UnknownShape);
115
116REGISTER_OP("TestAttr")
117 .Output("out: T")
118 .Attr("T: {float, double}")
119 .SetDoNotOptimize()
120 .SetShapeFn(shape_inference::UnknownShape);
121
122namespace {
123enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
124} // namespace
125
126template <KernelLabel KL>
127class KernelLabelOp : public OpKernel {
128 public:
129 using OpKernel::OpKernel;
130
131 void Compute(OpKernelContext* ctx) override {
132 Tensor* output;
133 OP_REQUIRES_OK(ctx,
134 ctx->allocate_output("result", TensorShape({}), &output));
135 switch (KL) {
136 case DEFAULT_LABEL:
137 output->scalar<tstring>()() = "My label is: default";
138 break;
139 case OVERLOAD_1_LABEL:
140 output->scalar<tstring>()() = "My label is: overload_1";
141 break;
142 case OVERLOAD_2_LABEL:
143 output->scalar<tstring>()() = "My label is: overload_2";
144 break;
145 }
146 }
147};
148
149REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU),
150 KernelLabelOp<DEFAULT_LABEL>);
151REGISTER_KERNEL_BUILDER(
152 Name("KernelLabel").Device(DEVICE_CPU).Label("overload_1"),
153 KernelLabelOp<OVERLOAD_1_LABEL>);
154REGISTER_KERNEL_BUILDER(
155 Name("KernelLabel").Device(DEVICE_CPU).Label("overload_2"),
156 KernelLabelOp<OVERLOAD_2_LABEL>);
157
158// All "KernelLabelRequired" kernels have labels
159REGISTER_KERNEL_BUILDER(
160 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"),
161 KernelLabelOp<OVERLOAD_1_LABEL>);
162REGISTER_KERNEL_BUILDER(
163 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"),
164 KernelLabelOp<OVERLOAD_2_LABEL>);
165
166class GraphDefVersionOp : public OpKernel {
167 public:
168 explicit GraphDefVersionOp(OpKernelConstruction* ctx)
169 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {}
170
171 void Compute(OpKernelContext* ctx) override {
172 Tensor* output = nullptr;
173 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
174 output->scalar<int>()() = graph_def_version_;
175 }
176
177 private:
178 const int graph_def_version_;
179};
180
181REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
182 GraphDefVersionOp);
183
184class OldOp : public OpKernel {
185 public:
186 explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
187
188 void Compute(OpKernelContext* ctx) override {}
189};
190
191REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
192
193class GetDeadlineOp : public OpKernel {
194 public:
195 explicit GetDeadlineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
196
197 void Compute(OpKernelContext* ctx) override {
198 if (!ctx->deadline()) {
199 ctx->SetStatus(errors::InvalidArgument("Deadline has not ben set."));
200 }
201 Tensor* output;
202 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
203 output->scalar<int64_t>()() = absl::ToUnixMicros(*ctx->deadline());
204 }
205};
206
207REGISTER_KERNEL_BUILDER(Name("GetDeadline").Device(DEVICE_CPU), GetDeadlineOp);
208
209class SleepOp : public OpKernel {
210 public:
211 explicit SleepOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
212
213 void Compute(OpKernelContext* ctx) override {
214 absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()()));
215 }
216};
217
218REGISTER_KERNEL_BUILDER(Name("SleepOp").Device(DEVICE_CPU), SleepOp);
219
220class SleepIdentityOp : public OpKernel {
221 public:
222 explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
223
224 void Compute(OpKernelContext* ctx) override {
225 absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()()));
226 ctx->set_output(0, ctx->input(1));
227 }
228};
229
230REGISTER_KERNEL_BUILDER(Name("SleepIdentityOp").Device(DEVICE_CPU),
231 SleepIdentityOp);
232
233// Stubbed-out resource to test resource handle ops.
234class StubResource : public ResourceBase {
235 public:
236 string DebugString() const override { return ""; }
237};
238
239REGISTER_RESOURCE_HANDLE_KERNEL(StubResource);
240
241REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp").Device(DEVICE_CPU),
242 IsResourceInitialized<StubResource>);
243
244class ResourceCreateOp : public OpKernel {
245 public:
246 explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {}
247
248 void Compute(OpKernelContext* c) override {
249 OP_REQUIRES_OK(c,
250 CreateResource(c, HandleFromInput(c, 0), new StubResource));
251 }
252};
253
254REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp").Device(DEVICE_CPU),
255 ResourceCreateOp);
256
257// Uses a ResourceHandle to check its validity.
258class ResourceUsingOp : public OpKernel {
259 public:
260 explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {}
261
262 void Compute(OpKernelContext* ctx) override {
263 StubResource* unused;
264 OP_REQUIRES_OK(ctx, LookupResource<StubResource>(
265 ctx, HandleFromInput(ctx, 0), &unused));
266 }
267};
268
269REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU),
270 ResourceUsingOp);
271
272class IsResourceHandleRefCountingOp : public OpKernel {
273 public:
274 explicit IsResourceHandleRefCountingOp(OpKernelConstruction* ctx)
275 : OpKernel(ctx) {}
276
277 void Compute(OpKernelContext* ctx) override {
278 const auto& handle = HandleFromInput(ctx, 0);
279 Tensor* output;
280 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
281 output->flat<bool>()(0) = handle.IsRefCounting();
282 }
283};
284
285REGISTER_KERNEL_BUILDER(Name("IsResourceHandleRefCounting").Device(DEVICE_CPU),
286 IsResourceHandleRefCountingOp);
287
288// Duplicates a ResourceHandle as a weak ResourceHandle.
289class MakeWeakResourceHandleOp : public OpKernel {
290 public:
291 explicit MakeWeakResourceHandleOp(OpKernelConstruction* c) : OpKernel(c) {}
292
293 void Compute(OpKernelContext* ctx) override {
294 Tensor tensor;
295 ResourceHandleProto proto;
296 HandleFromInput(ctx, 0).AsProto(&proto);
297
298 AllocatorAttributes attr;
299 attr.set_on_host(true);
300 OP_REQUIRES_OK(
301 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
302 tensor.scalar<ResourceHandle>()() = ResourceHandle{proto};
303 ctx->set_output(0, tensor);
304 }
305};
306
307REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_CPU),
308 MakeWeakResourceHandleOp);
309REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle").Device(DEVICE_DEFAULT),
310 MakeWeakResourceHandleOp);
311
312class TestAttrOp : public OpKernel {
313 public:
314 explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
315
316 void Compute(OpKernelContext* ctx) override {
317 Tensor* output;
318 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
319 output->scalar<float>()() = 1.0;
320 }
321};
322
323REGISTER_KERNEL_BUILDER(
324 Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp);
325
326// Various test ops without kernels. These are used to test graph construction.
327
328REGISTER_OP("A")
329 .Output("out: float32")
330 .SetShapeFn(shape_inference::UnknownShape);
331
332REGISTER_OP("B")
333 .Output("out: float32")
334 .SetShapeFn(shape_inference::UnknownShape);
335
336REGISTER_OP("Foo1")
337 .Input("a: float32")
338 .Input("b: int32")
339 .Input("c: int32")
340 .Output("d: float32")
341 .Output("e: int32")
342 .SetShapeFn(shape_inference::UnknownShape);
343
344REGISTER_OP("Foo2")
345 .Input("a: float32")
346 .Input("b: string")
347 .Input("c: string")
348 .Output("d: float32")
349 .Output("e: int32")
350 .SetShapeFn(shape_inference::UnknownShape);
351
352REGISTER_OP("Foo3")
353 .Input("a: float32")
354 .Input("b: string")
355 .Input("c: float32")
356 .Output("d: float32")
357 .Output("e: int32")
358 .SetShapeFn(shape_inference::UnknownShape);
359
360REGISTER_OP("CopyOp").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
361 shape_inference::UnknownShape);
362
363REGISTER_OP("None").SetShapeFn(shape_inference::UnknownShape);
364
365REGISTER_OP("IntOutput")
366 .Output("a: int32")
367 .SetShapeFn(shape_inference::UnknownShape);
368
369REGISTER_OP("Int64Output")
370 .Output("out: int64")
371 .SetShapeFn(shape_inference::UnknownShape);
372
373REGISTER_OP("RefOutput")
374 .Output("a: Ref(int32)")
375 .SetShapeFn(shape_inference::UnknownShape);
376
377REGISTER_OP("FloatOutput")
378 .Output("a: float32")
379 .SetShapeFn(shape_inference::UnknownShape);
380
381REGISTER_OP("TwoFloatOutputs")
382 .Output("a: float32")
383 .Output("b: float32")
384 .SetShapeFn(shape_inference::UnknownShape);
385
386REGISTER_OP("FiveFloatOutputs")
387 .Output("a: float32")
388 .Output("b: float32")
389 .Output("c: float32")
390 .Output("d: float32")
391 .Output("e: float32")
392 .SetShapeFn(shape_inference::UnknownShape);
393
394REGISTER_OP("RefOutputFloatOutput")
395 .Output("a: Ref(float32)")
396 .Output("b: float32")
397 .SetShapeFn(shape_inference::UnknownShape);
398
399REGISTER_OP("RefInputFloatInput")
400 .Input("a: Ref(float)")
401 .Input("b: float")
402 .SetShapeFn(shape_inference::UnknownShape);
403
404REGISTER_OP("IntInput")
405 .Input("a: int32")
406 .SetShapeFn(shape_inference::UnknownShape);
407
408REGISTER_OP("IntInputIntOutput")
409 .Input("a: int32")
410 .Output("b: int32")
411 .SetShapeFn(shape_inference::UnknownShape);
412
413REGISTER_OP("FloatInput")
414 .Input("a: float32")
415 .SetShapeFn(shape_inference::UnknownShape);
416
417REGISTER_OP("TwoIntOutputs")
418 .Output("a: int32")
419 .Output("b: int32")
420 .SetShapeFn(shape_inference::UnknownShape);
421
422REGISTER_OP("IntOutputFloatOutput")
423 .Output("a: int32")
424 .Output("b: float32")
425 .SetShapeFn(shape_inference::UnknownShape);
426
427REGISTER_OP("FloatOutputStringOutput")
428 .Output("a: float32")
429 .Output("b: string")
430 .SetShapeFn(shape_inference::UnknownShape);
431
432REGISTER_OP("TwoIntInputs")
433 .Input("a: int32")
434 .Input("b: int32")
435 .SetShapeFn(shape_inference::UnknownShape);
436
437REGISTER_OP("TwoFloatInputs")
438 .Input("a: float32")
439 .Input("b: float32")
440 .SetShapeFn(shape_inference::UnknownShape);
441
442REGISTER_OP("IntInputFloatInput")
443 .Input("a: int32")
444 .Input("b: float32")
445 .SetShapeFn(shape_inference::UnknownShape);
446
447REGISTER_OP("RefInputIntInput")
448 .Input("a: Ref(int32)")
449 .Input("b: int32")
450 .SetShapeFn(shape_inference::UnknownShape);
451
452REGISTER_OP("TwoFloatInputsFloatOutput")
453 .Input("a: float32")
454 .Input("b: float32")
455 .Output("c: float32")
456 .SetShapeFn(shape_inference::UnknownShape);
457
458REGISTER_OP("TwoFloatInputsIntOutput")
459 .Input("a: float32")
460 .Input("b: float32")
461 .Output("c: int32")
462 .SetShapeFn(shape_inference::UnknownShape);
463
464REGISTER_OP("RefInputFloatInputIntOutput")
465 .Input("a: Ref(float32)")
466 .Input("b: float32")
467 .Output("c: int32")
468 .SetShapeFn(shape_inference::UnknownShape);
469
470REGISTER_OP("ListInput")
471 .Input("a: N * T")
472 .Attr("N: int >= 1")
473 .Attr("T: type")
474 .SetShapeFn(shape_inference::UnknownShape);
475
476REGISTER_OP("ListOutput")
477 .Output("a: T")
478 .Attr("T: list(type) >= 1")
479 .SetShapeFn(shape_inference::UnknownShape);
480
481REGISTER_OP("Unary").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
482 shape_inference::UnknownShape);
483
484REGISTER_OP("OpWithDefaultAttr")
485 .Output("a: int32")
486 .Attr("default_float: float = 123.0")
487 .SetShapeFn(shape_inference::UnknownShape);
488
489REGISTER_OP("OpWithFutureDefaultAttr")
490 .SetShapeFn(shape_inference::UnknownShape);
491
492REGISTER_OP("IntAttr")
493 .Output("out: int64")
494 .Attr("foo: int = 1")
495 .SetShapeFn(shape_inference::UnknownShape);
496
497REGISTER_OP("StringListAttr")
498 .Attr("a: list(string)")
499 .Attr("b: string")
500 .SetShapeFn(shape_inference::UnknownShape);
501
502REGISTER_OP("DefaultAttrs")
503 .Attr("string_val: string = 'abc'")
504 .Attr("string_list_val: list(string) = ['abc', '']")
505 .Attr("int_val: int = 123")
506 .Attr("int_list_val: list(int) = [1, 2, 3]")
507 .Attr("float_val: float = 10.0")
508 .Attr("float_list_val: list(float) = [10.0]")
509 .Attr("bool_val: bool = true")
510 .Attr("bool_list_val: list(bool) = [true, false]")
511 .Attr("type_val: type = DT_INT32")
512 .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]")
513 .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }")
514 .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]")
515 .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}")
516 .Attr(
517 "tensor_list_val: list(tensor) = "
518 "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]")
519 .SetShapeFn(shape_inference::UnknownShape);
520
521REGISTER_OP("FuncAttr")
522 .Attr("f: func")
523 .SetShapeFn(shape_inference::UnknownShape);
524
525REGISTER_OP("FuncListAttr")
526 .Attr("f: list(func)")
527 .SetShapeFn(shape_inference::UnknownShape);
528
529REGISTER_OP("Simple")
530 .Input("a: int32")
531 .Output("out: float")
532 .SetShapeFn(shape_inference::UnknownShape);
533
534REGISTER_OP("OutT").Output("a: T").Attr("T: type").SetShapeFn(
535 shape_inference::UnknownShape);
536
537REGISTER_OP("ReservedInput")
538 .Input("input: int32")
539 .SetShapeFn(shape_inference::UnknownShape);
540
541REGISTER_OP("Polymorphic")
542 .Input("a: T")
543 .Output("out: T")
544 .Attr("T: type")
545 .SetShapeFn(shape_inference::UnknownShape);
546
547REGISTER_OP("PolymorphicOut")
548 .Output("out: T")
549 .Attr("T: type")
550 .SetShapeFn(shape_inference::UnknownShape);
551
552REGISTER_OP("PolymorphicDefaultOut")
553 .Output("out: T")
554 .Attr("T: type = DT_STRING")
555 .SetShapeFn(shape_inference::UnknownShape);
556
557REGISTER_OP("Binary")
558 .Input("a: T")
559 .Input("b: T")
560 .Output("out: T")
561 .Attr("T: type")
562 .SetShapeFn(shape_inference::UnknownShape);
563
564REGISTER_OP("Restrict")
565 .Input("a: T")
566 .Output("out: T")
567 .Attr("T: {string, bool}")
568 .SetShapeFn(shape_inference::UnknownShape);
569
570REGISTER_OP("TypeList")
571 .Input("a: T")
572 .Attr("T: list(type) >= 0")
573 .SetShapeFn(shape_inference::UnknownShape);
574
575REGISTER_OP("TypeListTwice")
576 .Input("a: T")
577 .Input("b: T")
578 .Attr("T: list(type) >= 0")
579 .SetShapeFn(shape_inference::UnknownShape);
580
581REGISTER_OP("OutTypeList")
582 .Output("out: T")
583 .Attr("T: list(type) >= 0")
584 .SetShapeFn(shape_inference::UnknownShape);
585
586REGISTER_OP("TypeListRestrict")
587 .Input("a: T")
588 .Attr("T: list({string, bool})")
589 .SetShapeFn(shape_inference::UnknownShape);
590
591REGISTER_OP("OutTypeListRestrict")
592 .Output("out: t")
593 .Attr("t: list({string, bool})")
594 .SetShapeFn(shape_inference::UnknownShape);
595
596REGISTER_OP("Attr").Attr("a: int").SetShapeFn(shape_inference::UnknownShape);
597
598REGISTER_OP("AttrFloat")
599 .Attr("a: float")
600 .SetShapeFn(shape_inference::UnknownShape);
601
602REGISTER_OP("AttrBool")
603 .Attr("a: bool")
604 .SetShapeFn(shape_inference::UnknownShape);
605
606REGISTER_OP("AttrBoolList")
607 .Attr("a: list(bool)")
608 .SetShapeFn(shape_inference::UnknownShape);
609
610REGISTER_OP("AttrMin")
611 .Attr("a: int >= 5")
612 .SetShapeFn(shape_inference::UnknownShape);
613
614REGISTER_OP("AttrListMin")
615 .Attr("a: list(int) >= 2")
616 .SetShapeFn(shape_inference::UnknownShape);
617
618REGISTER_OP("AttrEnum")
619 .Attr("a: {'apples', 'oranges'}")
620 .SetShapeFn(shape_inference::UnknownShape);
621
622REGISTER_OP("AttrEnumList")
623 .Attr("a: list({'apples', 'oranges'})")
624 .SetShapeFn(shape_inference::UnknownShape);
625
626REGISTER_OP("AttrShape")
627 .Attr("a: shape")
628 .SetShapeFn(shape_inference::UnknownShape);
629
630REGISTER_OP("AttrShapeList")
631 .Attr("a: list(shape)")
632 .SetShapeFn(shape_inference::UnknownShape);
633
634REGISTER_OP("AttrPartialShape")
635 .Attr("a: shape")
636 .SetShapeFn(shape_inference::UnknownShape);
637
638REGISTER_OP("AttrPartialShapeList")
639 .Attr("a: list(shape)")
640 .SetShapeFn(shape_inference::UnknownShape);
641
642REGISTER_OP("AttrDefault")
643 .Attr("a: string = 'banana'")
644 .SetShapeFn(shape_inference::UnknownShape);
645
646REGISTER_OP("AttrListDefault")
647 .Attr("a: list(int) = [5, 15]")
648 .SetShapeFn(shape_inference::UnknownShape);
649
650REGISTER_OP("AttrEmptyListDefault")
651 .Attr("a: list(float) = []")
652 .SetShapeFn(shape_inference::UnknownShape);
653
654REGISTER_OP("ReservedAttr")
655 .Attr("range: int")
656 .SetShapeFn(shape_inference::UnknownShape);
657
658REGISTER_OP("AttrTypeDefault")
659 .Input("a: T")
660 .Attr("T: type = DT_INT32")
661 .SetShapeFn(shape_inference::UnknownShape);
662
663REGISTER_OP("AttrListTypeDefault")
664 .Input("a: N * T")
665 .Input("b: N * T")
666 .Attr("T: type = DT_INT32")
667 .Attr("N: int")
668 .SetShapeFn(shape_inference::UnknownShape);
669
670REGISTER_OP("NIntsIn")
671 .Input("a: N * int32")
672 .Attr("N: int >= 2")
673 .SetShapeFn(shape_inference::UnknownShape);
674
675REGISTER_OP("NPolymorphicIn")
676 .Input("a: N * T")
677 .Attr("T: type")
678 .Attr("N: int >= 2")
679 .SetShapeFn(shape_inference::UnknownShape);
680
681REGISTER_OP("NPolymorphicRestrictIn")
682 .Input("a: N * T")
683 .Attr("T: {string, bool}")
684 .Attr("N: int >= 2")
685 .SetShapeFn(shape_inference::UnknownShape);
686
687REGISTER_OP("NInTwice")
688 .Input("a: N * int32")
689 .Input("b: N * string")
690 .Attr("N: int >= 0")
691 .SetShapeFn(shape_inference::UnknownShape);
692
693REGISTER_OP("NInPolymorphicTwice")
694 .Input("a: N * T")
695 .Input("b: N * T")
696 .Attr("T: type")
697 .Attr("N: int >= 0")
698 .SetShapeFn(shape_inference::UnknownShape);
699
700REGISTER_OP("NInTwoTypeVariables")
701 .Input("a: N * S")
702 .Input("b: N * T")
703 .Attr("S: type")
704 .Attr("T: type")
705 .Attr("N: int >= 0")
706 .SetShapeFn(shape_inference::UnknownShape);
707
708REGISTER_OP("InPolymorphicTwice")
709 .Input("a: N * T")
710 .Input("b: M * T")
711 .Attr("T: type = DT_INT32")
712 .Attr("N: int >= 0")
713 .Attr("M: int >= 0")
714 .SetShapeFn(shape_inference::UnknownShape);
715
716REGISTER_OP("NIntsOut")
717 .Output("a: N * int32")
718 .Attr("N: int >= 2")
719 .SetShapeFn(shape_inference::UnknownShape);
720
721REGISTER_OP("NIntsOutDefault")
722 .Output("a: N * int32")
723 .Attr("N: int >= 2 = 3")
724 .SetShapeFn(shape_inference::UnknownShape);
725
726REGISTER_OP("NPolymorphicOut")
727 .Output("a: N * T")
728 .Attr("T: type")
729 .Attr("N: int >= 2")
730 .SetShapeFn(shape_inference::UnknownShape);
731
732REGISTER_OP("NPolymorphicOutDefault")
733 .Output("a: N * T")
734 .Attr("T: type = DT_BOOL")
735 .Attr("N: int >= 2 = 2")
736 .SetShapeFn(shape_inference::UnknownShape);
737
738REGISTER_OP("NPolymorphicRestrictOut")
739 .Output("a: N * T")
740 .Attr("T: {string, bool}")
741 .Attr("N: int >= 2")
742 .SetShapeFn(shape_inference::UnknownShape);
743
744REGISTER_OP("RefIn")
745 .Input("a: Ref(T)")
746 .Attr("T: type")
747 .SetShapeFn(shape_inference::UnknownShape);
748
749REGISTER_OP("TwoRefsIn")
750 .Input("a: Ref(T)")
751 .Input("b: Ref(T)")
752 .Attr("T: type")
753 .SetShapeFn(shape_inference::UnknownShape);
754
755REGISTER_OP("RefOut")
756 .Output("a: Ref(T)")
757 .Attr("T: type")
758 .SetShapeFn(shape_inference::UnknownShape);
759
760REGISTER_OP("SimpleStruct")
761 .Output("a: n_a * int32")
762 .Attr("n_a: int >= 0")
763 .SetShapeFn(shape_inference::UnknownShape);
764
765REGISTER_OP("MixedStruct")
766 .Output("a: n_a * int32")
767 .Output("b: float")
768 .Attr("n_a: int >= 0")
769 .SetShapeFn(shape_inference::UnknownShape);
770
771REGISTER_OP("ComplexStruct")
772 .Output("a: n_a * int32")
773 .Output("b: n_b * int64")
774 .Output("c: t_c")
775 .Attr("n_a: int >= 0")
776 .Attr("n_b: int >= 0")
777 .Attr("t_c: list(type) >= 0")
778 .SetShapeFn(shape_inference::UnknownShape);
779
780// An op which returns its own device placement as a string, useful for testing
781// where ops get placed.
782REGISTER_OP("DevicePlacementOp")
783 .Output("device: string")
784 .SetIsStateful()
785 .SetShapeFn(shape_inference::ScalarShape);
786
787class DevicePlacementOp : public OpKernel {
788 public:
789 using OpKernel::OpKernel;
790
791 void Compute(OpKernelContext* ctx) override {
792 Tensor* output;
793 OP_REQUIRES_OK(ctx,
794 ctx->allocate_output("device", TensorShape({}), &output));
795 output->scalar<tstring>()() = ctx->device()->name();
796 }
797};
798
799REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU),
800 DevicePlacementOp);
801REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_DEFAULT),
802 DevicePlacementOp);
803
804// An op which returns the dtype of the tensor it was passed in. It expects
805// DT_UINT8.
806REGISTER_OP("DtypeWithDefaultOp")
807 .Input("in: T")
808 .Attr("T: type = DT_UINT8")
809 .Output("dtype: string")
810 .SetIsStateful()
811 .SetShapeFn(shape_inference::ScalarShape);
812
813class DTypeWithDefaultOp : public OpKernel {
814 public:
815 using OpKernel::OpKernel;
816
817 void Compute(OpKernelContext* ctx) override {
818 const Tensor& input = ctx->input(0);
819 Tensor* output;
820 OP_REQUIRES_OK(ctx,
821 ctx->allocate_output("dtype", TensorShape({}), &output));
822 output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype());
823 }
824};
825
826REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU),
827 DTypeWithDefaultOp);
828
829// An op that returns True if TensorFloat-32 execution is enabled. Useful for
830// testing that enabling/disabling TensorFloat-32 works correctly, even when
831// the test does not run with a GPU that supports TensorFloat-32.
832REGISTER_OP("IsTensorFloat32Enabled")
833 .Output("enabled: bool")
834 .SetIsStateful()
835 .SetShapeFn(shape_inference::ScalarShape);
836
837class IsTensorFloat32Enabled : public OpKernel {
838 public:
839 using OpKernel::OpKernel;
840
841 void Compute(OpKernelContext* ctx) override {
842 Tensor* output;
843 OP_REQUIRES_OK(ctx,
844 ctx->allocate_output("enabled", TensorShape({}), &output));
845 output->scalar<bool>()() = tensor_float_32_execution_enabled();
846 }
847};
848
849REGISTER_KERNEL_BUILDER(
850 Name("IsTensorFloat32Enabled").Device(DEVICE_CPU).HostMemory("enabled"),
851 IsTensorFloat32Enabled);
852REGISTER_KERNEL_BUILDER(
853 Name("IsTensorFloat32Enabled").Device(DEVICE_GPU).HostMemory("enabled"),
854 IsTensorFloat32Enabled);
855} // end namespace tensorflow
856