1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include <vector>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27
28#include "tensorflow/core/framework/numeric_op.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_types.h"
33#include "tensorflow/core/framework/type_index.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/gtl/array_slice.h"
37#include "tensorflow/core/platform/macros.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace tensorflow {
41
42typedef Eigen::ThreadPoolDevice CPUDevice;
43typedef Eigen::GpuDevice GPUDevice;
44
45// Forward declarations of functors that will be defined in tile_ops_impl.h
46namespace functor {
47template <typename Device, typename T, typename Tmultiple>
48struct Tile {
49 void operator()(const Device& d, Tensor* out, const Tensor& in,
50 const gtl::ArraySlice<Tmultiple> broadcast_array) const;
51};
52
53template <typename Device, typename T, int NDIM>
54struct TileGrad {
55 void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
56 typename TTypes<T, NDIM>::ConstTensor in,
57 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
58 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
59 bool first) const;
60};
61
62template <typename Device, typename T>
63struct TileGrad<Device, T, 0> {
64 void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
65 typename TTypes<T, 0>::ConstTensor in,
66 const Eigen::DSizes<Eigen::DenseIndex, 0>&,
67 const Eigen::DSizes<Eigen::DenseIndex, 0>&, bool first) const;
68};
69
70template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
71struct ReduceAndReshape {
72 void operator()(
73 const Device& d, typename TTypes<T, NDIM>::Tensor out,
74 typename TTypes<T, NDIM>::ConstTensor in,
75 const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
76 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const;
77};
78
79// Explicit instantiations are defined in tile_ops_{cpu,gpu}_impl.*,
80// below are their declarations.
81
82#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
83extern template struct Tile<GPUDevice, bool, int32>;
84extern template struct Tile<GPUDevice, bool, int64_t>;
85extern template struct Tile<GPUDevice, float, int32>;
86extern template struct Tile<GPUDevice, float, int64_t>;
87extern template struct Tile<GPUDevice, double, int32>;
88extern template struct Tile<GPUDevice, double, int64_t>;
89extern template struct Tile<GPUDevice, complex64, int32>;
90extern template struct Tile<GPUDevice, complex64, int64_t>;
91extern template struct Tile<GPUDevice, complex128, int32>;
92extern template struct Tile<GPUDevice, complex128, int64_t>;
93extern template struct Tile<GPUDevice, Eigen::half, int32>;
94extern template struct Tile<GPUDevice, Eigen::half, int64_t>;
95extern template struct Tile<GPUDevice, int16, int32>;
96extern template struct Tile<GPUDevice, int16, int64_t>;
97extern template struct Tile<GPUDevice, int32, int32>;
98extern template struct Tile<GPUDevice, int32, int64_t>;
99extern template struct Tile<GPUDevice, int64_t, int32>;
100extern template struct Tile<GPUDevice, int64_t, int64_t>;
101#define DECLARE_CUDA_DIM(T, NDIM) \
102 extern template struct TileGrad<GPUDevice, T, NDIM>; \
103 extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>
104#else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
105#define DECLARE_CUDA_DIM(T, NDIM)
106#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
107
108#define DECLARE_TYPE(T) \
109 extern template struct Tile<CPUDevice, T, int32>; \
110 extern template struct Tile<CPUDevice, T, int64_t>;
111TF_CALL_bool(DECLARE_TYPE);
112TF_CALL_float(DECLARE_TYPE);
113TF_CALL_bfloat16(DECLARE_TYPE);
114TF_CALL_double(DECLARE_TYPE);
115TF_CALL_uint8(DECLARE_TYPE);
116TF_CALL_int32(DECLARE_TYPE);
117TF_CALL_int16(DECLARE_TYPE);
118TF_CALL_int64(DECLARE_TYPE);
119TF_CALL_uint32(DECLARE_TYPE);
120TF_CALL_uint64(DECLARE_TYPE);
121TF_CALL_half(DECLARE_TYPE);
122TF_CALL_complex64(DECLARE_TYPE);
123TF_CALL_complex128(DECLARE_TYPE);
124TF_CALL_tstring(DECLARE_TYPE);
125TF_CALL_variant(DECLARE_TYPE);
126#undef DECLARE_TYPE
127
128#define DECLARE_DIM(T, NDIM) \
129 DECLARE_CUDA_DIM(T, NDIM); \
130 extern template struct TileGrad<CPUDevice, T, NDIM>; \
131 extern template struct ReduceAndReshape<CPUDevice, T, NDIM, 1>;
132
133#define DECLARE_TYPE(T) \
134 DECLARE_DIM(T, 1) \
135 DECLARE_DIM(T, 2) \
136 DECLARE_DIM(T, 3) \
137 DECLARE_DIM(T, 4) \
138 DECLARE_DIM(T, 5) \
139 DECLARE_DIM(T, 6) \
140 DECLARE_DIM(T, 7)
141TF_CALL_float(DECLARE_TYPE);
142TF_CALL_bfloat16(DECLARE_TYPE);
143TF_CALL_double(DECLARE_TYPE);
144TF_CALL_int16(DECLARE_TYPE);
145TF_CALL_int32(DECLARE_TYPE);
146TF_CALL_int64(DECLARE_TYPE);
147TF_CALL_half(DECLARE_TYPE);
148TF_CALL_complex64(DECLARE_TYPE);
149TF_CALL_complex128(DECLARE_TYPE);
150#undef DECLARE_TYPE
151
152#undef DECLARE_DIM
153#undef DECLARE_CUDA_DIM
154
155} // namespace functor
156
157// --------------------------------------------------------------------------
158template <typename Device, typename Tmultiples>
159class TileOp : public OpKernel {
160 public:
161 explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {}
162
163 void Compute(OpKernelContext* context) override {
164 const Tensor& input = context->input(0);
165 const Tensor& multiples = context->input(1);
166
167 OP_REQUIRES(
168 context, TensorShapeUtils::IsVector(multiples.shape()),
169 errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
170 multiples.shape().DebugString()));
171 OP_REQUIRES(context, input.dims() == multiples.NumElements(),
172 errors::InvalidArgument(
173 "Expected multiples argument to be a vector of length ",
174 input.dims(), " but got length ", multiples.dim_size(0)));
175 const int input_dims = input.dims();
176
177 // Eigen doesn't support scalars on the GPU, so handle 0-D specially
178 if (input_dims == 0) {
179 context->set_output(0, input);
180 return;
181 }
182
183 const gtl::ArraySlice<Tmultiples> multiples_array(
184 multiples.flat<Tmultiples>().data(), input_dims);
185 TensorShape output_shape;
186 for (int i = 0; i < input_dims; ++i) {
187 OP_REQUIRES(
188 context, multiples_array[i] >= 0,
189 errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ",
190 multiples_array[i]));
191 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(
192 input.dim_size(i) * multiples_array[i]));
193 }
194 if (output_shape == input.shape()) {
195 context->set_output(0, input);
196 return;
197 }
198 Tensor* result = nullptr;
199 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
200
201 // If there's no output, there's nothing to do.
202 if (output_shape.num_elements() == 0) return;
203
204#define HANDLE_TYPE(DT) \
205 if (context->input(0).dtype() == DT) { \
206 HandleCase<DT>(context, multiples_array, result); \
207 return; \
208 }
209
210#define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value)
211
212 // Invoke macro using TF_CALL_* so type-filtering for platform applies.
213 TF_CALL_bool(HANDLE_TYPE_NAME);
214 TF_CALL_bfloat16(HANDLE_TYPE_NAME);
215 TF_CALL_float(HANDLE_TYPE_NAME);
216 TF_CALL_double(HANDLE_TYPE_NAME);
217 TF_CALL_uint8(HANDLE_TYPE_NAME);
218 TF_CALL_int8(HANDLE_TYPE_NAME);
219 TF_CALL_int32(HANDLE_TYPE_NAME);
220 TF_CALL_int16(HANDLE_TYPE_NAME);
221 TF_CALL_int64(HANDLE_TYPE_NAME);
222 TF_CALL_uint32(HANDLE_TYPE_NAME);
223 TF_CALL_uint64(HANDLE_TYPE_NAME);
224 TF_CALL_half(HANDLE_TYPE_NAME);
225 TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
226 TF_CALL_complex64(HANDLE_TYPE_NAME);
227 TF_CALL_complex128(HANDLE_TYPE_NAME);
228 TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice
229
230#undef HANDLE_TYPE_NAME
231#undef HANDLE_TYPE
232
233 OP_REQUIRES(
234 context, false,
235 errors::Unimplemented(
236 "TileOp : The input data type is not supported, DataType : ",
237 DataTypeString(context->input(0).dtype()),
238 ", Dimension : ", input_dims));
239 }
240
241 private:
242 template <DataType DT>
243 void HandleCaseImpl(OpKernelContext* context,
244 const gtl::ArraySlice<Tmultiples> multiples_array,
245 Tensor* result) {
246 typedef typename EnumToDataType<DT>::Type T;
247 functor::Tile<Device, T, Tmultiples>()(context->eigen_device<Device>(),
248 result, context->input(0),
249 multiples_array);
250 }
251
252 template <DataType DT>
253 void HandleCase(OpKernelContext* context,
254 const gtl::ArraySlice<Tmultiples> multiples_array,
255 Tensor* result);
256
257 TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
258};
259
260template <typename Device, typename Tmultiples>
261template <DataType DT>
262inline void TileOp<Device, Tmultiples>::HandleCase(
263 OpKernelContext* context, const gtl::ArraySlice<Tmultiples> multiples_array,
264 Tensor* result) {
265 // TODO(vrv): print out the device name if useful. Currently disabled to avoid
266 // having to use RTTI.
267 LOG(FATAL) << "TileOp: Invalid combination of Device, DT: "
268 // << typeid(Device).name() << ", "
269 << DataTypeString(DT);
270}
271
272#define HANDLE_CASE(device, dtype, Tmultiples) \
273 template <> \
274 template <> \
275 void TileOp<device, Tmultiples>::HandleCase<dtype>( \
276 OpKernelContext * context, \
277 const gtl::ArraySlice<Tmultiples> multiples_array, Tensor* result) { \
278 HandleCaseImpl<dtype>(context, multiples_array, result); \
279 }
280
281#define HANDLE_TYPE_NAME_CPU(T) \
282 HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int32); \
283 HANDLE_CASE(CPUDevice, DataTypeToEnum<T>::value, int64_t);
284
285#define HANDLE_TYPE_NAME_GPU(T) \
286 HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int32); \
287 HANDLE_CASE(GPUDevice, DataTypeToEnum<T>::value, int64_t);
288
289TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
290TF_CALL_float(HANDLE_TYPE_NAME_CPU);
291TF_CALL_bfloat16(HANDLE_TYPE_NAME_CPU);
292TF_CALL_double(HANDLE_TYPE_NAME_CPU);
293TF_CALL_uint8(HANDLE_TYPE_NAME_CPU);
294TF_CALL_int8(HANDLE_TYPE_NAME_CPU);
295TF_CALL_int32(HANDLE_TYPE_NAME_CPU);
296TF_CALL_int16(HANDLE_TYPE_NAME_CPU);
297TF_CALL_int64(HANDLE_TYPE_NAME_CPU);
298TF_CALL_uint32(HANDLE_TYPE_NAME_CPU);
299TF_CALL_uint64(HANDLE_TYPE_NAME_CPU);
300TF_CALL_half(HANDLE_TYPE_NAME_CPU);
301TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
302TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
303TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
304TF_CALL_variant(HANDLE_TYPE_NAME_CPU);
305
306#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
307TF_CALL_bool(HANDLE_TYPE_NAME_GPU);
308TF_CALL_float(HANDLE_TYPE_NAME_GPU);
309TF_CALL_double(HANDLE_TYPE_NAME_GPU);
310TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
311TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
312TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
313TF_CALL_half(HANDLE_TYPE_NAME_GPU);
314TF_CALL_complex64(HANDLE_TYPE_NAME_GPU);
315TF_CALL_complex128(HANDLE_TYPE_NAME_GPU);
316#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
317
318
319#undef HANDLE_TYPE_NAME_CPU
320#undef HANDLE_TYPE_NAME_GPU
321#undef HANDLE_CASE
322
323// --------------------------------------------------------------------------
324template <typename Device, typename Tmultiples>
325class TileGradientOp : public OpKernel {
326 public:
327 explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
328
329 void Compute(OpKernelContext* context) override {
330 const Tensor& input = context->input(0);
331 const Tensor& multiples = context->input(1);
332 OP_REQUIRES(
333 context, TensorShapeUtils::IsVector(multiples.shape()),
334 errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
335 multiples.shape().DebugString()));
336 OP_REQUIRES(context, input.dims() == multiples.NumElements(),
337 errors::InvalidArgument(
338 "Expected multiples argument to be a vector of length ",
339 input.dims(), " but got length ", multiples.dim_size(0)));
340
341 const int input_dims = input.dims();
342
343 // Eigen doesn't support scalars on the GPU, so handle 0-D specially
344 if (input_dims == 0) {
345 context->set_output(0, input);
346 return;
347 }
348
349 const gtl::ArraySlice<Tmultiples> multiples_array(
350 multiples.flat<Tmultiples>().data(), input_dims);
351 TensorShape output_shape;
352 std::vector<Tmultiples> input_dim_size_vec;
353 for (int i = 0; i < input_dims; ++i) {
354 OP_REQUIRES(
355 context, multiples_array[i] > 0,
356 errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ",
357 multiples_array[i]));
358 OP_REQUIRES(context, input.dim_size(i) % multiples_array[i] == 0,
359 errors::InvalidArgument("Expected input_dim[", i,
360 "] to be divisible by multiples[", i,
361 "], but ", input.dim_size(i), " % ",
362 multiples_array[i], " != 0"));
363 output_shape.AddDim(input.dim_size(i) / multiples_array[i]);
364 input_dim_size_vec.push_back(input.dim_size(i));
365 }
366 if (output_shape == input.shape()) {
367 context->set_output(0, input);
368 return;
369 }
370 Tensor* result = nullptr;
371 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
372
373#define HANDLE_DIM(DT, NDIM) \
374 if (context->input(0).dtype() == DT && input_dims == NDIM) { \
375 HandleCase<DT, NDIM>(context, input_dim_size_vec, multiples_array, \
376 result); \
377 return; \
378 }
379
380#define HANDLE_TYPE(T) \
381 HANDLE_DIM(T, 1) \
382 HANDLE_DIM(T, 2) \
383 HANDLE_DIM(T, 3) \
384 HANDLE_DIM(T, 4) \
385 HANDLE_DIM(T, 5) \
386 HANDLE_DIM(T, 6) \
387 HANDLE_DIM(T, 7)
388
389#define HANDLE_TYPE_NAME(T) HANDLE_TYPE(DataTypeToEnum<T>::value)
390
391 TF_CALL_float(HANDLE_TYPE_NAME);
392 TF_CALL_double(HANDLE_TYPE_NAME);
393 TF_CALL_int32(HANDLE_TYPE_NAME);
394 TF_CALL_int16(HANDLE_TYPE_NAME);
395 TF_CALL_int64(HANDLE_TYPE_NAME);
396 TF_CALL_half(HANDLE_TYPE_NAME);
397 TF_CALL_bfloat16(HANDLE_TYPE_NAME);
398 TF_CALL_complex64(HANDLE_TYPE_NAME);
399 TF_CALL_complex128(HANDLE_TYPE_NAME);
400
401#undef HANDLE_TYPE_NAME
402#undef HANDLE_TYPE
403#undef HANDLE_DIM
404
405 OP_REQUIRES(context, false,
406 errors::Unimplemented("TileGradientOp : The input data type or "
407 "dimension is not supported, DataType : ",
408 DataTypeString(context->input(0).dtype()),
409 ", Dimension : ", input_dims));
410 }
411
412 private:
413 template <DataType DT, int NDIM>
414 void HandleCase(OpKernelContext* context,
415 const std::vector<Tmultiples>& input_dims,
416 const gtl::ArraySlice<Tmultiples> multiples_array,
417 Tensor* result);
418
419 template <DataType DT, int NDIM>
420 void HandleCaseImpl(OpKernelContext* context,
421 const std::vector<Tmultiples>& input_dims,
422 const gtl::ArraySlice<Tmultiples> multiples_array,
423 Tensor* result) {
424 typedef typename EnumToDataType<DT>::Type T;
425
426 bool reduction_only = true;
427 std::vector<Tmultiples> reduction_dims;
428
429 for (int i = 0; i < NDIM; ++i) {
430 if (input_dims[i] > multiples_array[i] && multiples_array[i] > 1) {
431 reduction_only = false;
432 break;
433 } else {
434 if (multiples_array[i] == input_dims[i]) {
435 reduction_dims.push_back(i);
436 }
437 }
438 }
439
440 if (reduction_only) {
441#define HANDLE_DIM(D) \
442 if (reduction_dims.size() == (D)) { \
443 HandleReduce<T, NDIM, (D)>(context, reduction_dims, result); \
444 return; \
445 }
446 // NOTE(keveman): Handling the most common case here.
447 // Adding more cases here would require more templating and code
448 // explosion. For instance, HANDLE_DIM(2) wouldn't make sense for NDIM=1.
449 HANDLE_DIM(1);
450
451// Fall through to the unoptimized version.
452#undef HANDLE_DIM
453 }
454
455 Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
456 Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
457
458 // Accumulate slices along the dimensions into the output. The number of
459 // slices along dimension 'i' is simply the multiple along dimension 'i'
460 // passed to the original Tile op.
461 for (int i = 0; i < NDIM; ++i) {
462 sizes[i] = input_dims[i] / multiples_array[i];
463 indices[i] = 0;
464 }
465
466 bool first = true;
467 while (true) {
468 functor::TileGrad<Device, T, NDIM>()(
469 context->eigen_device<Device>(), result->tensor<T, NDIM>(),
470 context->input(0).tensor<T, NDIM>(), indices, sizes, first);
471 first = false;
472 // Increment the begin indices.
473 int i = 0;
474 while (i < NDIM && indices[i] / sizes[i] == multiples_array[i] - 1) {
475 indices[i] = 0;
476 ++i;
477 }
478 // We are finished if we have iterated to the maximum along all
479 // dimensions.
480 if (i == NDIM) {
481 break;
482 }
483 indices[i] += sizes[i];
484 }
485 }
486
487 template <typename T, int NDIM, int REDUCENDIM>
488 void HandleReduce(OpKernelContext* context,
489 const std::vector<Tmultiples>& reduce_dim_in,
490 Tensor* result) {
491 static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions");
492 Eigen::DSizes<Eigen::DenseIndex, REDUCENDIM> reduce_dim;
493 Eigen::DSizes<Eigen::DenseIndex, NDIM> reshape_dim;
494
495 for (int i = 0; i < REDUCENDIM; ++i) {
496 reduce_dim[i] = reduce_dim_in[i];
497 }
498
499 for (int i = 0; i < NDIM; ++i) {
500 reshape_dim[i] = result->dim_size(i);
501 }
502
503 functor::ReduceAndReshape<Device, T, NDIM, REDUCENDIM>()(
504 context->eigen_device<Device>(), result->tensor<T, NDIM>(),
505 context->input(0).tensor<T, NDIM>(), reduce_dim, reshape_dim);
506 }
507
508 TF_DISALLOW_COPY_AND_ASSIGN(TileGradientOp);
509};
510
511template <typename Device, typename Tmultiples>
512template <DataType DT, int NDIM>
513inline void TileGradientOp<Device, Tmultiples>::HandleCase(
514 OpKernelContext* context, const std::vector<Tmultiples>& input_dims,
515 const gtl::ArraySlice<Tmultiples> multiples_array, Tensor* result) {
516 LOG(FATAL) << "TileGradientOp: Invalid combination of Device, DT and NDIM: "
517 << TypeIndex::Make<Device>().name() << ", " << DataTypeString(DT)
518 << ", " << NDIM;
519}
520
521#define HANDLE_CASE(device, T, dtype, Tmultiples, ndim) \
522 template <> \
523 template <> \
524 void TileGradientOp<device, Tmultiples>::HandleCase<dtype, ndim>( \
525 OpKernelContext * context, const std::vector<Tmultiples>& input_dims, \
526 const gtl::ArraySlice<Tmultiples> multiples_array, Tensor* result) { \
527 HandleCaseImpl<dtype, ndim>(context, input_dims, multiples_array, result); \
528 }
529
530// 0-D handled specially above
531#define HANDLE_CASE_DIM(device, T, dtype) \
532 HANDLE_CASE(device, T, dtype, int32, 1); \
533 HANDLE_CASE(device, T, dtype, int32, 2); \
534 HANDLE_CASE(device, T, dtype, int32, 3); \
535 HANDLE_CASE(device, T, dtype, int32, 4); \
536 HANDLE_CASE(device, T, dtype, int32, 5); \
537 HANDLE_CASE(device, T, dtype, int32, 6); \
538 HANDLE_CASE(device, T, dtype, int32, 7); \
539 HANDLE_CASE(device, T, dtype, int64_t, 1); \
540 HANDLE_CASE(device, T, dtype, int64_t, 2); \
541 HANDLE_CASE(device, T, dtype, int64_t, 3); \
542 HANDLE_CASE(device, T, dtype, int64_t, 4); \
543 HANDLE_CASE(device, T, dtype, int64_t, 5); \
544 HANDLE_CASE(device, T, dtype, int64_t, 6); \
545 HANDLE_CASE(device, T, dtype, int64_t, 7);
546
547#define HANDLE_TYPE_NAME_CPU(T) \
548 HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
549
550#define HANDLE_TYPE_NAME_GPU(T) \
551 HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
552
553TF_CALL_float(HANDLE_TYPE_NAME_CPU);
554TF_CALL_double(HANDLE_TYPE_NAME_CPU);
555TF_CALL_int16(HANDLE_TYPE_NAME_CPU);
556TF_CALL_int32(HANDLE_TYPE_NAME_CPU);
557TF_CALL_int64(HANDLE_TYPE_NAME_CPU);
558TF_CALL_half(HANDLE_TYPE_NAME_CPU);
559TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
560TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
561
562#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
563TF_CALL_float(HANDLE_TYPE_NAME_GPU);
564TF_CALL_double(HANDLE_TYPE_NAME_GPU);
565TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
566TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
567TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
568TF_CALL_half(HANDLE_TYPE_NAME_GPU);
569TF_CALL_complex64(HANDLE_TYPE_NAME_GPU);
570TF_CALL_complex128(HANDLE_TYPE_NAME_GPU);
571#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
572
573
574#undef HANDLE_TYPE_NAME_CPU
575#undef HANDLE_TYPE_NAME_GPU
576#undef HANDLE_CASE_DIM
577#undef HANDLE_CASE
578
579REGISTER_KERNEL_BUILDER(Name("Tile")
580 .Device(DEVICE_CPU)
581 .HostMemory("multiples")
582 .TypeConstraint<int32>("Tmultiples"),
583 TileOp<CPUDevice, int32>);
584REGISTER_KERNEL_BUILDER(Name("Tile")
585 .Device(DEVICE_CPU)
586 .HostMemory("multiples")
587 .TypeConstraint<int64_t>("Tmultiples"),
588 TileOp<CPUDevice, int64>);
589REGISTER_KERNEL_BUILDER(Name("TileGrad")
590 .Device(DEVICE_CPU)
591 .HostMemory("multiples")
592 .TypeConstraint<int32>("Tmultiples"),
593 TileGradientOp<CPUDevice, int32>);
594REGISTER_KERNEL_BUILDER(Name("TileGrad")
595 .Device(DEVICE_CPU)
596 .HostMemory("multiples")
597 .TypeConstraint<int64_t>("Tmultiples"),
598 TileGradientOp<CPUDevice, int64>);
599
600#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
601#define REGISTER_GPU_TILE(type) \
602 REGISTER_KERNEL_BUILDER(Name("Tile") \
603 .Device(DEVICE_GPU) \
604 .TypeConstraint<type>("T") \
605 .TypeConstraint<int32>("Tmultiples") \
606 .HostMemory("multiples"), \
607 TileOp<GPUDevice, int32>); \
608 REGISTER_KERNEL_BUILDER(Name("Tile") \
609 .Device(DEVICE_GPU) \
610 .TypeConstraint<type>("T") \
611 .TypeConstraint<int64_t>("Tmultiples") \
612 .HostMemory("multiples"), \
613 TileOp<GPUDevice, int64>);
614
615#define REGISTER_GPU_TILE_GRAD(type) \
616 REGISTER_KERNEL_BUILDER(Name("TileGrad") \
617 .Device(DEVICE_GPU) \
618 .TypeConstraint<type>("T") \
619 .TypeConstraint<int32>("Tmultiples") \
620 .HostMemory("multiples"), \
621 TileGradientOp<GPUDevice, int32>); \
622 REGISTER_KERNEL_BUILDER(Name("TileGrad") \
623 .Device(DEVICE_GPU) \
624 .TypeConstraint<type>("T") \
625 .TypeConstraint<int64_t>("Tmultiples") \
626 .HostMemory("multiples"), \
627 TileGradientOp<GPUDevice, int64>);
628
629#define REGISTER_GPU(type) \
630 REGISTER_GPU_TILE(type); \
631 REGISTER_GPU_TILE_GRAD(type);
632
633TF_CALL_bool(REGISTER_GPU_TILE);
634TF_CALL_float(REGISTER_GPU);
635TF_CALL_double(REGISTER_GPU);
636TF_CALL_half(REGISTER_GPU);
637TF_CALL_int16(REGISTER_GPU);
638TF_CALL_int32(REGISTER_GPU);
639TF_CALL_int64(REGISTER_GPU);
640TF_CALL_complex64(REGISTER_GPU);
641TF_CALL_complex128(REGISTER_GPU)
642
643#undef REGISTER_GPU_TILE
644#undef REGISTER_GPU_TILE_GRAD
645#undef REGISTER_GPU
646#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
647
648
649} // namespace tensorflow
650