1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "absl/time/clock.h" |
17 | #include "absl/time/time.h" |
18 | #include "tensorflow/core/framework/common_shape_fns.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/resource_handle.pb.h" |
22 | #include "tensorflow/core/framework/resource_mgr.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | #include "tensorflow/core/platform/tensor_float_32_utils.h" |
26 | #include "tensorflow/core/public/version.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | REGISTER_OP("KernelLabel" ) |
31 | .Output("result: string" ) |
32 | .SetShapeFn(shape_inference::ScalarShape); |
33 | |
34 | REGISTER_OP("KernelLabelRequired" ) |
35 | .Input("input: int32" ) |
36 | .Output("result: string" ) |
37 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
38 | shape_inference::ShapeHandle out; |
39 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out)); |
40 | c->set_output(0, c->Scalar()); |
41 | return OkStatus(); |
42 | }); |
43 | |
44 | REGISTER_OP("GraphDefVersion" ) |
45 | .Output("version: int32" ) |
46 | .SetIsStateful() |
47 | .SetShapeFn(shape_inference::ScalarShape); |
48 | |
49 | REGISTER_OP("RequiresOlderGraphVersion" ) |
50 | .Output("version: int32" ) |
51 | .SetIsStateful() |
52 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
53 | if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) { |
54 | return errors::InvalidArgument("Wrong graph version for shape" ); |
55 | } |
56 | return shape_inference::ScalarShape(c); |
57 | }); |
58 | |
59 | REGISTER_OP("Old" ) |
60 | .SetShapeFn(shape_inference::UnknownShape) |
61 | .Deprecated(8, "For reasons" ); |
62 | |
63 | REGISTER_OP("GetDeadline" ) |
64 | .Output("deadline_from_epoch_micros: int64" ) |
65 | .SetShapeFn(shape_inference::UnknownShape); |
66 | |
67 | REGISTER_OP("SleepOp" ) |
68 | .Input("sleep_seconds: int32" ) |
69 | .SetShapeFn(shape_inference::UnknownShape); |
70 | |
71 | REGISTER_OP("SleepIdentityOp" ) |
72 | .Input("sleep_seconds: int32" ) |
73 | .Input("input: T" ) |
74 | .Output("output: T" ) |
75 | .Attr("T: type" ) |
76 | .SetShapeFn(shape_inference::UnchangedShape); |
77 | |
78 | REGISTER_RESOURCE_HANDLE_OP(StubResource); |
79 | |
80 | REGISTER_OP("ResourceInitializedOp" ) |
81 | .Input("resource: resource" ) |
82 | .Output("initialized: bool" ) |
83 | .SetShapeFn(shape_inference::ScalarShape); |
84 | |
85 | REGISTER_OP("ResourceCreateOp" ) |
86 | .Input("resource: resource" ) |
87 | .SetShapeFn(shape_inference::UnknownShape); |
88 | |
89 | REGISTER_OP("ResourceUsingOp" ) |
90 | .Input("resource: resource" ) |
91 | .SetShapeFn(shape_inference::UnknownShape); |
92 | |
93 | REGISTER_OP("IsResourceHandleRefCounting" ) |
94 | .Input("handle: resource" ) |
95 | .Output("result: bool" ) |
96 | .SetShapeFn(shape_inference::ScalarShape); |
97 | |
98 | REGISTER_OP("MakeWeakResourceHandle" ) |
99 | .Input("handle: resource" ) |
100 | .Output("dup: resource" ) |
101 | .SetIsStateful() |
102 | .SetShapeFn(tensorflow::shape_inference::ScalarShape); |
103 | |
104 | REGISTER_OP("TestStringOutput" ) |
105 | .Input("input: float" ) |
106 | .Output("output1: float" ) |
107 | .Output("output2: string" ) |
108 | .SetShapeFn(shape_inference::UnknownShape); |
109 | |
110 | REGISTER_OP("Namespace>TestStringOutput" ) |
111 | .Input("input: float" ) |
112 | .Output("output1: float" ) |
113 | .Output("output2: string" ) |
114 | .SetShapeFn(shape_inference::UnknownShape); |
115 | |
116 | REGISTER_OP("TestAttr" ) |
117 | .Output("out: T" ) |
118 | .Attr("T: {float, double}" ) |
119 | .SetDoNotOptimize() |
120 | .SetShapeFn(shape_inference::UnknownShape); |
121 | |
122 | namespace { |
123 | enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; |
124 | } // namespace |
125 | |
126 | template <KernelLabel KL> |
127 | class KernelLabelOp : public OpKernel { |
128 | public: |
129 | using OpKernel::OpKernel; |
130 | |
131 | void Compute(OpKernelContext* ctx) override { |
132 | Tensor* output; |
133 | OP_REQUIRES_OK(ctx, |
134 | ctx->allocate_output("result" , TensorShape({}), &output)); |
135 | switch (KL) { |
136 | case DEFAULT_LABEL: |
137 | output->scalar<tstring>()() = "My label is: default" ; |
138 | break; |
139 | case OVERLOAD_1_LABEL: |
140 | output->scalar<tstring>()() = "My label is: overload_1" ; |
141 | break; |
142 | case OVERLOAD_2_LABEL: |
143 | output->scalar<tstring>()() = "My label is: overload_2" ; |
144 | break; |
145 | } |
146 | } |
147 | }; |
148 | |
149 | REGISTER_KERNEL_BUILDER(Name("KernelLabel" ).Device(DEVICE_CPU), |
150 | KernelLabelOp<DEFAULT_LABEL>); |
151 | REGISTER_KERNEL_BUILDER( |
152 | Name("KernelLabel" ).Device(DEVICE_CPU).Label("overload_1" ), |
153 | KernelLabelOp<OVERLOAD_1_LABEL>); |
154 | REGISTER_KERNEL_BUILDER( |
155 | Name("KernelLabel" ).Device(DEVICE_CPU).Label("overload_2" ), |
156 | KernelLabelOp<OVERLOAD_2_LABEL>); |
157 | |
158 | // All "KernelLabelRequired" kernels have labels |
159 | REGISTER_KERNEL_BUILDER( |
160 | Name("KernelLabelRequired" ).Device(DEVICE_CPU).Label("overload_1" ), |
161 | KernelLabelOp<OVERLOAD_1_LABEL>); |
162 | REGISTER_KERNEL_BUILDER( |
163 | Name("KernelLabelRequired" ).Device(DEVICE_CPU).Label("overload_2" ), |
164 | KernelLabelOp<OVERLOAD_2_LABEL>); |
165 | |
166 | class GraphDefVersionOp : public OpKernel { |
167 | public: |
168 | explicit GraphDefVersionOp(OpKernelConstruction* ctx) |
169 | : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {} |
170 | |
171 | void Compute(OpKernelContext* ctx) override { |
172 | Tensor* output = nullptr; |
173 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
174 | output->scalar<int>()() = graph_def_version_; |
175 | } |
176 | |
177 | private: |
178 | const int graph_def_version_; |
179 | }; |
180 | |
181 | REGISTER_KERNEL_BUILDER(Name("GraphDefVersion" ).Device(DEVICE_CPU), |
182 | GraphDefVersionOp); |
183 | |
184 | class OldOp : public OpKernel { |
185 | public: |
186 | explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
187 | |
188 | void Compute(OpKernelContext* ctx) override {} |
189 | }; |
190 | |
191 | REGISTER_KERNEL_BUILDER(Name("Old" ).Device(DEVICE_CPU), OldOp); |
192 | |
193 | class GetDeadlineOp : public OpKernel { |
194 | public: |
195 | explicit GetDeadlineOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
196 | |
197 | void Compute(OpKernelContext* ctx) override { |
198 | if (!ctx->deadline()) { |
199 | ctx->SetStatus(errors::InvalidArgument("Deadline has not ben set." )); |
200 | } |
201 | Tensor* output; |
202 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
203 | output->scalar<int64_t>()() = absl::ToUnixMicros(*ctx->deadline()); |
204 | } |
205 | }; |
206 | |
207 | REGISTER_KERNEL_BUILDER(Name("GetDeadline" ).Device(DEVICE_CPU), GetDeadlineOp); |
208 | |
209 | class SleepOp : public OpKernel { |
210 | public: |
211 | explicit SleepOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
212 | |
213 | void Compute(OpKernelContext* ctx) override { |
214 | absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()())); |
215 | } |
216 | }; |
217 | |
218 | REGISTER_KERNEL_BUILDER(Name("SleepOp" ).Device(DEVICE_CPU), SleepOp); |
219 | |
220 | class SleepIdentityOp : public OpKernel { |
221 | public: |
222 | explicit SleepIdentityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
223 | |
224 | void Compute(OpKernelContext* ctx) override { |
225 | absl::SleepFor(absl::Seconds(ctx->input(0).scalar<int>()())); |
226 | ctx->set_output(0, ctx->input(1)); |
227 | } |
228 | }; |
229 | |
230 | REGISTER_KERNEL_BUILDER(Name("SleepIdentityOp" ).Device(DEVICE_CPU), |
231 | SleepIdentityOp); |
232 | |
233 | // Stubbed-out resource to test resource handle ops. |
234 | class StubResource : public ResourceBase { |
235 | public: |
236 | string DebugString() const override { return "" ; } |
237 | }; |
238 | |
239 | REGISTER_RESOURCE_HANDLE_KERNEL(StubResource); |
240 | |
241 | REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp" ).Device(DEVICE_CPU), |
242 | IsResourceInitialized<StubResource>); |
243 | |
244 | class ResourceCreateOp : public OpKernel { |
245 | public: |
246 | explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {} |
247 | |
248 | void Compute(OpKernelContext* c) override { |
249 | OP_REQUIRES_OK(c, |
250 | CreateResource(c, HandleFromInput(c, 0), new StubResource)); |
251 | } |
252 | }; |
253 | |
254 | REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp" ).Device(DEVICE_CPU), |
255 | ResourceCreateOp); |
256 | |
257 | // Uses a ResourceHandle to check its validity. |
258 | class ResourceUsingOp : public OpKernel { |
259 | public: |
260 | explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {} |
261 | |
262 | void Compute(OpKernelContext* ctx) override { |
263 | StubResource* unused; |
264 | OP_REQUIRES_OK(ctx, LookupResource<StubResource>( |
265 | ctx, HandleFromInput(ctx, 0), &unused)); |
266 | } |
267 | }; |
268 | |
269 | REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp" ).Device(DEVICE_CPU), |
270 | ResourceUsingOp); |
271 | |
272 | class IsResourceHandleRefCountingOp : public OpKernel { |
273 | public: |
274 | explicit IsResourceHandleRefCountingOp(OpKernelConstruction* ctx) |
275 | : OpKernel(ctx) {} |
276 | |
277 | void Compute(OpKernelContext* ctx) override { |
278 | const auto& handle = HandleFromInput(ctx, 0); |
279 | Tensor* output; |
280 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); |
281 | output->flat<bool>()(0) = handle.IsRefCounting(); |
282 | } |
283 | }; |
284 | |
285 | REGISTER_KERNEL_BUILDER(Name("IsResourceHandleRefCounting" ).Device(DEVICE_CPU), |
286 | IsResourceHandleRefCountingOp); |
287 | |
288 | // Duplicates a ResourceHandle as a weak ResourceHandle. |
289 | class MakeWeakResourceHandleOp : public OpKernel { |
290 | public: |
291 | explicit MakeWeakResourceHandleOp(OpKernelConstruction* c) : OpKernel(c) {} |
292 | |
293 | void Compute(OpKernelContext* ctx) override { |
294 | Tensor tensor; |
295 | ResourceHandleProto proto; |
296 | HandleFromInput(ctx, 0).AsProto(&proto); |
297 | |
298 | AllocatorAttributes attr; |
299 | attr.set_on_host(true); |
300 | OP_REQUIRES_OK( |
301 | ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr)); |
302 | tensor.scalar<ResourceHandle>()() = ResourceHandle{proto}; |
303 | ctx->set_output(0, tensor); |
304 | } |
305 | }; |
306 | |
307 | REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle" ).Device(DEVICE_CPU), |
308 | MakeWeakResourceHandleOp); |
309 | REGISTER_KERNEL_BUILDER(Name("MakeWeakResourceHandle" ).Device(DEVICE_DEFAULT), |
310 | MakeWeakResourceHandleOp); |
311 | |
312 | class TestAttrOp : public OpKernel { |
313 | public: |
314 | explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
315 | |
316 | void Compute(OpKernelContext* ctx) override { |
317 | Tensor* output; |
318 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
319 | output->scalar<float>()() = 1.0; |
320 | } |
321 | }; |
322 | |
323 | REGISTER_KERNEL_BUILDER( |
324 | Name("TestAttr" ).Device(DEVICE_CPU).TypeConstraint<float>("T" ), TestAttrOp); |
325 | |
326 | // Various test ops without kernels. These are used to test graph construction. |
327 | |
328 | REGISTER_OP("A" ) |
329 | .Output("out: float32" ) |
330 | .SetShapeFn(shape_inference::UnknownShape); |
331 | |
332 | REGISTER_OP("B" ) |
333 | .Output("out: float32" ) |
334 | .SetShapeFn(shape_inference::UnknownShape); |
335 | |
336 | REGISTER_OP("Foo1" ) |
337 | .Input("a: float32" ) |
338 | .Input("b: int32" ) |
339 | .Input("c: int32" ) |
340 | .Output("d: float32" ) |
341 | .Output("e: int32" ) |
342 | .SetShapeFn(shape_inference::UnknownShape); |
343 | |
344 | REGISTER_OP("Foo2" ) |
345 | .Input("a: float32" ) |
346 | .Input("b: string" ) |
347 | .Input("c: string" ) |
348 | .Output("d: float32" ) |
349 | .Output("e: int32" ) |
350 | .SetShapeFn(shape_inference::UnknownShape); |
351 | |
352 | REGISTER_OP("Foo3" ) |
353 | .Input("a: float32" ) |
354 | .Input("b: string" ) |
355 | .Input("c: float32" ) |
356 | .Output("d: float32" ) |
357 | .Output("e: int32" ) |
358 | .SetShapeFn(shape_inference::UnknownShape); |
359 | |
360 | REGISTER_OP("CopyOp" ).Input("a: T" ).Output("b: T" ).Attr("T: type" ).SetShapeFn( |
361 | shape_inference::UnknownShape); |
362 | |
363 | REGISTER_OP("None" ).SetShapeFn(shape_inference::UnknownShape); |
364 | |
365 | REGISTER_OP("IntOutput" ) |
366 | .Output("a: int32" ) |
367 | .SetShapeFn(shape_inference::UnknownShape); |
368 | |
369 | REGISTER_OP("Int64Output" ) |
370 | .Output("out: int64" ) |
371 | .SetShapeFn(shape_inference::UnknownShape); |
372 | |
373 | REGISTER_OP("RefOutput" ) |
374 | .Output("a: Ref(int32)" ) |
375 | .SetShapeFn(shape_inference::UnknownShape); |
376 | |
377 | REGISTER_OP("FloatOutput" ) |
378 | .Output("a: float32" ) |
379 | .SetShapeFn(shape_inference::UnknownShape); |
380 | |
381 | REGISTER_OP("TwoFloatOutputs" ) |
382 | .Output("a: float32" ) |
383 | .Output("b: float32" ) |
384 | .SetShapeFn(shape_inference::UnknownShape); |
385 | |
386 | REGISTER_OP("FiveFloatOutputs" ) |
387 | .Output("a: float32" ) |
388 | .Output("b: float32" ) |
389 | .Output("c: float32" ) |
390 | .Output("d: float32" ) |
391 | .Output("e: float32" ) |
392 | .SetShapeFn(shape_inference::UnknownShape); |
393 | |
394 | REGISTER_OP("RefOutputFloatOutput" ) |
395 | .Output("a: Ref(float32)" ) |
396 | .Output("b: float32" ) |
397 | .SetShapeFn(shape_inference::UnknownShape); |
398 | |
399 | REGISTER_OP("RefInputFloatInput" ) |
400 | .Input("a: Ref(float)" ) |
401 | .Input("b: float" ) |
402 | .SetShapeFn(shape_inference::UnknownShape); |
403 | |
404 | REGISTER_OP("IntInput" ) |
405 | .Input("a: int32" ) |
406 | .SetShapeFn(shape_inference::UnknownShape); |
407 | |
408 | REGISTER_OP("IntInputIntOutput" ) |
409 | .Input("a: int32" ) |
410 | .Output("b: int32" ) |
411 | .SetShapeFn(shape_inference::UnknownShape); |
412 | |
413 | REGISTER_OP("FloatInput" ) |
414 | .Input("a: float32" ) |
415 | .SetShapeFn(shape_inference::UnknownShape); |
416 | |
417 | REGISTER_OP("TwoIntOutputs" ) |
418 | .Output("a: int32" ) |
419 | .Output("b: int32" ) |
420 | .SetShapeFn(shape_inference::UnknownShape); |
421 | |
422 | REGISTER_OP("IntOutputFloatOutput" ) |
423 | .Output("a: int32" ) |
424 | .Output("b: float32" ) |
425 | .SetShapeFn(shape_inference::UnknownShape); |
426 | |
427 | REGISTER_OP("FloatOutputStringOutput" ) |
428 | .Output("a: float32" ) |
429 | .Output("b: string" ) |
430 | .SetShapeFn(shape_inference::UnknownShape); |
431 | |
432 | REGISTER_OP("TwoIntInputs" ) |
433 | .Input("a: int32" ) |
434 | .Input("b: int32" ) |
435 | .SetShapeFn(shape_inference::UnknownShape); |
436 | |
437 | REGISTER_OP("TwoFloatInputs" ) |
438 | .Input("a: float32" ) |
439 | .Input("b: float32" ) |
440 | .SetShapeFn(shape_inference::UnknownShape); |
441 | |
442 | REGISTER_OP("IntInputFloatInput" ) |
443 | .Input("a: int32" ) |
444 | .Input("b: float32" ) |
445 | .SetShapeFn(shape_inference::UnknownShape); |
446 | |
447 | REGISTER_OP("RefInputIntInput" ) |
448 | .Input("a: Ref(int32)" ) |
449 | .Input("b: int32" ) |
450 | .SetShapeFn(shape_inference::UnknownShape); |
451 | |
452 | REGISTER_OP("TwoFloatInputsFloatOutput" ) |
453 | .Input("a: float32" ) |
454 | .Input("b: float32" ) |
455 | .Output("c: float32" ) |
456 | .SetShapeFn(shape_inference::UnknownShape); |
457 | |
458 | REGISTER_OP("TwoFloatInputsIntOutput" ) |
459 | .Input("a: float32" ) |
460 | .Input("b: float32" ) |
461 | .Output("c: int32" ) |
462 | .SetShapeFn(shape_inference::UnknownShape); |
463 | |
464 | REGISTER_OP("RefInputFloatInputIntOutput" ) |
465 | .Input("a: Ref(float32)" ) |
466 | .Input("b: float32" ) |
467 | .Output("c: int32" ) |
468 | .SetShapeFn(shape_inference::UnknownShape); |
469 | |
470 | REGISTER_OP("ListInput" ) |
471 | .Input("a: N * T" ) |
472 | .Attr("N: int >= 1" ) |
473 | .Attr("T: type" ) |
474 | .SetShapeFn(shape_inference::UnknownShape); |
475 | |
476 | REGISTER_OP("ListOutput" ) |
477 | .Output("a: T" ) |
478 | .Attr("T: list(type) >= 1" ) |
479 | .SetShapeFn(shape_inference::UnknownShape); |
480 | |
481 | REGISTER_OP("Unary" ).Input("a: T" ).Output("b: T" ).Attr("T: type" ).SetShapeFn( |
482 | shape_inference::UnknownShape); |
483 | |
484 | REGISTER_OP("OpWithDefaultAttr" ) |
485 | .Output("a: int32" ) |
486 | .Attr("default_float: float = 123.0" ) |
487 | .SetShapeFn(shape_inference::UnknownShape); |
488 | |
489 | REGISTER_OP("OpWithFutureDefaultAttr" ) |
490 | .SetShapeFn(shape_inference::UnknownShape); |
491 | |
492 | REGISTER_OP("IntAttr" ) |
493 | .Output("out: int64" ) |
494 | .Attr("foo: int = 1" ) |
495 | .SetShapeFn(shape_inference::UnknownShape); |
496 | |
497 | REGISTER_OP("StringListAttr" ) |
498 | .Attr("a: list(string)" ) |
499 | .Attr("b: string" ) |
500 | .SetShapeFn(shape_inference::UnknownShape); |
501 | |
502 | REGISTER_OP("DefaultAttrs" ) |
503 | .Attr("string_val: string = 'abc'" ) |
504 | .Attr("string_list_val: list(string) = ['abc', '']" ) |
505 | .Attr("int_val: int = 123" ) |
506 | .Attr("int_list_val: list(int) = [1, 2, 3]" ) |
507 | .Attr("float_val: float = 10.0" ) |
508 | .Attr("float_list_val: list(float) = [10.0]" ) |
509 | .Attr("bool_val: bool = true" ) |
510 | .Attr("bool_list_val: list(bool) = [true, false]" ) |
511 | .Attr("type_val: type = DT_INT32" ) |
512 | .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]" ) |
513 | .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }" ) |
514 | .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]" ) |
515 | .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}" ) |
516 | .Attr( |
517 | "tensor_list_val: list(tensor) = " |
518 | "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]" ) |
519 | .SetShapeFn(shape_inference::UnknownShape); |
520 | |
521 | REGISTER_OP("FuncAttr" ) |
522 | .Attr("f: func" ) |
523 | .SetShapeFn(shape_inference::UnknownShape); |
524 | |
525 | REGISTER_OP("FuncListAttr" ) |
526 | .Attr("f: list(func)" ) |
527 | .SetShapeFn(shape_inference::UnknownShape); |
528 | |
529 | REGISTER_OP("Simple" ) |
530 | .Input("a: int32" ) |
531 | .Output("out: float" ) |
532 | .SetShapeFn(shape_inference::UnknownShape); |
533 | |
534 | REGISTER_OP("OutT" ).Output("a: T" ).Attr("T: type" ).SetShapeFn( |
535 | shape_inference::UnknownShape); |
536 | |
537 | REGISTER_OP("ReservedInput" ) |
538 | .Input("input: int32" ) |
539 | .SetShapeFn(shape_inference::UnknownShape); |
540 | |
541 | REGISTER_OP("Polymorphic" ) |
542 | .Input("a: T" ) |
543 | .Output("out: T" ) |
544 | .Attr("T: type" ) |
545 | .SetShapeFn(shape_inference::UnknownShape); |
546 | |
547 | REGISTER_OP("PolymorphicOut" ) |
548 | .Output("out: T" ) |
549 | .Attr("T: type" ) |
550 | .SetShapeFn(shape_inference::UnknownShape); |
551 | |
552 | REGISTER_OP("PolymorphicDefaultOut" ) |
553 | .Output("out: T" ) |
554 | .Attr("T: type = DT_STRING" ) |
555 | .SetShapeFn(shape_inference::UnknownShape); |
556 | |
557 | REGISTER_OP("Binary" ) |
558 | .Input("a: T" ) |
559 | .Input("b: T" ) |
560 | .Output("out: T" ) |
561 | .Attr("T: type" ) |
562 | .SetShapeFn(shape_inference::UnknownShape); |
563 | |
564 | REGISTER_OP("Restrict" ) |
565 | .Input("a: T" ) |
566 | .Output("out: T" ) |
567 | .Attr("T: {string, bool}" ) |
568 | .SetShapeFn(shape_inference::UnknownShape); |
569 | |
570 | REGISTER_OP("TypeList" ) |
571 | .Input("a: T" ) |
572 | .Attr("T: list(type) >= 0" ) |
573 | .SetShapeFn(shape_inference::UnknownShape); |
574 | |
575 | REGISTER_OP("TypeListTwice" ) |
576 | .Input("a: T" ) |
577 | .Input("b: T" ) |
578 | .Attr("T: list(type) >= 0" ) |
579 | .SetShapeFn(shape_inference::UnknownShape); |
580 | |
581 | REGISTER_OP("OutTypeList" ) |
582 | .Output("out: T" ) |
583 | .Attr("T: list(type) >= 0" ) |
584 | .SetShapeFn(shape_inference::UnknownShape); |
585 | |
586 | REGISTER_OP("TypeListRestrict" ) |
587 | .Input("a: T" ) |
588 | .Attr("T: list({string, bool})" ) |
589 | .SetShapeFn(shape_inference::UnknownShape); |
590 | |
591 | REGISTER_OP("OutTypeListRestrict" ) |
592 | .Output("out: t" ) |
593 | .Attr("t: list({string, bool})" ) |
594 | .SetShapeFn(shape_inference::UnknownShape); |
595 | |
596 | REGISTER_OP("Attr" ).Attr("a: int" ).SetShapeFn(shape_inference::UnknownShape); |
597 | |
598 | REGISTER_OP("AttrFloat" ) |
599 | .Attr("a: float" ) |
600 | .SetShapeFn(shape_inference::UnknownShape); |
601 | |
602 | REGISTER_OP("AttrBool" ) |
603 | .Attr("a: bool" ) |
604 | .SetShapeFn(shape_inference::UnknownShape); |
605 | |
606 | REGISTER_OP("AttrBoolList" ) |
607 | .Attr("a: list(bool)" ) |
608 | .SetShapeFn(shape_inference::UnknownShape); |
609 | |
610 | REGISTER_OP("AttrMin" ) |
611 | .Attr("a: int >= 5" ) |
612 | .SetShapeFn(shape_inference::UnknownShape); |
613 | |
614 | REGISTER_OP("AttrListMin" ) |
615 | .Attr("a: list(int) >= 2" ) |
616 | .SetShapeFn(shape_inference::UnknownShape); |
617 | |
618 | REGISTER_OP("AttrEnum" ) |
619 | .Attr("a: {'apples', 'oranges'}" ) |
620 | .SetShapeFn(shape_inference::UnknownShape); |
621 | |
622 | REGISTER_OP("AttrEnumList" ) |
623 | .Attr("a: list({'apples', 'oranges'})" ) |
624 | .SetShapeFn(shape_inference::UnknownShape); |
625 | |
626 | REGISTER_OP("AttrShape" ) |
627 | .Attr("a: shape" ) |
628 | .SetShapeFn(shape_inference::UnknownShape); |
629 | |
630 | REGISTER_OP("AttrShapeList" ) |
631 | .Attr("a: list(shape)" ) |
632 | .SetShapeFn(shape_inference::UnknownShape); |
633 | |
634 | REGISTER_OP("AttrPartialShape" ) |
635 | .Attr("a: shape" ) |
636 | .SetShapeFn(shape_inference::UnknownShape); |
637 | |
638 | REGISTER_OP("AttrPartialShapeList" ) |
639 | .Attr("a: list(shape)" ) |
640 | .SetShapeFn(shape_inference::UnknownShape); |
641 | |
642 | REGISTER_OP("AttrDefault" ) |
643 | .Attr("a: string = 'banana'" ) |
644 | .SetShapeFn(shape_inference::UnknownShape); |
645 | |
646 | REGISTER_OP("AttrListDefault" ) |
647 | .Attr("a: list(int) = [5, 15]" ) |
648 | .SetShapeFn(shape_inference::UnknownShape); |
649 | |
650 | REGISTER_OP("AttrEmptyListDefault" ) |
651 | .Attr("a: list(float) = []" ) |
652 | .SetShapeFn(shape_inference::UnknownShape); |
653 | |
654 | REGISTER_OP("ReservedAttr" ) |
655 | .Attr("range: int" ) |
656 | .SetShapeFn(shape_inference::UnknownShape); |
657 | |
658 | REGISTER_OP("AttrTypeDefault" ) |
659 | .Input("a: T" ) |
660 | .Attr("T: type = DT_INT32" ) |
661 | .SetShapeFn(shape_inference::UnknownShape); |
662 | |
663 | REGISTER_OP("AttrListTypeDefault" ) |
664 | .Input("a: N * T" ) |
665 | .Input("b: N * T" ) |
666 | .Attr("T: type = DT_INT32" ) |
667 | .Attr("N: int" ) |
668 | .SetShapeFn(shape_inference::UnknownShape); |
669 | |
670 | REGISTER_OP("NIntsIn" ) |
671 | .Input("a: N * int32" ) |
672 | .Attr("N: int >= 2" ) |
673 | .SetShapeFn(shape_inference::UnknownShape); |
674 | |
675 | REGISTER_OP("NPolymorphicIn" ) |
676 | .Input("a: N * T" ) |
677 | .Attr("T: type" ) |
678 | .Attr("N: int >= 2" ) |
679 | .SetShapeFn(shape_inference::UnknownShape); |
680 | |
681 | REGISTER_OP("NPolymorphicRestrictIn" ) |
682 | .Input("a: N * T" ) |
683 | .Attr("T: {string, bool}" ) |
684 | .Attr("N: int >= 2" ) |
685 | .SetShapeFn(shape_inference::UnknownShape); |
686 | |
687 | REGISTER_OP("NInTwice" ) |
688 | .Input("a: N * int32" ) |
689 | .Input("b: N * string" ) |
690 | .Attr("N: int >= 0" ) |
691 | .SetShapeFn(shape_inference::UnknownShape); |
692 | |
693 | REGISTER_OP("NInPolymorphicTwice" ) |
694 | .Input("a: N * T" ) |
695 | .Input("b: N * T" ) |
696 | .Attr("T: type" ) |
697 | .Attr("N: int >= 0" ) |
698 | .SetShapeFn(shape_inference::UnknownShape); |
699 | |
700 | REGISTER_OP("NInTwoTypeVariables" ) |
701 | .Input("a: N * S" ) |
702 | .Input("b: N * T" ) |
703 | .Attr("S: type" ) |
704 | .Attr("T: type" ) |
705 | .Attr("N: int >= 0" ) |
706 | .SetShapeFn(shape_inference::UnknownShape); |
707 | |
708 | REGISTER_OP("InPolymorphicTwice" ) |
709 | .Input("a: N * T" ) |
710 | .Input("b: M * T" ) |
711 | .Attr("T: type = DT_INT32" ) |
712 | .Attr("N: int >= 0" ) |
713 | .Attr("M: int >= 0" ) |
714 | .SetShapeFn(shape_inference::UnknownShape); |
715 | |
716 | REGISTER_OP("NIntsOut" ) |
717 | .Output("a: N * int32" ) |
718 | .Attr("N: int >= 2" ) |
719 | .SetShapeFn(shape_inference::UnknownShape); |
720 | |
721 | REGISTER_OP("NIntsOutDefault" ) |
722 | .Output("a: N * int32" ) |
723 | .Attr("N: int >= 2 = 3" ) |
724 | .SetShapeFn(shape_inference::UnknownShape); |
725 | |
726 | REGISTER_OP("NPolymorphicOut" ) |
727 | .Output("a: N * T" ) |
728 | .Attr("T: type" ) |
729 | .Attr("N: int >= 2" ) |
730 | .SetShapeFn(shape_inference::UnknownShape); |
731 | |
732 | REGISTER_OP("NPolymorphicOutDefault" ) |
733 | .Output("a: N * T" ) |
734 | .Attr("T: type = DT_BOOL" ) |
735 | .Attr("N: int >= 2 = 2" ) |
736 | .SetShapeFn(shape_inference::UnknownShape); |
737 | |
738 | REGISTER_OP("NPolymorphicRestrictOut" ) |
739 | .Output("a: N * T" ) |
740 | .Attr("T: {string, bool}" ) |
741 | .Attr("N: int >= 2" ) |
742 | .SetShapeFn(shape_inference::UnknownShape); |
743 | |
744 | REGISTER_OP("RefIn" ) |
745 | .Input("a: Ref(T)" ) |
746 | .Attr("T: type" ) |
747 | .SetShapeFn(shape_inference::UnknownShape); |
748 | |
749 | REGISTER_OP("TwoRefsIn" ) |
750 | .Input("a: Ref(T)" ) |
751 | .Input("b: Ref(T)" ) |
752 | .Attr("T: type" ) |
753 | .SetShapeFn(shape_inference::UnknownShape); |
754 | |
755 | REGISTER_OP("RefOut" ) |
756 | .Output("a: Ref(T)" ) |
757 | .Attr("T: type" ) |
758 | .SetShapeFn(shape_inference::UnknownShape); |
759 | |
760 | REGISTER_OP("SimpleStruct" ) |
761 | .Output("a: n_a * int32" ) |
762 | .Attr("n_a: int >= 0" ) |
763 | .SetShapeFn(shape_inference::UnknownShape); |
764 | |
765 | REGISTER_OP("MixedStruct" ) |
766 | .Output("a: n_a * int32" ) |
767 | .Output("b: float" ) |
768 | .Attr("n_a: int >= 0" ) |
769 | .SetShapeFn(shape_inference::UnknownShape); |
770 | |
771 | REGISTER_OP("ComplexStruct" ) |
772 | .Output("a: n_a * int32" ) |
773 | .Output("b: n_b * int64" ) |
774 | .Output("c: t_c" ) |
775 | .Attr("n_a: int >= 0" ) |
776 | .Attr("n_b: int >= 0" ) |
777 | .Attr("t_c: list(type) >= 0" ) |
778 | .SetShapeFn(shape_inference::UnknownShape); |
779 | |
780 | // An op which returns its own device placement as a string, useful for testing |
781 | // where ops get placed. |
782 | REGISTER_OP("DevicePlacementOp" ) |
783 | .Output("device: string" ) |
784 | .SetIsStateful() |
785 | .SetShapeFn(shape_inference::ScalarShape); |
786 | |
787 | class DevicePlacementOp : public OpKernel { |
788 | public: |
789 | using OpKernel::OpKernel; |
790 | |
791 | void Compute(OpKernelContext* ctx) override { |
792 | Tensor* output; |
793 | OP_REQUIRES_OK(ctx, |
794 | ctx->allocate_output("device" , TensorShape({}), &output)); |
795 | output->scalar<tstring>()() = ctx->device()->name(); |
796 | } |
797 | }; |
798 | |
799 | REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp" ).Device(DEVICE_CPU), |
800 | DevicePlacementOp); |
801 | REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp" ).Device(DEVICE_DEFAULT), |
802 | DevicePlacementOp); |
803 | |
804 | // An op which returns the dtype of the tensor it was passed in. It expects |
805 | // DT_UINT8. |
806 | REGISTER_OP("DtypeWithDefaultOp" ) |
807 | .Input("in: T" ) |
808 | .Attr("T: type = DT_UINT8" ) |
809 | .Output("dtype: string" ) |
810 | .SetIsStateful() |
811 | .SetShapeFn(shape_inference::ScalarShape); |
812 | |
813 | class DTypeWithDefaultOp : public OpKernel { |
814 | public: |
815 | using OpKernel::OpKernel; |
816 | |
817 | void Compute(OpKernelContext* ctx) override { |
818 | const Tensor& input = ctx->input(0); |
819 | Tensor* output; |
820 | OP_REQUIRES_OK(ctx, |
821 | ctx->allocate_output("dtype" , TensorShape({}), &output)); |
822 | output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype()); |
823 | } |
824 | }; |
825 | |
826 | REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp" ).Device(DEVICE_CPU), |
827 | DTypeWithDefaultOp); |
828 | |
829 | // An op that returns True if TensorFloat-32 execution is enabled. Useful for |
830 | // testing that enabling/disabling TensorFloat-32 works correctly, even when |
831 | // the test does not run with a GPU that supports TensorFloat-32. |
832 | REGISTER_OP("IsTensorFloat32Enabled" ) |
833 | .Output("enabled: bool" ) |
834 | .SetIsStateful() |
835 | .SetShapeFn(shape_inference::ScalarShape); |
836 | |
837 | class IsTensorFloat32Enabled : public OpKernel { |
838 | public: |
839 | using OpKernel::OpKernel; |
840 | |
841 | void Compute(OpKernelContext* ctx) override { |
842 | Tensor* output; |
843 | OP_REQUIRES_OK(ctx, |
844 | ctx->allocate_output("enabled" , TensorShape({}), &output)); |
845 | output->scalar<bool>()() = tensor_float_32_execution_enabled(); |
846 | } |
847 | }; |
848 | |
849 | REGISTER_KERNEL_BUILDER( |
850 | Name("IsTensorFloat32Enabled" ).Device(DEVICE_CPU).HostMemory("enabled" ), |
851 | IsTensorFloat32Enabled); |
852 | REGISTER_KERNEL_BUILDER( |
853 | Name("IsTensorFloat32Enabled" ).Device(DEVICE_GPU).HostMemory("enabled" ), |
854 | IsTensorFloat32Enabled); |
855 | } // end namespace tensorflow |
856 | |