1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements a quantized version of the Relu6 operation.
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/framework/numeric_op.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/kernels/meta_support.h"
23#include "tensorflow/core/kernels/quantization_utils.h"
24#include "tensorflow/core/lib/core/errors.h"
25
26namespace tensorflow {
27
28template <typename T>
29class QuantizedReluOp : public OpKernel {
30 public:
31 explicit QuantizedReluOp(OpKernelConstruction* context) : OpKernel(context) {}
32
33 void Compute(OpKernelContext* context) override {
34 const Tensor& input = context->input(0);
35 const Tensor& min_input_tensor = context->input(1);
36 const Tensor& max_input_tensor = context->input(2);
37
38 OP_REQUIRES(
39 context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
40 errors::InvalidArgument("`min_input` must be rank 0 but is rank ",
41 min_input_tensor.dims()));
42 OP_REQUIRES(
43 context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
44 errors::InvalidArgument("`max_input` must be rank 0 but is rank ",
45 max_input_tensor.dims()));
46
47 const float min_input = min_input_tensor.scalar<float>()();
48 const float max_input = max_input_tensor.scalar<float>()();
49
50 Tensor* output = nullptr;
51 OP_REQUIRES_OK(context,
52 context->allocate_output(0, input.shape(), &output));
53 const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input);
54
55 if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
56 auto input_ui8_array = input.flat<quint8>();
57 meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(),
58 min_as_quantized, 255, output->flat<quint8>().data());
59 } else {
60 output->flat<T>().device(context->eigen_cpu_device()) =
61 input.flat<T>().cwiseMax(min_as_quantized).template cast<T>();
62 }
63
64 Tensor* output_min = nullptr;
65 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
66 output_min->flat<float>()(0) = min_input;
67 Tensor* output_max = nullptr;
68 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
69 output_max->flat<float>()(0) = max_input;
70 }
71};
72
73template <typename T>
74class QuantizedRelu6Op : public OpKernel {
75 public:
76 explicit QuantizedRelu6Op(OpKernelConstruction* context)
77 : OpKernel(context) {}
78
79 void Compute(OpKernelContext* context) override {
80 const Tensor& input = context->input(0);
81 const Tensor& min_input_tensor = context->input(1);
82 const Tensor& max_input_tensor = context->input(2);
83
84 OP_REQUIRES(
85 context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
86 errors::InvalidArgument("`min_input` must be rank 0 but is rank ",
87 min_input_tensor.dims()));
88 OP_REQUIRES(
89 context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
90 errors::InvalidArgument("`max_input` must be rank 0 but is rank ",
91 max_input_tensor.dims()));
92
93 const float min_input = min_input_tensor.scalar<float>()();
94 const float max_input = max_input_tensor.scalar<float>()();
95
96 Tensor* output = nullptr;
97 OP_REQUIRES_OK(context,
98 context->allocate_output(0, input.shape(), &output));
99 const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input);
100 const T max_as_quantized = FloatToQuantized<T>(6.0f, min_input, max_input);
101
102 if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
103 auto input_ui8_array = input.flat<quint8>();
104 meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(),
105 min_as_quantized, max_as_quantized,
106 output->flat<quint8>().data());
107 } else {
108 output->flat<T>().device(context->eigen_cpu_device()) =
109 input.flat<T>()
110 .cwiseMax(min_as_quantized)
111 .cwiseMin(max_as_quantized)
112 .template cast<T>();
113 }
114
115 Tensor* output_min = nullptr;
116 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
117 output_min->flat<float>()(0) = min_input;
118 Tensor* output_max = nullptr;
119 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
120 output_max->flat<float>()(0) = max_input;
121 }
122};
123
124REGISTER_KERNEL_BUILDER(Name("QuantizedRelu")
125 .Device(DEVICE_CPU)
126 .TypeConstraint<qint32>("Tinput")
127 .TypeConstraint<qint32>("out_type"),
128 QuantizedReluOp<qint32>);
129REGISTER_KERNEL_BUILDER(Name("QuantizedRelu")
130 .Device(DEVICE_CPU)
131 .TypeConstraint<quint8>("Tinput")
132 .TypeConstraint<quint8>("out_type"),
133 QuantizedReluOp<quint8>);
134
135REGISTER_KERNEL_BUILDER(Name("QuantizedRelu6")
136 .Device(DEVICE_CPU)
137 .TypeConstraint<qint32>("Tinput")
138 .TypeConstraint<qint32>("out_type"),
139 QuantizedRelu6Op<qint32>);
140REGISTER_KERNEL_BUILDER(Name("QuantizedRelu6")
141 .Device(DEVICE_CPU)
142 .TypeConstraint<quint8>("Tinput")
143 .TypeConstraint<quint8>("out_type"),
144 QuantizedRelu6Op<quint8>);
145} // namespace tensorflow
146