1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_requires.h" |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
20 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
21 | #define EIGEN_USE_GPU |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include "tensorflow/core/framework/op.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/type_traits.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/kernels/quantize_and_dequantize_op.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | |
33 | namespace tensorflow { |
34 | namespace { |
35 | |
36 | using CpuDevice = ::Eigen::ThreadPoolDevice; |
37 | using GpuDevice = ::Eigen::GpuDevice; |
38 | using ::tensorflow::errors::InvalidArgument; |
39 | |
40 | } // namespace |
41 | |
42 | // Simulate quantization precision loss in a float tensor by: |
43 | // 1. Quantize the tensor to fixed point numbers, which should match the target |
44 | // quantization method when it is used in inference. |
45 | // 2. Dequantize it back to floating point numbers for the following ops, most |
46 | // likely matmul. |
47 | template <typename Device, typename T> |
48 | class QuantizeAndDequantizeV2Op : public OpKernel { |
49 | public: |
50 | explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx) |
51 | : OpKernel(ctx) { |
52 | OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input" , &signed_input_)); |
53 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis" , &axis_)); |
54 | OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits" , &num_bits_)); |
55 | OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), |
56 | InvalidArgument("num_bits is out of range: " , num_bits_, |
57 | " with signed_input_ " , signed_input_)); |
58 | OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given" , &range_given_)); |
59 | |
60 | string round_mode_string; |
61 | OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode" , &round_mode_string)); |
62 | OP_REQUIRES( |
63 | ctx, |
64 | (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN" ), |
65 | InvalidArgument("Round mode string must be " |
66 | "'HALF_UP' or " |
67 | "'HALF_TO_EVEN', is '" + |
68 | round_mode_string + "'" )); |
69 | if (round_mode_string == "HALF_UP" ) { |
70 | round_mode_ = ROUND_HALF_UP; |
71 | } else if (round_mode_string == "HALF_TO_EVEN" ) { |
72 | round_mode_ = ROUND_HALF_TO_EVEN; |
73 | } |
74 | OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range" , &narrow_range_)); |
75 | } |
76 | |
77 | void Compute(OpKernelContext* ctx) override { |
78 | const Tensor& input = ctx->input(0); |
79 | OP_REQUIRES(ctx, axis_ >= -1, |
80 | InvalidArgument("Axis must be at least -1. Found " , axis_)); |
81 | OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()), |
82 | InvalidArgument("Shape must be at least rank " , axis_ + 1, |
83 | " but is rank " , input.shape().dims())); |
84 | const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); |
85 | Tensor input_min_tensor; |
86 | Tensor input_max_tensor; |
87 | Tensor* output = nullptr; |
88 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); |
89 | if (range_given_) { |
90 | input_min_tensor = ctx->input(1); |
91 | input_max_tensor = ctx->input(2); |
92 | if (axis_ == -1) { |
93 | auto min_val = input_min_tensor.scalar<T>()(); |
94 | auto max_val = input_max_tensor.scalar<T>()(); |
95 | OP_REQUIRES(ctx, min_val <= max_val, |
96 | InvalidArgument("Invalid range: input_min " , min_val, |
97 | " > input_max " , max_val)); |
98 | } else { |
99 | OP_REQUIRES( |
100 | ctx, input_min_tensor.dim_size(0) == depth, |
101 | InvalidArgument("input_min_tensor has incorrect size, was " , |
102 | input_min_tensor.dim_size(0), " expected " , depth, |
103 | " to match dim " , axis_, " of the input " , |
104 | input_min_tensor.shape())); |
105 | OP_REQUIRES( |
106 | ctx, input_max_tensor.dim_size(0) == depth, |
107 | InvalidArgument("input_max_tensor has incorrect size, was " , |
108 | input_max_tensor.dim_size(0), " expected " , depth, |
109 | " to match dim " , axis_, " of the input " , |
110 | input_max_tensor.shape())); |
111 | } |
112 | } else { |
113 | auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth}); |
114 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
115 | range_shape, &input_min_tensor)); |
116 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
117 | range_shape, &input_max_tensor)); |
118 | } |
119 | |
120 | if (axis_ == -1) { |
121 | functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f; |
122 | f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_, |
123 | range_given_, &input_min_tensor, &input_max_tensor, round_mode_, |
124 | narrow_range_, output->flat<T>()); |
125 | } else { |
126 | functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f; |
127 | f(ctx->eigen_device<Device>(), |
128 | input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_, |
129 | num_bits_, range_given_, &input_min_tensor, &input_max_tensor, |
130 | round_mode_, narrow_range_, |
131 | output->template flat_inner_outer_dims<T, 3>(axis_ - 1)); |
132 | } |
133 | } |
134 | |
135 | private: |
136 | int num_bits_; |
137 | int axis_; |
138 | QuantizerRoundMode round_mode_; |
139 | bool signed_input_; |
140 | bool range_given_; |
141 | bool narrow_range_; |
142 | }; |
143 | |
144 | // Implementation of QuantizeAndDequantizeV4GradientOp. |
145 | // When back-propagating the error through a quantized layer, the following |
146 | // paper gives evidence that clipped-ReLU is better than non-clipped: |
147 | // "Deep Learning with Low Precision by Half-wave Gaussian Quantization" |
148 | // http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf |
149 | template <typename Device, typename T> |
150 | class QuantizeAndDequantizeV4GradientOp : public OpKernel { |
151 | public: |
152 | explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx) |
153 | : OpKernel::OpKernel(ctx) { |
154 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis" , &axis_)); |
155 | } |
156 | |
157 | void Compute(OpKernelContext* ctx) override { |
158 | const Tensor& gradient = ctx->input(0); |
159 | const Tensor& input = ctx->input(1); |
160 | Tensor* input_backprop = nullptr; |
161 | OP_REQUIRES_OK(ctx, |
162 | ctx->allocate_output(0, input.shape(), &input_backprop)); |
163 | OP_REQUIRES(ctx, axis_ >= -1, |
164 | InvalidArgument("Axis must be at least -1. Found " , axis_)); |
165 | OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()), |
166 | InvalidArgument( |
167 | "Axis should be -1 or 0 or a positive value less than " , |
168 | input.shape().dims(), "but given axis value was " , axis_)); |
169 | |
170 | OP_REQUIRES(ctx, input.IsSameSize(gradient), |
171 | InvalidArgument("gradient and input must be the same size" )); |
172 | const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); |
173 | const Tensor& input_min_tensor = ctx->input(2); |
174 | OP_REQUIRES(ctx, |
175 | input_min_tensor.dims() == 0 || input_min_tensor.dims() == 1, |
176 | InvalidArgument( |
177 | "Input min tensor must have dimension 0 or 1. Received " , |
178 | input_min_tensor.dims(), "." )); |
179 | const Tensor& input_max_tensor = ctx->input(3); |
180 | OP_REQUIRES(ctx, |
181 | input_max_tensor.dims() == 0 || input_max_tensor.dims() == 1, |
182 | InvalidArgument( |
183 | "Input max tensor must have dimension 0 or 1. Received " , |
184 | input_max_tensor.dims(), "." )); |
185 | if (axis_ != -1) { |
186 | OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth, |
187 | InvalidArgument("min has incorrect size, expected " , depth, |
188 | " was " , input_min_tensor.dim_size(0))); |
189 | OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth, |
190 | InvalidArgument("max has incorrect size, expected " , depth, |
191 | " was " , input_max_tensor.dim_size(0))); |
192 | } |
193 | |
194 | TensorShape min_max_shape(input_min_tensor.shape()); |
195 | Tensor* input_min_backprop; |
196 | OP_REQUIRES_OK(ctx, |
197 | ctx->allocate_output(1, min_max_shape, &input_min_backprop)); |
198 | |
199 | Tensor* input_max_backprop; |
200 | OP_REQUIRES_OK(ctx, |
201 | ctx->allocate_output(2, min_max_shape, &input_max_backprop)); |
202 | |
203 | if (axis_ == -1) { |
204 | OP_REQUIRES( |
205 | ctx, TensorShapeUtils::IsScalar(input_min_tensor.shape()), |
206 | InvalidArgument("input_min must be a scalar if axis is unspecified" )); |
207 | OP_REQUIRES( |
208 | ctx, TensorShapeUtils::IsScalar(input_max_tensor.shape()), |
209 | InvalidArgument("input_max must be a scalar if axis is unspecified" )); |
210 | functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f; |
211 | f(ctx->eigen_device<Device>(), gradient.template flat<T>(), |
212 | input.template flat<T>(), input_min_tensor.scalar<T>(), |
213 | input_max_tensor.scalar<T>(), input_backprop->template flat<T>(), |
214 | input_min_backprop->template scalar<T>(), |
215 | input_max_backprop->template scalar<T>()); |
216 | } else { |
217 | functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f; |
218 | f(ctx->eigen_device<Device>(), |
219 | gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1), |
220 | input.template flat_inner_outer_dims<T, 3>(axis_ - 1), |
221 | &input_min_tensor, &input_max_tensor, |
222 | input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1), |
223 | input_min_backprop->template flat<T>(), |
224 | input_max_backprop->template flat<T>()); |
225 | } |
226 | } |
227 | |
228 | private: |
229 | int axis_; |
230 | }; |
231 | |
232 | // Simulate quantization precision loss in a float tensor by: |
233 | // 1. Quantize the tensor to fixed point numbers, which should match the target |
234 | // quantization method when it is used in inference. |
235 | // 2. Dequantize it back to floating point numbers for the following ops, most |
236 | // likely matmul. |
237 | // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a |
238 | // tensor. |
239 | template <typename Device, typename T> |
240 | class QuantizeAndDequantizeV3Op : public OpKernel { |
241 | public: |
242 | explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx) |
243 | : OpKernel(ctx) { |
244 | OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input" , &signed_input_)); |
245 | OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given" , &range_given_)); |
246 | OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range" , &narrow_range_)); |
247 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis" , &axis_)); |
248 | } |
249 | |
250 | void Compute(OpKernelContext* ctx) override { |
251 | const Tensor& input = ctx->input(0); |
252 | OP_REQUIRES(ctx, axis_ < input.dims(), |
253 | InvalidArgument( |
254 | "Axis requested is larger than input dimensions. Axis: " , |
255 | axis_, " Input Dimensions: " , input.dims())); |
256 | const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); |
257 | Tensor* output = nullptr; |
258 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); |
259 | |
260 | // Get num_bits and validate. |
261 | const Tensor num_bits_tensor = ctx->input(3); |
262 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_bits_tensor.shape()), |
263 | InvalidArgument("Invalid shape. The `num_bits` tensor should " |
264 | "be a scalar. Got dimensions: " , |
265 | num_bits_tensor.dims())); |
266 | |
267 | const int num_bits_val = num_bits_tensor.scalar<int32>()(); |
268 | OP_REQUIRES(ctx, |
269 | num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63), |
270 | InvalidArgument("num_bits is out of range: " , num_bits_val, |
271 | " with `signed_input_` " , signed_input_)); |
272 | |
273 | Tensor input_min_tensor; |
274 | Tensor input_max_tensor; |
275 | if (range_given_) { |
276 | input_min_tensor = ctx->input(1); |
277 | input_max_tensor = ctx->input(2); |
278 | if (axis_ == -1) { |
279 | const auto min_val = input_min_tensor.scalar<T>()(); |
280 | const auto max_val = input_max_tensor.scalar<T>()(); |
281 | OP_REQUIRES(ctx, min_val <= max_val, |
282 | InvalidArgument("Invalid range: input_min " , min_val, |
283 | " > input_max " , max_val)); |
284 | } else { |
285 | OP_REQUIRES( |
286 | ctx, input_min_tensor.dim_size(0) == depth, |
287 | InvalidArgument("input_min_tensor has incorrect size, was " , |
288 | input_min_tensor.dim_size(0), " expected " , depth, |
289 | " to match dim " , axis_, " of the input " , |
290 | input_min_tensor.shape())); |
291 | OP_REQUIRES( |
292 | ctx, input_max_tensor.dim_size(0) == depth, |
293 | InvalidArgument("input_max_tensor has incorrect size, was " , |
294 | input_max_tensor.dim_size(0), " expected " , depth, |
295 | " to match dim " , axis_, " of the input " , |
296 | input_max_tensor.shape())); |
297 | } |
298 | } else { |
299 | auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth}); |
300 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
301 | range_shape, &input_min_tensor)); |
302 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
303 | range_shape, &input_max_tensor)); |
304 | } |
305 | |
306 | if (axis_ == -1) { |
307 | functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f; |
308 | f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, |
309 | num_bits_val, range_given_, &input_min_tensor, &input_max_tensor, |
310 | ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>()); |
311 | } else { |
312 | functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f; |
313 | f(ctx->eigen_device<Device>(), |
314 | input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_, |
315 | num_bits_val, range_given_, &input_min_tensor, &input_max_tensor, |
316 | ROUND_HALF_TO_EVEN, narrow_range_, |
317 | output->template flat_inner_outer_dims<T, 3>(axis_ - 1)); |
318 | } |
319 | } |
320 | |
321 | private: |
322 | int axis_; |
323 | bool signed_input_; |
324 | bool range_given_; |
325 | bool narrow_range_; |
326 | }; |
327 | |
328 | // DEPRECATED: Use QuantizeAndDequantizeV2Op. |
329 | template <typename Device, typename T> |
330 | class QuantizeAndDequantizeOp : public OpKernel { |
331 | public: |
332 | explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
333 | OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input" , &signed_input_)); |
334 | OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits" , &num_bits_)); |
335 | OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), |
336 | InvalidArgument("num_bits is out of range: " , num_bits_, |
337 | " with signed_input_ " , signed_input_)); |
338 | OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given" , &range_given_)); |
339 | OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min" , &input_min_)); |
340 | OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max" , &input_max_)); |
341 | if (range_given_) { |
342 | OP_REQUIRES(ctx, input_min_ <= input_max_, |
343 | InvalidArgument("Invalid range: input_min " , input_min_, |
344 | " > input_max " , input_max_)); |
345 | } |
346 | } |
347 | |
348 | void Compute(OpKernelContext* ctx) override { |
349 | const Tensor& input = ctx->input(0); |
350 | |
351 | Tensor* output = nullptr; |
352 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); |
353 | |
354 | // One global scale. |
355 | Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape()); |
356 | Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape()); |
357 | // Initialize the tensors with the values in the Attrs. |
358 | input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_); |
359 | input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_); |
360 | |
361 | functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor; |
362 | functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, |
363 | num_bits_, range_given_, &input_min_tensor, &input_max_tensor, |
364 | ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>()); |
365 | } |
366 | |
367 | private: |
368 | bool signed_input_; |
369 | int num_bits_; |
370 | bool range_given_; |
371 | float input_min_; |
372 | float input_max_; |
373 | }; |
374 | |
375 | // Specializations for CpuDevice. |
376 | |
377 | namespace functor { |
378 | template <typename T> |
379 | struct QuantizeAndDequantizeOneScaleFunctor<CpuDevice, T> { |
380 | void operator()(const CpuDevice& d, typename TTypes<T>::ConstVec input, |
381 | const bool signed_input, const int num_bits, |
382 | const bool range_given, Tensor* input_min_tensor, |
383 | Tensor* input_max_tensor, QuantizerRoundMode round_mode, |
384 | bool narrow_range, typename TTypes<T>::Vec out) { |
385 | QuantizeAndDequantizeOneScaleImpl<CpuDevice, T>::Compute( |
386 | d, input, signed_input, num_bits, range_given, input_min_tensor, |
387 | input_max_tensor, round_mode, narrow_range, out); |
388 | } |
389 | }; |
390 | |
391 | template <typename T> |
392 | struct QuantizeAndDequantizePerChannelFunctor<CpuDevice, T> { |
393 | void operator()(const CpuDevice& d, typename TTypes<T, 3>::ConstTensor input, |
394 | bool signed_input, int num_bits, bool range_given, |
395 | Tensor* input_min_tensor, Tensor* input_max_tensor, |
396 | QuantizerRoundMode round_mode, bool narrow_range, |
397 | typename TTypes<T, 3>::Tensor out) { |
398 | QuantizeAndDequantizePerChannelImpl<CpuDevice, T>::Compute( |
399 | d, input, signed_input, num_bits, range_given, input_min_tensor, |
400 | input_max_tensor, round_mode, narrow_range, out); |
401 | } |
402 | }; |
403 | |
404 | template <typename T> |
405 | struct QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice, T> { |
406 | void operator()(const CpuDevice& d, typename TTypes<T>::ConstFlat gradient, |
407 | typename TTypes<T>::ConstFlat input, |
408 | typename TTypes<T>::ConstScalar input_min_tensor, |
409 | typename TTypes<T>::ConstScalar input_max_tensor, |
410 | typename TTypes<T>::Flat input_backprop, |
411 | typename TTypes<T>::Scalar input_min_backprop, |
412 | typename TTypes<T>::Scalar input_max_backprop) { |
413 | QuantizeAndDequantizeOneScaleGradientImpl<CpuDevice, T>::Compute( |
414 | d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, |
415 | input_min_backprop, input_max_backprop); |
416 | } |
417 | }; |
418 | |
419 | template <typename T> |
420 | struct QuantizeAndDequantizePerChannelGradientFunctor<CpuDevice, T> { |
421 | void operator()(const CpuDevice& d, |
422 | typename TTypes<T, 3>::ConstTensor gradient, |
423 | typename TTypes<T, 3>::ConstTensor input, |
424 | const Tensor* input_min_tensor, |
425 | const Tensor* input_max_tensor, |
426 | typename TTypes<T, 3>::Tensor input_backprop, |
427 | typename TTypes<T>::Flat input_min_backprop, |
428 | typename TTypes<T>::Flat input_max_backprop) { |
429 | QuantizeAndDequantizePerChannelGradientImpl<CpuDevice, T>::Compute( |
430 | d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, |
431 | input_min_backprop, input_max_backprop); |
432 | } |
433 | }; |
434 | |
435 | template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice, |
436 | float>; |
437 | template struct functor::QuantizeAndDequantizePerChannelGradientFunctor< |
438 | CpuDevice, double>; |
439 | |
440 | } // namespace functor |
441 | |
442 | #define REGISTER_CPU_KERNEL(T) \ |
443 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \ |
444 | .Device(DEVICE_CPU) \ |
445 | .TypeConstraint<T>("T"), \ |
446 | QuantizeAndDequantizeV2Op<CpuDevice, T>); \ |
447 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \ |
448 | .Device(DEVICE_CPU) \ |
449 | .TypeConstraint<T>("T"), \ |
450 | QuantizeAndDequantizeV3Op<CpuDevice, T>); \ |
451 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ |
452 | .Device(DEVICE_CPU) \ |
453 | .TypeConstraint<T>("T"), \ |
454 | QuantizeAndDequantizeV2Op<CpuDevice, T>); \ |
455 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ |
456 | .Device(DEVICE_CPU) \ |
457 | .TypeConstraint<T>("T"), \ |
458 | QuantizeAndDequantizeV4GradientOp<CpuDevice, T>); \ |
459 | REGISTER_KERNEL_BUILDER( \ |
460 | Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
461 | QuantizeAndDequantizeOp<CpuDevice, T>); |
462 | TF_CALL_float(REGISTER_CPU_KERNEL); |
463 | TF_CALL_double(REGISTER_CPU_KERNEL); |
464 | #undef REGISTER_CPU_KERNEL |
465 | |
466 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
467 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
468 | #define REGISTER_GPU_KERNEL(T) \ |
469 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2") \ |
470 | .Device(DEVICE_GPU) \ |
471 | .HostMemory("input_min") \ |
472 | .HostMemory("input_max") \ |
473 | .TypeConstraint<T>("T"), \ |
474 | QuantizeAndDequantizeV2Op<GpuDevice, T>); \ |
475 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \ |
476 | .Device(DEVICE_GPU) \ |
477 | .HostMemory("input_min") \ |
478 | .HostMemory("input_max") \ |
479 | .HostMemory("num_bits") \ |
480 | .TypeConstraint<T>("T"), \ |
481 | QuantizeAndDequantizeV3Op<GpuDevice, T>); \ |
482 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ |
483 | .Device(DEVICE_GPU) \ |
484 | .HostMemory("input_min") \ |
485 | .HostMemory("input_max") \ |
486 | .TypeConstraint<T>("T"), \ |
487 | QuantizeAndDequantizeV2Op<GpuDevice, T>); \ |
488 | REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ |
489 | .Device(DEVICE_GPU) \ |
490 | .HostMemory("input_min") \ |
491 | .HostMemory("input_max") \ |
492 | .TypeConstraint<T>("T"), \ |
493 | QuantizeAndDequantizeV4GradientOp<GpuDevice, T>); \ |
494 | REGISTER_KERNEL_BUILDER( \ |
495 | Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
496 | QuantizeAndDequantizeOp<GpuDevice, T>); |
497 | TF_CALL_float(REGISTER_GPU_KERNEL); |
498 | TF_CALL_double(REGISTER_GPU_KERNEL); |
499 | #undef REGISTER_GPU_KERNEL |
500 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
501 | } // namespace tensorflow |
502 | |