1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
19 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
20 | #define EIGEN_USE_GPU |
21 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
22 | |
23 | #include "tensorflow/core/kernels/fake_quant_ops_functor.h" |
24 | // Above is the related header but clang tidy doesn't recognize it. |
25 | #include "tensorflow/core/framework/numeric_op.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/monitoring/gauge.h" |
30 | #include "tensorflow/core/platform/protobuf.h" |
31 | #include "tensorflow/core/util/determinism.h" |
32 | |
33 | using tensorflow::BinaryElementWiseOp; |
34 | using tensorflow::DEVICE_CPU; |
35 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
36 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
37 | using tensorflow::DEVICE_GPU; |
38 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
39 | using tensorflow::OpKernel; |
40 | using tensorflow::OpKernelConstruction; |
41 | using tensorflow::OpKernelContext; |
42 | using tensorflow::Tensor; |
43 | using tensorflow::TensorShape; |
44 | using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove. |
45 | using tensorflow::UnaryElementWiseOp; |
46 | using tensorflow::errors::InvalidArgument; |
47 | |
48 | namespace tensorflow { |
49 | |
50 | typedef Eigen::ThreadPoolDevice CPUDevice; |
51 | |
52 | auto* using_fake_quant = monitoring::Gauge<bool, 0>::New( |
53 | "/tensorflow/api/op/using_fake_quantization" , |
54 | "True if a fake_quant op is created." ); |
55 | |
56 | #define SET_USING_FAKE_QUANT() using_fake_quant->GetCell()->Set(true) |
57 | |
58 | namespace { |
59 | bool IsNumBitsValid(int num_bits) { return num_bits >= 2 && num_bits <= 16; } |
60 | } // namespace |
61 | |
62 | // ----------------------------------------------------------------------------- |
63 | // Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in |
64 | // core/ops/array_ops.cc. |
65 | template <typename Device> |
66 | class FakeQuantWithMinMaxArgsOp |
67 | : public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> { |
68 | public: |
69 | typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base; |
70 | explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context) |
71 | : Base::UnaryElementWiseOp(context) { |
72 | OP_REQUIRES_OK(context, context->GetAttr("min" , &min_)); |
73 | OP_REQUIRES_OK(context, context->GetAttr("max" , &max_)); |
74 | OP_REQUIRES(context, min_ < max_, |
75 | InvalidArgument("min has to be smaller than max, was: " , min_, |
76 | " >= " , max_)); |
77 | int num_bits; |
78 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
79 | OP_REQUIRES( |
80 | context, IsNumBitsValid(num_bits), |
81 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
82 | bool narrow_range; |
83 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
84 | quant_min_ = narrow_range ? 1 : 0; |
85 | quant_max_ = (1 << num_bits) - 1; |
86 | SET_USING_FAKE_QUANT(); |
87 | } |
88 | |
89 | void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { |
90 | FakeQuantWithMinMaxArgsFunctor<Device> functor; |
91 | functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_, |
92 | quant_min_, quant_max_, output->flat<float>()); |
93 | } |
94 | |
95 | private: |
96 | float min_; |
97 | float max_; |
98 | int quant_min_; |
99 | int quant_max_; |
100 | }; |
101 | |
102 | // Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in |
103 | // core/ops/array_ops.cc. |
104 | template <typename Device> |
105 | class FakeQuantWithMinMaxArgsGradientOp |
106 | : public BinaryElementWiseOp<float, |
107 | FakeQuantWithMinMaxArgsGradientOp<Device>> { |
108 | public: |
109 | typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>> |
110 | Base; |
111 | explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context) |
112 | : Base::BinaryElementWiseOp(context) { |
113 | OP_REQUIRES_OK(context, context->GetAttr("min" , &min_)); |
114 | OP_REQUIRES_OK(context, context->GetAttr("max" , &max_)); |
115 | OP_REQUIRES(context, min_ < max_, |
116 | InvalidArgument("min has to be smaller than max, was: " , min_, |
117 | " >= " , max_)); |
118 | int num_bits; |
119 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
120 | OP_REQUIRES( |
121 | context, IsNumBitsValid(num_bits), |
122 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
123 | bool narrow_range; |
124 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
125 | quant_min_ = narrow_range ? 1 : 0; |
126 | quant_max_ = (1 << num_bits) - 1; |
127 | } |
128 | |
129 | template <int NDIMS> |
130 | void Operate(OpKernelContext* context, const Tensor& gradient, |
131 | const Tensor& input, Tensor* output) { |
132 | OperateNoTemplate(context, gradient, input, output); |
133 | } |
134 | |
135 | void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient, |
136 | const Tensor& input, Tensor* output) { |
137 | OP_REQUIRES(context, input.IsSameSize(gradient), |
138 | InvalidArgument("gradient and input must be the same size" )); |
139 | FakeQuantWithMinMaxArgsGradientFunctor<Device> functor; |
140 | functor(context->eigen_device<Device>(), gradient.flat<float>(), |
141 | input.flat<float>(), min_, max_, quant_min_, quant_max_, |
142 | output->flat<float>()); |
143 | } |
144 | |
145 | private: |
146 | float min_; |
147 | float max_; |
148 | int quant_min_; |
149 | int quant_max_; |
150 | }; |
151 | |
152 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs" ).Device(DEVICE_CPU), |
153 | FakeQuantWithMinMaxArgsOp<CPUDevice>); |
154 | REGISTER_KERNEL_BUILDER( |
155 | Name("FakeQuantWithMinMaxArgsGradient" ).Device(DEVICE_CPU), |
156 | FakeQuantWithMinMaxArgsGradientOp<CPUDevice>); |
157 | |
158 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
159 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
160 | typedef Eigen::GpuDevice GPUDevice; |
161 | |
162 | // Forward declarations for functor specializations for GPU. |
163 | template <> |
164 | void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()( |
165 | const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, |
166 | const float min, const float max, const int quant_min, const int quant_max, |
167 | typename TTypes<float>::Flat outputs); |
168 | extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>; |
169 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs" ).Device(DEVICE_GPU), |
170 | FakeQuantWithMinMaxArgsOp<GPUDevice>); |
171 | |
172 | template <> |
173 | void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()( |
174 | const GPUDevice& d, typename TTypes<float>::ConstFlat gradients, |
175 | typename TTypes<float>::ConstFlat inputs, const float min, const float max, |
176 | const int quant_min, const int quant_max, |
177 | typename TTypes<float>::Flat backprops); |
178 | REGISTER_KERNEL_BUILDER( |
179 | Name("FakeQuantWithMinMaxArgsGradient" ).Device(DEVICE_GPU), |
180 | FakeQuantWithMinMaxArgsGradientOp<GPUDevice>); |
181 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
182 | |
183 | // ----------------------------------------------------------------------------- |
184 | // Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in |
185 | // core/ops/array_ops.cc. |
186 | template <typename Device> |
187 | class FakeQuantWithMinMaxVarsOp : public OpKernel { |
188 | public: |
189 | explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context) |
190 | : OpKernel::OpKernel(context) { |
191 | int num_bits; |
192 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
193 | OP_REQUIRES( |
194 | context, IsNumBitsValid(num_bits), |
195 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
196 | bool narrow_range; |
197 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
198 | quant_min_ = narrow_range ? 1 : 0; |
199 | quant_max_ = (1 << num_bits) - 1; |
200 | SET_USING_FAKE_QUANT(); |
201 | } |
202 | |
203 | void Compute(OpKernelContext* context) override { |
204 | CHECK_EQ(3, context->num_inputs()); |
205 | const Tensor& input = context->input(0); |
206 | const Tensor& min = context->input(1); |
207 | const Tensor& max = context->input(2); |
208 | |
209 | OP_REQUIRES( |
210 | context, TensorShapeUtils::IsScalar(min.shape()), |
211 | InvalidArgument("`min` must be rank 0 but is rank " , min.dims())); |
212 | OP_REQUIRES( |
213 | context, TensorShapeUtils::IsScalar(max.shape()), |
214 | InvalidArgument("`max` must be rank 0 but is rank " , max.dims())); |
215 | |
216 | Tensor* output; |
217 | OP_REQUIRES_OK(context, |
218 | context->allocate_output(0, input.shape(), &output)); |
219 | |
220 | FakeQuantWithMinMaxVarsFunctor<Device> functor; |
221 | functor(context->eigen_device<Device>(), input.flat<float>(), |
222 | min.scalar<float>(), max.scalar<float>(), quant_min_, quant_max_, |
223 | output->flat<float>()); |
224 | } |
225 | |
226 | private: |
227 | int quant_min_; |
228 | int quant_max_; |
229 | }; |
230 | |
231 | // Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in |
232 | // core/ops/array_ops.cc. |
233 | template <typename Device> |
234 | class FakeQuantWithMinMaxVarsGradientOp : public OpKernel { |
235 | public: |
236 | explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context) |
237 | : OpKernel::OpKernel(context) { |
238 | int num_bits; |
239 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
240 | OP_REQUIRES( |
241 | context, IsNumBitsValid(num_bits), |
242 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
243 | bool narrow_range; |
244 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
245 | quant_min_ = narrow_range ? 1 : 0; |
246 | quant_max_ = (1 << num_bits) - 1; |
247 | if (std::is_same<Device, Eigen::GpuDevice>::value) { |
248 | OP_REQUIRES( |
249 | context, !OpDeterminismRequired(), |
250 | errors::Unimplemented( |
251 | "Determinism is not yet supported in GPU implementation of " |
252 | "FakeQuantWithMinMaxVarsGradient." )); |
253 | } |
254 | } |
255 | |
256 | void Compute(OpKernelContext* context) override { |
257 | CHECK_EQ(4, context->num_inputs()); |
258 | const Tensor& gradient = context->input(0); |
259 | const Tensor& input = context->input(1); |
260 | OP_REQUIRES(context, input.IsSameSize(gradient), |
261 | InvalidArgument("gradient and input must be the same size" )); |
262 | const Tensor& min = context->input(2); |
263 | const Tensor& max = context->input(3); |
264 | OP_REQUIRES( |
265 | context, TensorShapeUtils::IsScalar(min.shape()), |
266 | InvalidArgument("`min` must be rank 0 but is rank " , min.dims())); |
267 | OP_REQUIRES( |
268 | context, TensorShapeUtils::IsScalar(max.shape()), |
269 | InvalidArgument("`max` must be rank 0 but is rank " , max.dims())); |
270 | |
271 | Tensor* grad_wrt_input; |
272 | OP_REQUIRES_OK(context, |
273 | context->allocate_output(0, input.shape(), &grad_wrt_input)); |
274 | |
275 | TensorShape scalar_shape; |
276 | Tensor* grad_wrt_min; |
277 | OP_REQUIRES_OK(context, |
278 | context->allocate_output(1, scalar_shape, &grad_wrt_min)); |
279 | |
280 | Tensor* grad_wrt_max; |
281 | OP_REQUIRES_OK(context, |
282 | context->allocate_output(2, scalar_shape, &grad_wrt_max)); |
283 | |
284 | FakeQuantWithMinMaxVarsGradientFunctor<Device> functor; |
285 | functor(context->eigen_device<Device>(), gradient.flat<float>(), |
286 | input.flat<float>(), min.scalar<float>(), max.scalar<float>(), |
287 | quant_min_, quant_max_, grad_wrt_input->flat<float>(), |
288 | grad_wrt_min->scalar<float>(), grad_wrt_max->scalar<float>()); |
289 | } |
290 | |
291 | private: |
292 | int quant_min_; |
293 | int quant_max_; |
294 | }; |
295 | |
296 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars" ).Device(DEVICE_CPU), |
297 | FakeQuantWithMinMaxVarsOp<CPUDevice>); |
298 | REGISTER_KERNEL_BUILDER( |
299 | Name("FakeQuantWithMinMaxVarsGradient" ).Device(DEVICE_CPU), |
300 | FakeQuantWithMinMaxVarsGradientOp<CPUDevice>); |
301 | |
302 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
303 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
304 | template <> |
305 | void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()( |
306 | const GPUDevice& d, typename TTypes<float>::ConstFlat inputs, |
307 | typename TTypes<float>::ConstScalar min, |
308 | typename TTypes<float>::ConstScalar max, const int quant_min, |
309 | const int quant_max, typename TTypes<float>::Flat output); |
310 | extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>; |
311 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars" ) |
312 | .Device(DEVICE_GPU) |
313 | .HostMemory("min" ) |
314 | .HostMemory("max" ), |
315 | FakeQuantWithMinMaxVarsOp<GPUDevice>); |
316 | |
317 | template <> |
318 | void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()( |
319 | const GPUDevice& d, typename TTypes<float>::ConstFlat gradients, |
320 | typename TTypes<float>::ConstFlat inputs, |
321 | typename TTypes<float>::ConstScalar min, |
322 | typename TTypes<float>::ConstScalar max, const int quant_min, |
323 | const int quant_max, typename TTypes<float>::Flat backprops_wrt_input, |
324 | typename TTypes<float>::Scalar backprop_wrt_min, |
325 | typename TTypes<float>::Scalar backprop_wrt_max); |
326 | extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>; |
327 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient" ) |
328 | .Device(DEVICE_GPU) |
329 | .HostMemory("min" ) |
330 | .HostMemory("max" ), |
331 | FakeQuantWithMinMaxVarsGradientOp<GPUDevice>); |
332 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
333 | |
334 | // ----------------------------------------------------------------------------- |
335 | // Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation |
336 | // in core/ops/array_ops.cc. |
337 | template <typename Device> |
338 | class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel { |
339 | public: |
340 | explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context) |
341 | : OpKernel::OpKernel(context) { |
342 | int num_bits; |
343 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
344 | OP_REQUIRES( |
345 | context, IsNumBitsValid(num_bits), |
346 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
347 | bool narrow_range; |
348 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
349 | quant_min_ = narrow_range ? 1 : 0; |
350 | quant_max_ = (1 << num_bits) - 1; |
351 | SET_USING_FAKE_QUANT(); |
352 | } |
353 | |
354 | void Compute(OpKernelContext* context) override { |
355 | CHECK_EQ(3, context->num_inputs()); |
356 | const Tensor& input = context->input(0); |
357 | const int depth = input.dim_size(input.dims() - 1); // last dimension size. |
358 | const Tensor& min = context->input(1); |
359 | const Tensor& max = context->input(2); |
360 | |
361 | OP_REQUIRES( |
362 | context, TensorShapeUtils::IsVector(min.shape()), |
363 | InvalidArgument("`min` must be rank 1 but is rank " , min.dims())); |
364 | OP_REQUIRES(context, min.dim_size(0) == depth, |
365 | InvalidArgument("min has incorrect size, expected " , depth, |
366 | " was " , min.dim_size(0))); |
367 | OP_REQUIRES( |
368 | context, TensorShapeUtils::IsVector(max.shape()), |
369 | InvalidArgument("`max` must be rank 1 but is rank " , max.dims())); |
370 | OP_REQUIRES(context, max.dim_size(0) == depth, |
371 | InvalidArgument("max has incorrect size, expected " , depth, |
372 | " was " , max.dim_size(0))); |
373 | |
374 | Tensor* output; |
375 | OP_REQUIRES_OK(context, |
376 | context->allocate_output(0, input.shape(), &output)); |
377 | |
378 | FakeQuantWithMinMaxVarsPerChannelFunctor<Device> functor; |
379 | functor(context->eigen_device<Device>(), input.flat_inner_dims<float, 2>(), |
380 | min.vec<float>(), max.vec<float>(), quant_min_, quant_max_, |
381 | output->flat_inner_dims<float, 2>()); |
382 | } |
383 | |
384 | private: |
385 | int quant_min_; |
386 | int quant_max_; |
387 | }; |
388 | |
389 | // Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its |
390 | // documentation in core/ops/array_ops.cc. |
391 | template <typename Device> |
392 | class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel { |
393 | public: |
394 | explicit FakeQuantWithMinMaxVarsPerChannelGradientOp( |
395 | OpKernelConstruction* context) |
396 | : OpKernel::OpKernel(context) { |
397 | int num_bits; |
398 | OP_REQUIRES_OK(context, context->GetAttr("num_bits" , &num_bits)); |
399 | OP_REQUIRES( |
400 | context, IsNumBitsValid(num_bits), |
401 | InvalidArgument("num_bits must be between 2 and 16, inclusive" )); |
402 | bool narrow_range; |
403 | OP_REQUIRES_OK(context, context->GetAttr("narrow_range" , &narrow_range)); |
404 | quant_min_ = narrow_range ? 1 : 0; |
405 | quant_max_ = (1 << num_bits) - 1; |
406 | if (std::is_same<Device, Eigen::GpuDevice>::value) { |
407 | OP_REQUIRES( |
408 | context, !OpDeterminismRequired(), |
409 | errors::Unimplemented( |
410 | "Determinism is not yet supported in GPU implementation of " |
411 | "FakeQuantWithMinMaxVarsPerChannelGradient." )); |
412 | } |
413 | } |
414 | |
415 | void Compute(OpKernelContext* context) override { |
416 | CHECK_EQ(4, context->num_inputs()); |
417 | const Tensor& gradient = context->input(0); |
418 | const Tensor& input = context->input(1); |
419 | OP_REQUIRES(context, input.IsSameSize(gradient), |
420 | InvalidArgument("gradient and input must be the same size" )); |
421 | const int depth = input.dim_size(input.dims() - 1); // last dimension size. |
422 | const Tensor& min = context->input(2); |
423 | OP_REQUIRES( |
424 | context, TensorShapeUtils::IsVector(min.shape()), |
425 | InvalidArgument("`min` must be rank 1 but is rank " , min.dims())); |
426 | OP_REQUIRES(context, min.dim_size(0) == depth, |
427 | InvalidArgument("min has incorrect size, expected " , depth, |
428 | " was " , min.dim_size(0))); |
429 | const Tensor& max = context->input(3); |
430 | OP_REQUIRES( |
431 | context, TensorShapeUtils::IsVector(max.shape()), |
432 | InvalidArgument("`max` must be rank 1 but is rank " , max.dims())); |
433 | OP_REQUIRES(context, max.dim_size(0) == depth, |
434 | InvalidArgument("max has incorrect size, expected " , depth, |
435 | " was " , max.dim_size(0))); |
436 | |
437 | Tensor* grad_wrt_input; |
438 | OP_REQUIRES_OK(context, |
439 | context->allocate_output(0, input.shape(), &grad_wrt_input)); |
440 | |
441 | TensorShape min_max_shape({input.dim_size(input.dims() - 1)}); |
442 | Tensor* grad_wrt_min; |
443 | OP_REQUIRES_OK(context, |
444 | context->allocate_output(1, min_max_shape, &grad_wrt_min)); |
445 | |
446 | Tensor* grad_wrt_max; |
447 | OP_REQUIRES_OK(context, |
448 | context->allocate_output(2, min_max_shape, &grad_wrt_max)); |
449 | |
450 | FakeQuantWithMinMaxVarsPerChannelGradientFunctor<Device> functor; |
451 | functor( |
452 | context->eigen_device<Device>(), gradient.flat_inner_dims<float, 2>(), |
453 | input.flat_inner_dims<float, 2>(), min.vec<float>(), max.vec<float>(), |
454 | quant_min_, quant_max_, grad_wrt_input->flat_inner_dims<float, 2>(), |
455 | grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>()); |
456 | } |
457 | |
458 | private: |
459 | int quant_min_; |
460 | int quant_max_; |
461 | }; |
462 | |
463 | REGISTER_KERNEL_BUILDER( |
464 | Name("FakeQuantWithMinMaxVarsPerChannel" ).Device(DEVICE_CPU), |
465 | FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>); |
466 | REGISTER_KERNEL_BUILDER( |
467 | Name("FakeQuantWithMinMaxVarsPerChannelGradient" ).Device(DEVICE_CPU), |
468 | FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>); |
469 | |
470 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
471 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
472 | template <> |
473 | void FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()( |
474 | const GPUDevice& d, typename TTypes<float>::ConstMatrix inputs, |
475 | typename TTypes<float>::ConstFlat min, |
476 | typename TTypes<float>::ConstFlat max, const int quant_min, |
477 | const int quant_max, typename TTypes<float>::Matrix outputs); |
478 | extern template struct FakeQuantWithMinMaxVarsPerChannelFunctor<GPUDevice>; |
479 | |
480 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel" ) |
481 | .Device(DEVICE_GPU) |
482 | .HostMemory("min" ) |
483 | .HostMemory("max" ), |
484 | FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>); |
485 | |
486 | template <> |
487 | void FakeQuantWithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()( |
488 | const GPUDevice& d, typename TTypes<float>::ConstMatrix gradients, |
489 | typename TTypes<float>::ConstMatrix inputs, |
490 | typename TTypes<float>::ConstVec min, typename TTypes<float>::ConstVec max, |
491 | const int quant_min, const int quant_max, |
492 | typename TTypes<float>::Matrix backprops_wrt_input, |
493 | typename TTypes<float>::Vec backprop_wrt_min, |
494 | typename TTypes<float>::Vec backprop_wrt_max); |
495 | extern template struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor< |
496 | GPUDevice>; |
497 | |
498 | REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient" ) |
499 | .Device(DEVICE_GPU) |
500 | .HostMemory("min" ) |
501 | .HostMemory("max" ), |
502 | FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>); |
503 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
504 | |
505 | } // namespace tensorflow |
506 | |