1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
19 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
20#define EIGEN_USE_GPU
21#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
24// Above is the related header but clang tidy doesn't recognize it.
25#include "tensorflow/core/framework/numeric_op.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/monitoring/gauge.h"
30#include "tensorflow/core/platform/protobuf.h"
31#include "tensorflow/core/util/determinism.h"
32
33using tensorflow::BinaryElementWiseOp;
34using tensorflow::DEVICE_CPU;
35#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
36 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
37using tensorflow::DEVICE_GPU;
38#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39using tensorflow::OpKernel;
40using tensorflow::OpKernelConstruction;
41using tensorflow::OpKernelContext;
42using tensorflow::Tensor;
43using tensorflow::TensorShape;
44using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove.
45using tensorflow::UnaryElementWiseOp;
46using tensorflow::errors::InvalidArgument;
47
48namespace tensorflow {
49
50typedef Eigen::ThreadPoolDevice CPUDevice;
51
52auto* using_fake_quant = monitoring::Gauge<bool, 0>::New(
53 "/tensorflow/api/op/using_fake_quantization",
54 "True if a fake_quant op is created.");
55
56#define SET_USING_FAKE_QUANT() using_fake_quant->GetCell()->Set(true)
57
58namespace {
59bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; }
60} // namespace
61
62// -----------------------------------------------------------------------------
63// Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in
64// core/ops/array_ops.cc.
65template <typename Device>
66class FakeQuantWithMinMaxArgsOp
67 : public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> {
68 public:
69 typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base;
70 explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context)
71 : Base::UnaryElementWiseOp(context) {
72 OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
73 OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
74 OP_REQUIRES(context, min_ < max_,
75 InvalidArgument("min has to be smaller than max, was: ", min_,
76 " >= ", max_));
77 int num_bits;
78 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
79 OP_REQUIRES(
80 context, IsNumBitsValid(num_bits),
81 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
82 bool narrow_range;
83 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
84 quant_min_ = narrow_range ? 1 : 0;
85 quant_max_ = (1 << num_bits) - 1;
86 SET_USING_FAKE_QUANT();
87 }
88
89 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
90 FakeQuantWithMinMaxArgsFunctor<Device> functor;
91 functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
92 quant_min_, quant_max_, output->flat<float>());
93 }
94
95 private:
96 float min_;
97 float max_;
98 int quant_min_;
99 int quant_max_;
100};
101
102// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
103// core/ops/array_ops.cc.
104template <typename Device>
105class FakeQuantWithMinMaxArgsGradientOp
106 : public BinaryElementWiseOp<float,
107 FakeQuantWithMinMaxArgsGradientOp<Device>> {
108 public:
109 typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>>
110 Base;
111 explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context)
112 : Base::BinaryElementWiseOp(context) {
113 OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
114 OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
115 OP_REQUIRES(context, min_ < max_,
116 InvalidArgument("min has to be smaller than max, was: ", min_,
117 " >= ", max_));
118 int num_bits;
119 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
120 OP_REQUIRES(
121 context, IsNumBitsValid(num_bits),
122 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
123 bool narrow_range;
124 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
125 quant_min_ = narrow_range ? 1 : 0;
126 quant_max_ = (1 << num_bits) - 1;
127 }
128
129 template <int NDIMS>
130 void Operate(OpKernelContext* context, const Tensor& gradient,
131 const Tensor& input, Tensor* output) {
132 OperateNoTemplate(context, gradient, input, output);
133 }
134
135 void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
136 const Tensor& input, Tensor* output) {
137 OP_REQUIRES(context, input.IsSameSize(gradient),
138 InvalidArgument("gradient and input must be the same size"));
139 FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
140 functor(context->eigen_device<Device>(), gradient.flat<float>(),
141 input.flat<float>(), min_, max_, quant_min_, quant_max_,
142 output->flat<float>());
143 }
144
145 private:
146 float min_;
147 float max_;
148 int quant_min_;
149 int quant_max_;
150};
151
152REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
153 FakeQuantWithMinMaxArgsOp<CPUDevice>);
154REGISTER_KERNEL_BUILDER(
155 Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
156 FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
157
158#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
159 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
160typedef Eigen::GpuDevice GPUDevice;
161
162// Forward declarations for functor specializations for GPU.
163template <>
164void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
165 const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
166 const float min, const float max, const int quant_min, const int quant_max,
167 typename TTypes<float>::Flat outputs);
168extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
169REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
170 FakeQuantWithMinMaxArgsOp<GPUDevice>);
171
172template <>
173void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
174 const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
175 typename TTypes<float>::ConstFlat inputs, const float min, const float max,
176 const int quant_min, const int quant_max,
177 typename TTypes<float>::Flat backprops);
178REGISTER_KERNEL_BUILDER(
179 Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
180 FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
181#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
182
183// -----------------------------------------------------------------------------
184// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
185// core/ops/array_ops.cc.
186template <typename Device>
187class FakeQuantWithMinMaxVarsOp : public OpKernel {
188 public:
189 explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context)
190 : OpKernel::OpKernel(context) {
191 int num_bits;
192 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
193 OP_REQUIRES(
194 context, IsNumBitsValid(num_bits),
195 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
196 bool narrow_range;
197 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
198 quant_min_ = narrow_range ? 1 : 0;
199 quant_max_ = (1 << num_bits) - 1;
200 SET_USING_FAKE_QUANT();
201 }
202
203 void Compute(OpKernelContext* context) override {
204 CHECK_EQ(3, context->num_inputs());
205 const Tensor& input = context->input(0);
206 const Tensor& min = context->input(1);
207 const Tensor& max = context->input(2);
208
209 OP_REQUIRES(
210 context, TensorShapeUtils::IsScalar(min.shape()),
211 InvalidArgument("`min` must be rank 0 but is rank ", min.dims()));
212 OP_REQUIRES(
213 context, TensorShapeUtils::IsScalar(max.shape()),
214 InvalidArgument("`max` must be rank 0 but is rank ", max.dims()));
215
216 Tensor* output;
217 OP_REQUIRES_OK(context,
218 context->allocate_output(0, input.shape(), &output));
219
220 FakeQuantWithMinMaxVarsFunctor<Device> functor;
221 functor(context->eigen_device<Device>(), input.flat<float>(),
222 min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_,
223 output->flat<float>());
224 }
225
226 private:
227 int quant_min_;
228 int quant_max_;
229};
230
231// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
232// core/ops/array_ops.cc.
233template <typename Device>
234class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
235 public:
236 explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context)
237 : OpKernel::OpKernel(context) {
238 int num_bits;
239 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
240 OP_REQUIRES(
241 context, IsNumBitsValid(num_bits),
242 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
243 bool narrow_range;
244 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
245 quant_min_ = narrow_range ? 1 : 0;
246 quant_max_ = (1 << num_bits) - 1;
247 if (std::is_same<Device, Eigen::GpuDevice>::value) {
248 OP_REQUIRES(
249 context, !OpDeterminismRequired(),
250 errors::Unimplemented(
251 "Determinism is not yet supported in GPU implementation of "
252 "FakeQuantWithMinMaxVarsGradient."));
253 }
254 }
255
256 void Compute(OpKernelContext* context) override {
257 CHECK_EQ(4, context->num_inputs());
258 const Tensor& gradient = context->input(0);
259 const Tensor& input = context->input(1);
260 OP_REQUIRES(context, input.IsSameSize(gradient),
261 InvalidArgument("gradient and input must be the same size"));
262 const Tensor& min = context->input(2);
263 const Tensor& max = context->input(3);
264 OP_REQUIRES(
265 context, TensorShapeUtils::IsScalar(min.shape()),
266 InvalidArgument("`min` must be rank 0 but is rank ", min.dims()));
267 OP_REQUIRES(
268 context, TensorShapeUtils::IsScalar(max.shape()),
269 InvalidArgument("`max` must be rank 0 but is rank ", max.dims()));
270
271 Tensor* grad_wrt_input;
272 OP_REQUIRES_OK(context,
273 context->allocate_output(0, input.shape(), &grad_wrt_input));
274
275 TensorShape scalar_shape;
276 Tensor* grad_wrt_min;
277 OP_REQUIRES_OK(context,
278 context->allocate_output(1, scalar_shape, &grad_wrt_min));
279
280 Tensor* grad_wrt_max;
281 OP_REQUIRES_OK(context,
282 context->allocate_output(2, scalar_shape, &grad_wrt_max));
283
284 FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
285 functor(context->eigen_device<Device>(), gradient.flat<float>(),
286 input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
287 quant_min_, quant_max_, grad_wrt_input->flat<float>(),
288 grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>());
289 }
290
291 private:
292 int quant_min_;
293 int quant_max_;
294};
295
296REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
297 FakeQuantWithMinMaxVarsOp<CPUDevice>);
298REGISTER_KERNEL_BUILDER(
299 Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
300 FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
301
302#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
303 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
304template <>
305void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
306 const GPUDevice& d, typename TTypes<float>::ConstFlat inputs,
307 typename TTypes<float>::ConstScalar min,
308 typename TTypes<float>::ConstScalar max, const int quant_min,
309 const int quant_max, typename TTypes<float>::Flat output);
310extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
311REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
312 .Device(DEVICE_GPU)
313 .HostMemory("min")
314 .HostMemory("max"),
315 FakeQuantWithMinMaxVarsOp<GPUDevice>);
316
317template <>
318void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
319 const GPUDevice& d, typename TTypes<float>::ConstFlat gradients,
320 typename TTypes<float>::ConstFlat inputs,
321 typename TTypes<float>::ConstScalar min,
322 typename TTypes<float>::ConstScalar max, const int quant_min,
323 const int quant_max, typename TTypes<float>::Flat backprops_wrt_input,
324 typename TTypes<float>::Scalar backprop_wrt_min,
325 typename TTypes<float>::Scalar backprop_wrt_max);
326extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
327REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
328 .Device(DEVICE_GPU)
329 .HostMemory("min")
330 .HostMemory("max"),
331 FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
332#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
333
334// -----------------------------------------------------------------------------
335// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
336// in core/ops/array_ops.cc.
337template <typename Device>
338class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
339 public:
340 explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context)
341 : OpKernel::OpKernel(context) {
342 int num_bits;
343 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
344 OP_REQUIRES(
345 context, IsNumBitsValid(num_bits),
346 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
347 bool narrow_range;
348 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
349 quant_min_ = narrow_range ? 1 : 0;
350 quant_max_ = (1 << num_bits) - 1;
351 SET_USING_FAKE_QUANT();
352 }
353
354 void Compute(OpKernelContext* context) override {
355 CHECK_EQ(3, context->num_inputs());
356 const Tensor& input = context->input(0);
357 const int depth = input.dim_size(input.dims() - 1); // last dimension size.
358 const Tensor& min = context->input(1);
359 const Tensor& max = context->input(2);
360
361 OP_REQUIRES(
362 context, TensorShapeUtils::IsVector(min.shape()),
363 InvalidArgument("`min` must be rank 1 but is rank ", min.dims()));
364 OP_REQUIRES(context, min.dim_size(0) == depth,
365 InvalidArgument("min has incorrect size, expected ", depth,
366 " was ", min.dim_size(0)));
367 OP_REQUIRES(
368 context, TensorShapeUtils::IsVector(max.shape()),
369 InvalidArgument("`max` must be rank 1 but is rank ", max.dims()));
370 OP_REQUIRES(context, max.dim_size(0) == depth,
371 InvalidArgument("max has incorrect size, expected ", depth,
372 " was ", max.dim_size(0)));
373
374 Tensor* output;
375 OP_REQUIRES_OK(context,
376 context->allocate_output(0, input.shape(), &output));
377
378 FakeQuantWithMinMaxVarsPerChannelFunctor<Device> functor;
379 functor(context->eigen_device<Device>(), input.flat_inner_dims<float, 2>(),
380 min.vec<float>(), max.vec<float>(), quant_min_, quant_max_,
381 output->flat_inner_dims<float, 2>());
382 }
383
384 private:
385 int quant_min_;
386 int quant_max_;
387};
388
389// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
390// documentation in core/ops/array_ops.cc.
391template <typename Device>
392class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
393 public:
394 explicit FakeQuantWithMinMaxVarsPerChannelGradientOp(
395 OpKernelConstruction* context)
396 : OpKernel::OpKernel(context) {
397 int num_bits;
398 OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits));
399 OP_REQUIRES(
400 context, IsNumBitsValid(num_bits),
401 InvalidArgument("num_bits must be between 2 and 16, inclusive"));
402 bool narrow_range;
403 OP_REQUIRES_OK(context, context->GetAttr("narrow_range", &narrow_range));
404 quant_min_ = narrow_range ? 1 : 0;
405 quant_max_ = (1 << num_bits) - 1;
406 if (std::is_same<Device, Eigen::GpuDevice>::value) {
407 OP_REQUIRES(
408 context, !OpDeterminismRequired(),
409 errors::Unimplemented(
410 "Determinism is not yet supported in GPU implementation of "
411 "FakeQuantWithMinMaxVarsPerChannelGradient."));
412 }
413 }
414
415 void Compute(OpKernelContext* context) override {
416 CHECK_EQ(4, context->num_inputs());
417 const Tensor& gradient = context->input(0);
418 const Tensor& input = context->input(1);
419 OP_REQUIRES(context, input.IsSameSize(gradient),
420 InvalidArgument("gradient and input must be the same size"));
421 const int depth = input.dim_size(input.dims() - 1); // last dimension size.
422 const Tensor& min = context->input(2);
423 OP_REQUIRES(
424 context, TensorShapeUtils::IsVector(min.shape()),
425 InvalidArgument("`min` must be rank 1 but is rank ", min.dims()));
426 OP_REQUIRES(context, min.dim_size(0) == depth,
427 InvalidArgument("min has incorrect size, expected ", depth,
428 " was ", min.dim_size(0)));
429 const Tensor& max = context->input(3);
430 OP_REQUIRES(
431 context, TensorShapeUtils::IsVector(max.shape()),
432 InvalidArgument("`max` must be rank 1 but is rank ", max.dims()));
433 OP_REQUIRES(context, max.dim_size(0) == depth,
434 InvalidArgument("max has incorrect size, expected ", depth,
435 " was ", max.dim_size(0)));
436
437 Tensor* grad_wrt_input;
438 OP_REQUIRES_OK(context,
439 context->allocate_output(0, input.shape(), &grad_wrt_input));
440
441 TensorShape min_max_shape({input.dim_size(input.dims() - 1)});
442 Tensor* grad_wrt_min;
443 OP_REQUIRES_OK(context,
444 context->allocate_output(1, min_max_shape, &grad_wrt_min));
445
446 Tensor* grad_wrt_max;
447 OP_REQUIRES_OK(context,
448 context->allocate_output(2, min_max_shape, &grad_wrt_max));
449
450 FakeQuantWithMinMaxVarsPerChannelGradientFunctor<Device> functor;
451 functor(
452 context->eigen_device<Device>(), gradient.flat_inner_dims<float, 2>(),
453 input.flat_inner_dims<float, 2>(), min.vec<float>(), max.vec<float>(),
454 quant_min_, quant_max_, grad_wrt_input->flat_inner_dims<float, 2>(),
455 grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
456 }
457
458 private:
459 int quant_min_;
460 int quant_max_;
461};
462
463REGISTER_KERNEL_BUILDER(
464 Name("FakeQuantWithMinMaxVarsPerChannel").Device(DEVICE_CPU),
465 FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>);
466REGISTER_KERNEL_BUILDER(
467 Name("FakeQuantWithMinMaxVarsPerChannelGradient").Device(DEVICE_CPU),
468 FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
469
470#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
471 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
472template <>
473void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
474 const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs,
475 typename TTypes<float>::ConstFlat min,
476 typename TTypes<float>::ConstFlat max, const int quant_min,
477 const int quant_max, typename TTypes<float>::Matrix outputs);
478extern template struct FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>;
479
480REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
481 .Device(DEVICE_GPU)
482 .HostMemory("min")
483 .HostMemory("max"),
484 FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>);
485
486template <>
487void FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
488 const GPUDevice& d, typename TTypes<float>::ConstMatrix gradients,
489 typename TTypes<float>::ConstMatrix inputs,
490 typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max,
491 const int quant_min, const int quant_max,
492 typename TTypes<float>::Matrix backprops_wrt_input,
493 typename TTypes<float>::Vec backprop_wrt_min,
494 typename TTypes<float>::Vec backprop_wrt_max);
495extern template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor<
496 GPUDevice>;
497
498REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
499 .Device(DEVICE_GPU)
500 .HostMemory("min")
501 .HostMemory("max"),
502 FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
503#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
504
505} // namespace tensorflow
506