1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #include "tensorflow/core/kernels/shape_ops.h" |
19 | |
20 | #include "tensorflow/core/framework/node_def.pb.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Shape ---------------------------------------- |
26 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
27 | .Device(DEVICE_CPU) |
28 | .HostMemory("output" ) |
29 | .TypeConstraint<int32>("out_type" ), |
30 | ShapeOp<int32>); |
31 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
32 | .Device(DEVICE_CPU) |
33 | .HostMemory("output" ) |
34 | .TypeConstraint<int64_t>("out_type" ), |
35 | ShapeOp<int64_t>); |
36 | |
37 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
38 | #define REGISTER_GPU_KERNEL(type) \ |
39 | REGISTER_KERNEL_BUILDER(Name("Shape") \ |
40 | .Device(DEVICE_GPU) \ |
41 | .HostMemory("output") \ |
42 | .TypeConstraint<int32>("out_type") \ |
43 | .TypeConstraint<type>("T"), \ |
44 | ShapeOp<int32>); \ |
45 | REGISTER_KERNEL_BUILDER(Name("Shape") \ |
46 | .Device(DEVICE_GPU) \ |
47 | .HostMemory("output") \ |
48 | .TypeConstraint<int64_t>("out_type") \ |
49 | .TypeConstraint<type>("T"), \ |
50 | ShapeOp<int64_t>); |
51 | |
52 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
53 | TF_CALL_bool(REGISTER_GPU_KERNEL); |
54 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
55 | #undef REGISTER_GPU_KERNEL |
56 | |
57 | // A special GPU kernel for int32. |
58 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
59 | // registration requires all int32 inputs and outputs to be in host memory. |
60 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
61 | .Device(DEVICE_GPU) |
62 | .HostMemory("input" ) |
63 | .HostMemory("output" ) |
64 | .TypeConstraint<int32>("T" ) |
65 | .TypeConstraint<int32>("out_type" ), |
66 | ShapeOp<int32>); |
67 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
68 | .Device(DEVICE_GPU) |
69 | .HostMemory("input" ) |
70 | .HostMemory("output" ) |
71 | .TypeConstraint<int32>("T" ) |
72 | .TypeConstraint<int64_t>("out_type" ), |
73 | ShapeOp<int64_t>); |
74 | |
75 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
76 | |
77 | #define REGISTER_DEFAULT_KERNEL(type) \ |
78 | REGISTER_KERNEL_BUILDER(Name("Shape") \ |
79 | .Device(DEVICE_DEFAULT) \ |
80 | .HostMemory("output") \ |
81 | .TypeConstraint<int32>("out_type") \ |
82 | .TypeConstraint<type>("T"), \ |
83 | ShapeOp<int32>); \ |
84 | REGISTER_KERNEL_BUILDER(Name("Shape") \ |
85 | .Device(DEVICE_DEFAULT) \ |
86 | .HostMemory("output") \ |
87 | .TypeConstraint<int64_t>("out_type") \ |
88 | .TypeConstraint<type>("T"), \ |
89 | ShapeOp<int64_t>); |
90 | |
91 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
92 | TF_CALL_bool(REGISTER_DEFAULT_KERNEL); |
93 | TF_CALL_variant(REGISTER_DEFAULT_KERNEL); |
94 | #undef REGISTER_DEFAULT_KERNEL |
95 | |
96 | // A special GPU kernel for int32. |
97 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
98 | // registration requires all int32 inputs and outputs to be in host memory. |
99 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
100 | .Device(DEVICE_DEFAULT) |
101 | .HostMemory("input" ) |
102 | .HostMemory("output" ) |
103 | .TypeConstraint<int32>("T" ) |
104 | .TypeConstraint<int32>("out_type" ), |
105 | ShapeOp<int32>); |
106 | REGISTER_KERNEL_BUILDER(Name("Shape" ) |
107 | .Device(DEVICE_DEFAULT) |
108 | .HostMemory("input" ) |
109 | .HostMemory("output" ) |
110 | .TypeConstraint<int32>("T" ) |
111 | .TypeConstraint<int64_t>("out_type" ), |
112 | ShapeOp<int64_t>); |
113 | |
114 | // ShapeN --------------------------------------- |
115 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
116 | .Device(DEVICE_CPU) |
117 | .HostMemory("output" ) |
118 | .TypeConstraint<int32>("out_type" ), |
119 | ShapeNOp<int32>); |
120 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
121 | .Device(DEVICE_CPU) |
122 | .HostMemory("output" ) |
123 | .TypeConstraint<int64_t>("out_type" ), |
124 | ShapeNOp<int64_t>); |
125 | |
126 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
127 | #define REGISTER_GPU_KERNEL(type) \ |
128 | REGISTER_KERNEL_BUILDER(Name("ShapeN") \ |
129 | .Device(DEVICE_GPU) \ |
130 | .HostMemory("output") \ |
131 | .TypeConstraint<int32>("out_type") \ |
132 | .TypeConstraint<type>("T"), \ |
133 | ShapeNOp<int32>); \ |
134 | REGISTER_KERNEL_BUILDER(Name("ShapeN") \ |
135 | .Device(DEVICE_GPU) \ |
136 | .HostMemory("output") \ |
137 | .TypeConstraint<int64_t>("out_type") \ |
138 | .TypeConstraint<type>("T"), \ |
139 | ShapeNOp<int64_t>) |
140 | |
141 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
142 | TF_CALL_bool(REGISTER_GPU_KERNEL); |
143 | #undef REGISTER_GPU_KERNEL |
144 | |
145 | // A special GPU kernel for int32. |
146 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
147 | // registration requires all int32 inputs and outputs to be in host memory. |
148 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
149 | .Device(DEVICE_GPU) |
150 | .HostMemory("input" ) |
151 | .HostMemory("output" ) |
152 | .TypeConstraint<int32>("T" ) |
153 | .TypeConstraint<int32>("out_type" ), |
154 | ShapeNOp<int32>); |
155 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
156 | .Device(DEVICE_GPU) |
157 | .HostMemory("input" ) |
158 | .HostMemory("output" ) |
159 | .TypeConstraint<int32>("T" ) |
160 | .TypeConstraint<int64_t>("out_type" ), |
161 | ShapeNOp<int64_t>); |
162 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
163 | |
164 | #define REGISTER_DEFAULT_KERNEL(type) \ |
165 | REGISTER_KERNEL_BUILDER(Name("ShapeN") \ |
166 | .Device(DEVICE_DEFAULT) \ |
167 | .HostMemory("output") \ |
168 | .TypeConstraint<int32>("out_type") \ |
169 | .TypeConstraint<type>("T"), \ |
170 | ShapeNOp<int32>); \ |
171 | REGISTER_KERNEL_BUILDER(Name("ShapeN") \ |
172 | .Device(DEVICE_DEFAULT) \ |
173 | .HostMemory("output") \ |
174 | .TypeConstraint<int64_t>("out_type") \ |
175 | .TypeConstraint<type>("T"), \ |
176 | ShapeNOp<int64_t>) |
177 | |
178 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
179 | TF_CALL_bool(REGISTER_DEFAULT_KERNEL); |
180 | #undef REGISTER_DEFAULT_KERNEL |
181 | |
182 | // A special GPU kernel for int32. |
183 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
184 | // registration requires all int32 inputs and outputs to be in host memory. |
185 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
186 | .Device(DEVICE_DEFAULT) |
187 | .HostMemory("input" ) |
188 | .HostMemory("output" ) |
189 | .TypeConstraint<int32>("T" ) |
190 | .TypeConstraint<int32>("out_type" ), |
191 | ShapeNOp<int32>); |
192 | REGISTER_KERNEL_BUILDER(Name("ShapeN" ) |
193 | .Device(DEVICE_DEFAULT) |
194 | .HostMemory("input" ) |
195 | .HostMemory("output" ) |
196 | .TypeConstraint<int32>("T" ) |
197 | .TypeConstraint<int64_t>("out_type" ), |
198 | ShapeNOp<int64_t>); |
199 | |
200 | // Rank ------------------------------------------ |
201 | REGISTER_KERNEL_BUILDER(Name("Rank" ).Device(DEVICE_CPU).HostMemory("output" ), |
202 | RankOp); |
203 | |
204 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
205 | #define REGISTER_GPU_KERNEL(type) \ |
206 | REGISTER_KERNEL_BUILDER(Name("Rank") \ |
207 | .Device(DEVICE_GPU) \ |
208 | .TypeConstraint<type>("T") \ |
209 | .HostMemory("output"), \ |
210 | RankOp); |
211 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
212 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
213 | #undef REGISTER_GPU_KERNEL |
214 | |
215 | // A special GPU kernel for int32 and bool. |
216 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
217 | // registration requires all int32 inputs and outputs to be in host memory. |
218 | REGISTER_KERNEL_BUILDER(Name("Rank" ) |
219 | .Device(DEVICE_GPU) |
220 | .TypeConstraint<int32>("T" ) |
221 | .HostMemory("input" ) |
222 | .HostMemory("output" ), |
223 | RankOp); |
224 | |
225 | REGISTER_KERNEL_BUILDER(Name("Rank" ) |
226 | .Device(DEVICE_GPU) |
227 | .TypeConstraint<bool>("T" ) |
228 | .HostMemory("input" ) |
229 | .HostMemory("output" ), |
230 | RankOp); |
231 | |
232 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
233 | |
234 | #define REGISTER_DEFAULT_KERNEL(type) \ |
235 | REGISTER_KERNEL_BUILDER(Name("Rank") \ |
236 | .Device(DEVICE_DEFAULT) \ |
237 | .TypeConstraint<type>("T") \ |
238 | .HostMemory("output"), \ |
239 | RankOp); |
240 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
241 | TF_CALL_variant(REGISTER_DEFAULT_KERNEL); |
242 | #undef REGISTER_DEFAULT_KERNEL |
243 | |
244 | // A special GPU kernel for int32 and bool. |
245 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
246 | // registration requires all int32 inputs and outputs to be in host memory. |
247 | REGISTER_KERNEL_BUILDER(Name("Rank" ) |
248 | .Device(DEVICE_DEFAULT) |
249 | .TypeConstraint<int32>("T" ) |
250 | .HostMemory("input" ) |
251 | .HostMemory("output" ), |
252 | RankOp); |
253 | |
254 | REGISTER_KERNEL_BUILDER(Name("Rank" ) |
255 | .Device(DEVICE_DEFAULT) |
256 | .TypeConstraint<bool>("T" ) |
257 | .HostMemory("input" ) |
258 | .HostMemory("output" ), |
259 | RankOp); |
260 | |
261 | // Size ------------------------------------------ |
262 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
263 | .Device(DEVICE_CPU) |
264 | .HostMemory("output" ) |
265 | .TypeConstraint<int32>("out_type" ), |
266 | SizeOp<int32>); |
267 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
268 | .Device(DEVICE_CPU) |
269 | .HostMemory("output" ) |
270 | .TypeConstraint<int64_t>("out_type" ), |
271 | SizeOp<int64_t>); |
272 | |
273 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
274 | #define REGISTER_GPU_KERNEL(type) \ |
275 | REGISTER_KERNEL_BUILDER(Name("Size") \ |
276 | .Device(DEVICE_GPU) \ |
277 | .TypeConstraint<type>("T") \ |
278 | .TypeConstraint<int32>("out_type") \ |
279 | .HostMemory("output"), \ |
280 | SizeOp<int32>); \ |
281 | REGISTER_KERNEL_BUILDER(Name("Size") \ |
282 | .Device(DEVICE_GPU) \ |
283 | .TypeConstraint<type>("T") \ |
284 | .TypeConstraint<int64_t>("out_type") \ |
285 | .HostMemory("output"), \ |
286 | SizeOp<int64_t>); |
287 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
288 | TF_CALL_bool(REGISTER_GPU_KERNEL); |
289 | TF_CALL_variant(REGISTER_GPU_KERNEL); |
290 | #undef REGISTER_GPU_KERNEL |
291 | |
292 | // A special GPU kernel for int32. |
293 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
294 | // registration requires all int32 inputs and outputs to be in host memory. |
295 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
296 | .Device(DEVICE_GPU) |
297 | .TypeConstraint<int32>("T" ) |
298 | .TypeConstraint<int32>("out_type" ) |
299 | .HostMemory("input" ) |
300 | .HostMemory("output" ), |
301 | SizeOp<int32>); |
302 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
303 | .Device(DEVICE_GPU) |
304 | .TypeConstraint<int32>("T" ) |
305 | .TypeConstraint<int64_t>("out_type" ) |
306 | .HostMemory("input" ) |
307 | .HostMemory("output" ), |
308 | SizeOp<int64_t>); |
309 | |
310 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
311 | |
312 | #define REGISTER_DEFAULT_KERNEL(type) \ |
313 | REGISTER_KERNEL_BUILDER(Name("Size") \ |
314 | .Device(DEVICE_DEFAULT) \ |
315 | .TypeConstraint<type>("T") \ |
316 | .TypeConstraint<int32>("out_type") \ |
317 | .HostMemory("output"), \ |
318 | SizeOp<int32>); \ |
319 | REGISTER_KERNEL_BUILDER(Name("Size") \ |
320 | .Device(DEVICE_DEFAULT) \ |
321 | .TypeConstraint<type>("T") \ |
322 | .TypeConstraint<int64_t>("out_type") \ |
323 | .HostMemory("output"), \ |
324 | SizeOp<int64_t>); |
325 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
326 | TF_CALL_bool(REGISTER_DEFAULT_KERNEL); |
327 | TF_CALL_variant(REGISTER_DEFAULT_KERNEL); |
328 | #undef REGISTER_DEFAULT_KERNEL |
329 | |
330 | // A special GPU kernel for int32. |
331 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
332 | // registration requires all int32 inputs and outputs to be in host memory. |
333 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
334 | .Device(DEVICE_DEFAULT) |
335 | .TypeConstraint<int32>("T" ) |
336 | .TypeConstraint<int32>("out_type" ) |
337 | .HostMemory("input" ) |
338 | .HostMemory("output" ), |
339 | SizeOp<int32>); |
340 | REGISTER_KERNEL_BUILDER(Name("Size" ) |
341 | .Device(DEVICE_DEFAULT) |
342 | .TypeConstraint<int32>("T" ) |
343 | .TypeConstraint<int64_t>("out_type" ) |
344 | .HostMemory("input" ) |
345 | .HostMemory("output" ), |
346 | SizeOp<int64_t>); |
347 | |
348 | // ExpandDims ------------------------------------ |
349 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
350 | .Device(DEVICE_CPU) |
351 | .HostMemory("dim" ) |
352 | .TypeConstraint<int32>("Tdim" ), |
353 | ExpandDimsOp<int32>); |
354 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
355 | .Device(DEVICE_CPU) |
356 | .HostMemory("dim" ) |
357 | .TypeConstraint<int64_t>("Tdim" ), |
358 | ExpandDimsOp<int64_t>); |
359 | |
360 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
361 | #define REGISTER_GPU_KERNEL(type) \ |
362 | REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ |
363 | .Device(DEVICE_GPU) \ |
364 | .TypeConstraint<type>("T") \ |
365 | .TypeConstraint<int32>("Tdim") \ |
366 | .HostMemory("dim"), \ |
367 | ExpandDimsOp<int32>); \ |
368 | REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ |
369 | .Device(DEVICE_GPU) \ |
370 | .TypeConstraint<type>("T") \ |
371 | .TypeConstraint<int64_t>("Tdim") \ |
372 | .HostMemory("dim"), \ |
373 | ExpandDimsOp<int64_t>); |
374 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
375 | TF_CALL_bool(REGISTER_GPU_KERNEL); |
376 | #undef REGISTER_GPU_KERNEL |
377 | |
378 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
379 | .Device(DEVICE_GPU) |
380 | .TypeConstraint<int32>("T" ) |
381 | .TypeConstraint<int32>("Tdim" ) |
382 | .HostMemory("input" ) |
383 | .HostMemory("dim" ) |
384 | .HostMemory("output" ), |
385 | ExpandDimsOp<int32>); |
386 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
387 | .Device(DEVICE_GPU) |
388 | .TypeConstraint<int32>("T" ) |
389 | .TypeConstraint<int64_t>("Tdim" ) |
390 | .HostMemory("input" ) |
391 | .HostMemory("dim" ) |
392 | .HostMemory("output" ), |
393 | ExpandDimsOp<int64_t>); |
394 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
395 | |
396 | #define REGISTER_DEFAULT_KERNEL(type) \ |
397 | REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ |
398 | .Device(DEVICE_DEFAULT) \ |
399 | .TypeConstraint<type>("T") \ |
400 | .TypeConstraint<int32>("Tdim") \ |
401 | .HostMemory("dim"), \ |
402 | ExpandDimsOp<int32>); \ |
403 | REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ |
404 | .Device(DEVICE_DEFAULT) \ |
405 | .TypeConstraint<type>("T") \ |
406 | .TypeConstraint<int64_t>("Tdim") \ |
407 | .HostMemory("dim"), \ |
408 | ExpandDimsOp<int64_t>); |
409 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
410 | TF_CALL_bool(REGISTER_DEFAULT_KERNEL); |
411 | #undef REGISTER_DEFAULT_KERNEL |
412 | |
413 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
414 | .Device(DEVICE_DEFAULT) |
415 | .TypeConstraint<int32>("T" ) |
416 | .TypeConstraint<int32>("Tdim" ) |
417 | .HostMemory("input" ) |
418 | .HostMemory("dim" ) |
419 | .HostMemory("output" ), |
420 | ExpandDimsOp<int32>); |
421 | REGISTER_KERNEL_BUILDER(Name("ExpandDims" ) |
422 | .Device(DEVICE_DEFAULT) |
423 | .TypeConstraint<int32>("T" ) |
424 | .TypeConstraint<int64_t>("Tdim" ) |
425 | .HostMemory("input" ) |
426 | .HostMemory("dim" ) |
427 | .HostMemory("output" ), |
428 | ExpandDimsOp<int64_t>); |
429 | |
430 | // Squeeze --------------------------------------- |
431 | REGISTER_KERNEL_BUILDER(Name("Squeeze" ).Device(DEVICE_CPU), SqueezeOp); |
432 | |
433 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
434 | #define REGISTER_GPU_KERNEL(type) \ |
435 | REGISTER_KERNEL_BUILDER( \ |
436 | Name("Squeeze").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
437 | SqueezeOp); |
438 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); |
439 | TF_CALL_bool(REGISTER_GPU_KERNEL); |
440 | #undef REGISTER_GPU_KERNEL |
441 | |
442 | // A special GPU kernel for int32. |
443 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
444 | // registration requires all int32 inputs and outputs to be in host memory. |
445 | REGISTER_KERNEL_BUILDER(Name("Squeeze" ) |
446 | .Device(DEVICE_GPU) |
447 | .TypeConstraint<int32>("T" ) |
448 | .HostMemory("input" ) |
449 | .HostMemory("output" ), |
450 | SqueezeOp); |
451 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
452 | |
453 | #define REGISTER_DEFAULT_KERNEL(type) \ |
454 | REGISTER_KERNEL_BUILDER( \ |
455 | Name("Squeeze").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \ |
456 | SqueezeOp); |
457 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
458 | TF_CALL_bool(REGISTER_DEFAULT_KERNEL); |
459 | #undef REGISTER_DEFAULT_KERNEL |
460 | |
461 | // A special GPU kernel for int32. |
462 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
463 | // registration requires all int32 inputs and outputs to be in host memory. |
464 | REGISTER_KERNEL_BUILDER(Name("Squeeze" ) |
465 | .Device(DEVICE_DEFAULT) |
466 | .TypeConstraint<int32>("T" ) |
467 | .HostMemory("input" ) |
468 | .HostMemory("output" ), |
469 | SqueezeOp); |
470 | |
471 | class EnsureShapeOp : public OpKernel { |
472 | public: |
473 | explicit EnsureShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
474 | OP_REQUIRES_OK(ctx, ctx->GetAttr("shape" , &expected_shape_)); |
475 | } |
476 | |
477 | void Compute(OpKernelContext* ctx) override { |
478 | TensorShape shape; |
479 | OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); |
480 | |
481 | if (!expected_shape_.IsCompatibleWith(shape)) { |
482 | ctx->SetStatus(errors::InvalidArgument( |
483 | "Shape of tensor " , this->def().input(0), " " , shape.DebugString(), |
484 | " is not compatible with expected shape " , |
485 | expected_shape_.DebugString(), "." )); |
486 | } |
487 | |
488 | // If shape matches, outputs the tensor. |
489 | if (IsRefType(ctx->input_dtype(0))) { |
490 | ctx->forward_ref_input_to_ref_output(0, 0); |
491 | } else { |
492 | ctx->set_output(0, ctx->input(0)); |
493 | } |
494 | } |
495 | |
496 | bool IsExpensive() override { return false; } |
497 | |
498 | private: |
499 | PartialTensorShape expected_shape_; |
500 | }; |
501 | |
502 | // NOTE(rachelim): The kernel registrations for EnsureShapeOp are identical to |
503 | // those of the identity op, since the ops have the same device type |
504 | // constraints. |
505 | REGISTER_KERNEL_BUILDER(Name("EnsureShape" ).Device(DEVICE_CPU), EnsureShapeOp); |
506 | |
507 | #define REGISTER_DEVICE_KERNEL(type) \ |
508 | REGISTER_KERNEL_BUILDER( \ |
509 | Name("EnsureShape").Device(DEVICE_DEFAULT).TypeConstraint<type>("T"), \ |
510 | EnsureShapeOp) |
511 | |
512 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEVICE_KERNEL); |
513 | REGISTER_DEVICE_KERNEL(Variant); |
514 | |
515 | #undef REGISTER_DEVICE_KERNEL |
516 | |
517 | // A special DEVICE_DEFAULT kernel for int32 and bool. |
518 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
519 | // registration requires all int32 inputs and outputs to be in host memory. |
520 | #define REGISTER_DEVICE_HOST_KERNEL(type) \ |
521 | REGISTER_KERNEL_BUILDER(Name("EnsureShape") \ |
522 | .Device(DEVICE_DEFAULT) \ |
523 | .HostMemory("input") \ |
524 | .HostMemory("output") \ |
525 | .TypeConstraint<type>("T"), \ |
526 | EnsureShapeOp) |
527 | |
528 | REGISTER_DEVICE_HOST_KERNEL(int32); |
529 | REGISTER_DEVICE_HOST_KERNEL(bool); |
530 | REGISTER_DEVICE_HOST_KERNEL(tstring); |
531 | REGISTER_DEVICE_HOST_KERNEL(ResourceHandle); |
532 | |
533 | #undef REGISTER_DEVICE_HOST_KERNEL |
534 | |
535 | } // namespace tensorflow |
536 | |