1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_requires.h"
17#define EIGEN_USE_THREADS
18
19#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
20 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/framework/op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/type_traits.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
31#include "tensorflow/core/lib/core/errors.h"
32
33namespace tensorflow {
34namespace {
35
36using CpuDevice = ::Eigen::ThreadPoolDevice;
37using GpuDevice = ::Eigen::GpuDevice;
38using ::tensorflow::errors::InvalidArgument;
39
40} // namespace
41
42// Simulate quantization precision loss in a float tensor by:
43// 1. Quantize the tensor to fixed point numbers, which should match the target
44// quantization method when it is used in inference.
45// 2. Dequantize it back to floating point numbers for the following ops, most
46// likely matmul.
47template <typename Device, typename T>
48class QuantizeAndDequantizeV2Op : public OpKernel {
49 public:
50 explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
51 : OpKernel(ctx) {
52 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
53 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
54 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
55 OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
56 InvalidArgument("num_bits is out of range: ", num_bits_,
57 " with signed_input_ ", signed_input_));
58 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
59
60 string round_mode_string;
61 OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
62 OP_REQUIRES(
63 ctx,
64 (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
65 InvalidArgument("Round mode string must be "
66 "'HALF_UP' or "
67 "'HALF_TO_EVEN', is '" +
68 round_mode_string + "'"));
69 if (round_mode_string == "HALF_UP") {
70 round_mode_ = ROUND_HALF_UP;
71 } else if (round_mode_string == "HALF_TO_EVEN") {
72 round_mode_ = ROUND_HALF_TO_EVEN;
73 }
74 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
75 }
76
77 void Compute(OpKernelContext* ctx) override {
78 const Tensor& input = ctx->input(0);
79 OP_REQUIRES(ctx, axis_ >= -1,
80 InvalidArgument("Axis must be at least -1. Found ", axis_));
81 OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
82 InvalidArgument("Shape must be at least rank ", axis_ + 1,
83 " but is rank ", input.shape().dims()));
84 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
85 Tensor input_min_tensor;
86 Tensor input_max_tensor;
87 Tensor* output = nullptr;
88 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
89 if (range_given_) {
90 input_min_tensor = ctx->input(1);
91 input_max_tensor = ctx->input(2);
92 if (axis_ == -1) {
93 auto min_val = input_min_tensor.scalar<T>()();
94 auto max_val = input_max_tensor.scalar<T>()();
95 OP_REQUIRES(ctx, min_val <= max_val,
96 InvalidArgument("Invalid range: input_min ", min_val,
97 " > input_max ", max_val));
98 } else {
99 OP_REQUIRES(
100 ctx, input_min_tensor.dim_size(0) == depth,
101 InvalidArgument("input_min_tensor has incorrect size, was ",
102 input_min_tensor.dim_size(0), " expected ", depth,
103 " to match dim ", axis_, " of the input ",
104 input_min_tensor.shape()));
105 OP_REQUIRES(
106 ctx, input_max_tensor.dim_size(0) == depth,
107 InvalidArgument("input_max_tensor has incorrect size, was ",
108 input_max_tensor.dim_size(0), " expected ", depth,
109 " to match dim ", axis_, " of the input ",
110 input_max_tensor.shape()));
111 }
112 } else {
113 auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
114 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
115 range_shape, &input_min_tensor));
116 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
117 range_shape, &input_max_tensor));
118 }
119
120 if (axis_ == -1) {
121 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
122 f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
123 range_given_, &input_min_tensor, &input_max_tensor, round_mode_,
124 narrow_range_, output->flat<T>());
125 } else {
126 functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
127 f(ctx->eigen_device<Device>(),
128 input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
129 num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
130 round_mode_, narrow_range_,
131 output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
132 }
133 }
134
135 private:
136 int num_bits_;
137 int axis_;
138 QuantizerRoundMode round_mode_;
139 bool signed_input_;
140 bool range_given_;
141 bool narrow_range_;
142};
143
144// Implementation of QuantizeAndDequantizeV4GradientOp.
145// When back-propagating the error through a quantized layer, the following
146// paper gives evidence that clipped-ReLU is better than non-clipped:
147// "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
148// http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
149template <typename Device, typename T>
150class QuantizeAndDequantizeV4GradientOp : public OpKernel {
151 public:
152 explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx)
153 : OpKernel::OpKernel(ctx) {
154 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
155 }
156
157 void Compute(OpKernelContext* ctx) override {
158 const Tensor& gradient = ctx->input(0);
159 const Tensor& input = ctx->input(1);
160 Tensor* input_backprop = nullptr;
161 OP_REQUIRES_OK(ctx,
162 ctx->allocate_output(0, input.shape(), &input_backprop));
163 OP_REQUIRES(ctx, axis_ >= -1,
164 InvalidArgument("Axis must be at least -1. Found ", axis_));
165 OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
166 InvalidArgument(
167 "Axis should be -1 or 0 or a positive value less than ",
168 input.shape().dims(), "but given axis value was ", axis_));
169
170 OP_REQUIRES(ctx, input.IsSameSize(gradient),
171 InvalidArgument("gradient and input must be the same size"));
172 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
173 const Tensor& input_min_tensor = ctx->input(2);
174 OP_REQUIRES(ctx,
175 input_min_tensor.dims() == 0 || input_min_tensor.dims() == 1,
176 InvalidArgument(
177 "Input min tensor must have dimension 0 or 1. Received ",
178 input_min_tensor.dims(), "."));
179 const Tensor& input_max_tensor = ctx->input(3);
180 OP_REQUIRES(ctx,
181 input_max_tensor.dims() == 0 || input_max_tensor.dims() == 1,
182 InvalidArgument(
183 "Input max tensor must have dimension 0 or 1. Received ",
184 input_max_tensor.dims(), "."));
185 if (axis_ != -1) {
186 OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
187 InvalidArgument("min has incorrect size, expected ", depth,
188 " was ", input_min_tensor.dim_size(0)));
189 OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
190 InvalidArgument("max has incorrect size, expected ", depth,
191 " was ", input_max_tensor.dim_size(0)));
192 }
193
194 TensorShape min_max_shape(input_min_tensor.shape());
195 Tensor* input_min_backprop;
196 OP_REQUIRES_OK(ctx,
197 ctx->allocate_output(1, min_max_shape, &input_min_backprop));
198
199 Tensor* input_max_backprop;
200 OP_REQUIRES_OK(ctx,
201 ctx->allocate_output(2, min_max_shape, &input_max_backprop));
202
203 if (axis_ == -1) {
204 OP_REQUIRES(
205 ctx, TensorShapeUtils::IsScalar(input_min_tensor.shape()),
206 InvalidArgument("input_min must be a scalar if axis is unspecified"));
207 OP_REQUIRES(
208 ctx, TensorShapeUtils::IsScalar(input_max_tensor.shape()),
209 InvalidArgument("input_max must be a scalar if axis is unspecified"));
210 functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
211 f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
212 input.template flat<T>(), input_min_tensor.scalar<T>(),
213 input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
214 input_min_backprop->template scalar<T>(),
215 input_max_backprop->template scalar<T>());
216 } else {
217 functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
218 f(ctx->eigen_device<Device>(),
219 gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
220 input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
221 &input_min_tensor, &input_max_tensor,
222 input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
223 input_min_backprop->template flat<T>(),
224 input_max_backprop->template flat<T>());
225 }
226 }
227
228 private:
229 int axis_;
230};
231
232// Simulate quantization precision loss in a float tensor by:
233// 1. Quantize the tensor to fixed point numbers, which should match the target
234// quantization method when it is used in inference.
235// 2. Dequantize it back to floating point numbers for the following ops, most
236// likely matmul.
237// Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
238// tensor.
239template <typename Device, typename T>
240class QuantizeAndDequantizeV3Op : public OpKernel {
241 public:
242 explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
243 : OpKernel(ctx) {
244 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
245 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
246 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
247 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
248 }
249
250 void Compute(OpKernelContext* ctx) override {
251 const Tensor& input = ctx->input(0);
252 OP_REQUIRES(ctx, axis_ < input.dims(),
253 InvalidArgument(
254 "Axis requested is larger than input dimensions. Axis: ",
255 axis_, " Input Dimensions: ", input.dims()));
256 const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
257 Tensor* output = nullptr;
258 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
259
260 // Get num_bits and validate.
261 const Tensor num_bits_tensor = ctx->input(3);
262 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_bits_tensor.shape()),
263 InvalidArgument("Invalid shape. The `num_bits` tensor should "
264 "be a scalar. Got dimensions: ",
265 num_bits_tensor.dims()));
266
267 const int num_bits_val = num_bits_tensor.scalar<int32>()();
268 OP_REQUIRES(ctx,
269 num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
270 InvalidArgument("num_bits is out of range: ", num_bits_val,
271 " with `signed_input_` ", signed_input_));
272
273 Tensor input_min_tensor;
274 Tensor input_max_tensor;
275 if (range_given_) {
276 input_min_tensor = ctx->input(1);
277 input_max_tensor = ctx->input(2);
278 if (axis_ == -1) {
279 const auto min_val = input_min_tensor.scalar<T>()();
280 const auto max_val = input_max_tensor.scalar<T>()();
281 OP_REQUIRES(ctx, min_val <= max_val,
282 InvalidArgument("Invalid range: input_min ", min_val,
283 " > input_max ", max_val));
284 } else {
285 OP_REQUIRES(
286 ctx, input_min_tensor.dim_size(0) == depth,
287 InvalidArgument("input_min_tensor has incorrect size, was ",
288 input_min_tensor.dim_size(0), " expected ", depth,
289 " to match dim ", axis_, " of the input ",
290 input_min_tensor.shape()));
291 OP_REQUIRES(
292 ctx, input_max_tensor.dim_size(0) == depth,
293 InvalidArgument("input_max_tensor has incorrect size, was ",
294 input_max_tensor.dim_size(0), " expected ", depth,
295 " to match dim ", axis_, " of the input ",
296 input_max_tensor.shape()));
297 }
298 } else {
299 auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
300 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
301 range_shape, &input_min_tensor));
302 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
303 range_shape, &input_max_tensor));
304 }
305
306 if (axis_ == -1) {
307 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
308 f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
309 num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
310 ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>());
311 } else {
312 functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
313 f(ctx->eigen_device<Device>(),
314 input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
315 num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
316 ROUND_HALF_TO_EVEN, narrow_range_,
317 output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
318 }
319 }
320
321 private:
322 int axis_;
323 bool signed_input_;
324 bool range_given_;
325 bool narrow_range_;
326};
327
328// DEPRECATED: Use QuantizeAndDequantizeV2Op.
329template <typename Device, typename T>
330class QuantizeAndDequantizeOp : public OpKernel {
331 public:
332 explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
333 OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
334 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
335 OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
336 InvalidArgument("num_bits is out of range: ", num_bits_,
337 " with signed_input_ ", signed_input_));
338 OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
339 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
340 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
341 if (range_given_) {
342 OP_REQUIRES(ctx, input_min_ <= input_max_,
343 InvalidArgument("Invalid range: input_min ", input_min_,
344 " > input_max ", input_max_));
345 }
346 }
347
348 void Compute(OpKernelContext* ctx) override {
349 const Tensor& input = ctx->input(0);
350
351 Tensor* output = nullptr;
352 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
353
354 // One global scale.
355 Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
356 Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
357 // Initialize the tensors with the values in the Attrs.
358 input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
359 input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
360
361 functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
362 functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
363 num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
364 ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>());
365 }
366
367 private:
368 bool signed_input_;
369 int num_bits_;
370 bool range_given_;
371 float input_min_;
372 float input_max_;
373};
374
375// Specializations for CpuDevice.
376
377namespace functor {
378template <typename T>
379struct QuantizeAndDequantizeOneScaleFunctor<CpuDevice, T> {
380 void operator()(const CpuDevice& d, typename TTypes<T>::ConstVec input,
381 const bool signed_input, const int num_bits,
382 const bool range_given, Tensor* input_min_tensor,
383 Tensor* input_max_tensor, QuantizerRoundMode round_mode,
384 bool narrow_range, typename TTypes<T>::Vec out) {
385 QuantizeAndDequantizeOneScaleImpl<CpuDevice, T>::Compute(
386 d, input, signed_input, num_bits, range_given, input_min_tensor,
387 input_max_tensor, round_mode, narrow_range, out);
388 }
389};
390
391template <typename T>
392struct QuantizeAndDequantizePerChannelFunctor<CpuDevice, T> {
393 void operator()(const CpuDevice& d, typename TTypes<T, 3>::ConstTensor input,
394 bool signed_input, int num_bits, bool range_given,
395 Tensor* input_min_tensor, Tensor* input_max_tensor,
396 QuantizerRoundMode round_mode, bool narrow_range,
397 typename TTypes<T, 3>::Tensor out) {
398 QuantizeAndDequantizePerChannelImpl<CpuDevice, T>::Compute(
399 d, input, signed_input, num_bits, range_given, input_min_tensor,
400 input_max_tensor, round_mode, narrow_range, out);
401 }
402};
403
404template <typename T>
405struct QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice, T> {
406 void operator()(const CpuDevice& d, typename TTypes<T>::ConstFlat gradient,
407 typename TTypes<T>::ConstFlat input,
408 typename TTypes<T>::ConstScalar input_min_tensor,
409 typename TTypes<T>::ConstScalar input_max_tensor,
410 typename TTypes<T>::Flat input_backprop,
411 typename TTypes<T>::Scalar input_min_backprop,
412 typename TTypes<T>::Scalar input_max_backprop) {
413 QuantizeAndDequantizeOneScaleGradientImpl<CpuDevice, T>::Compute(
414 d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
415 input_min_backprop, input_max_backprop);
416 }
417};
418
419template <typename T>
420struct QuantizeAndDequantizePerChannelGradientFunctor<CpuDevice, T> {
421 void operator()(const CpuDevice& d,
422 typename TTypes<T, 3>::ConstTensor gradient,
423 typename TTypes<T, 3>::ConstTensor input,
424 const Tensor* input_min_tensor,
425 const Tensor* input_max_tensor,
426 typename TTypes<T, 3>::Tensor input_backprop,
427 typename TTypes<T>::Flat input_min_backprop,
428 typename TTypes<T>::Flat input_max_backprop) {
429 QuantizeAndDequantizePerChannelGradientImpl<CpuDevice, T>::Compute(
430 d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
431 input_min_backprop, input_max_backprop);
432 }
433};
434
435template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice,
436 float>;
437template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
438 CpuDevice, double>;
439
440} // namespace functor
441
442#define REGISTER_CPU_KERNEL(T) \
443 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \
444 .Device(DEVICE_CPU) \
445 .TypeConstraint<T>("T"), \
446 QuantizeAndDequantizeV2Op<CpuDevice, T>); \
447 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
448 .Device(DEVICE_CPU) \
449 .TypeConstraint<T>("T"), \
450 QuantizeAndDequantizeV3Op<CpuDevice, T>); \
451 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \
452 .Device(DEVICE_CPU) \
453 .TypeConstraint<T>("T"), \
454 QuantizeAndDequantizeV2Op<CpuDevice, T>); \
455 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \
456 .Device(DEVICE_CPU) \
457 .TypeConstraint<T>("T"), \
458 QuantizeAndDequantizeV4GradientOp<CpuDevice, T>); \
459 REGISTER_KERNEL_BUILDER( \
460 Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
461 QuantizeAndDequantizeOp<CpuDevice, T>);
462TF_CALL_float(REGISTER_CPU_KERNEL);
463TF_CALL_double(REGISTER_CPU_KERNEL);
464#undef REGISTER_CPU_KERNEL
465
466#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
467 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
468#define REGISTER_GPU_KERNEL(T) \
469 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \
470 .Device(DEVICE_GPU) \
471 .HostMemory("input_min") \
472 .HostMemory("input_max") \
473 .TypeConstraint<T>("T"), \
474 QuantizeAndDequantizeV2Op<GpuDevice, T>); \
475 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
476 .Device(DEVICE_GPU) \
477 .HostMemory("input_min") \
478 .HostMemory("input_max") \
479 .HostMemory("num_bits") \
480 .TypeConstraint<T>("T"), \
481 QuantizeAndDequantizeV3Op<GpuDevice, T>); \
482 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \
483 .Device(DEVICE_GPU) \
484 .HostMemory("input_min") \
485 .HostMemory("input_max") \
486 .TypeConstraint<T>("T"), \
487 QuantizeAndDequantizeV2Op<GpuDevice, T>); \
488 REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \
489 .Device(DEVICE_GPU) \
490 .HostMemory("input_min") \
491 .HostMemory("input_max") \
492 .TypeConstraint<T>("T"), \
493 QuantizeAndDequantizeV4GradientOp<GpuDevice, T>); \
494 REGISTER_KERNEL_BUILDER( \
495 Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
496 QuantizeAndDequantizeOp<GpuDevice, T>);
497TF_CALL_float(REGISTER_GPU_KERNEL);
498TF_CALL_double(REGISTER_GPU_KERNEL);
499#undef REGISTER_GPU_KERNEL
500#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
501} // namespace tensorflow
502