1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/common_shape_fns.h"
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18#include "tensorflow/core/platform/errors.h"
19#include "tensorflow/core/platform/status.h"
20
21namespace tensorflow {
22namespace {
23
24using shape_inference::DimensionHandle;
25using shape_inference::ShapeHandle;
26using tensorflow::errors::InvalidArgument;
27
28Status ScalesZeroPointsShapeValid(shape_inference::InferenceContext* context,
29 DimensionHandle match_dimension_handle,
30 ShapeHandle scales, ShapeHandle zero_points) {
31 const int32_t scales_rank = shape_inference::InferenceContext::Rank(scales);
32 const int32_t zero_points_rank =
33 shape_inference::InferenceContext::Rank(zero_points);
34 // Skip validation when rank is unknown.
35 if (scales_rank == shape_inference::InferenceContext::kUnknownRank ||
36 zero_points_rank == shape_inference::InferenceContext::kUnknownRank) {
37 return OkStatus();
38 }
39
40 if (scales_rank != zero_points_rank) {
41 return InvalidArgument("scales and zero_points must have same rank.");
42 }
43 if (scales_rank == 0) {
44 return OkStatus();
45 }
46 DimensionHandle scales_size = context->Dim(scales, 0);
47 DimensionHandle zero_points_size = context->Dim(zero_points, 0);
48 DimensionHandle merged_scales;
49 TF_RETURN_IF_ERROR(
50 context->Merge(scales_size, match_dimension_handle, &merged_scales));
51 DimensionHandle merged_zero_points;
52 TF_RETURN_IF_ERROR(context->Merge(zero_points_size, match_dimension_handle,
53 &merged_zero_points));
54 return OkStatus();
55}
56
57Status DotShape(shape_inference::InferenceContext* context) {
58 ShapeHandle lhs;
59 TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs));
60 ShapeHandle rhs;
61 TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 2, &rhs));
62 // lhs scales and zero_points must be scalar tensors.
63 ShapeHandle lhs_scales;
64 TF_RETURN_IF_ERROR(
65 context->WithRankAtMost(context->input(2), 0, &lhs_scales));
66 ShapeHandle lhs_zero_points;
67 TF_RETURN_IF_ERROR(
68 context->WithRankAtMost(context->input(3), 0, &lhs_zero_points));
69 ShapeHandle rhs_scales;
70 TF_RETURN_IF_ERROR(
71 context->WithRankAtMost(context->input(4), 1, &rhs_scales));
72 ShapeHandle rhs_zero_points;
73 TF_RETURN_IF_ERROR(
74 context->WithRankAtMost(context->input(5), 1, &rhs_zero_points));
75 ShapeHandle output_scales;
76 TF_RETURN_IF_ERROR(
77 context->WithRankAtMost(context->input(6), 1, &output_scales));
78 ShapeHandle output_zero_points;
79 TF_RETURN_IF_ERROR(
80 context->WithRankAtMost(context->input(7), 1, &output_zero_points));
81
82 // Validate that the inner shapes are compatible.
83 DimensionHandle inner_lhs = context->Dim(lhs, 1);
84 DimensionHandle inner_rhs = context->Dim(rhs, 0);
85 DimensionHandle merged;
86 TF_RETURN_IF_ERROR(context->Merge(inner_lhs, inner_rhs, &merged));
87
88 DimensionHandle output_rows = context->Dim(lhs, 0);
89 DimensionHandle output_cols = context->Dim(rhs, 1);
90
91 TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid(context, output_cols,
92 rhs_scales, rhs_zero_points));
93 TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid(
94 context, output_cols, output_scales, output_zero_points));
95
96 context->set_output(0, context->Matrix(output_rows, output_cols));
97 return OkStatus();
98}
99
100Status DotHybridShape(shape_inference::InferenceContext* context) {
101 ShapeHandle lhs;
102 TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs));
103 ShapeHandle rhs;
104 TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 2, &rhs));
105 ShapeHandle rhs_scales;
106 TF_RETURN_IF_ERROR(
107 context->WithRankAtMost(context->input(2), 1, &rhs_scales));
108 ShapeHandle rhs_zero_points;
109 TF_RETURN_IF_ERROR(
110 context->WithRankAtMost(context->input(3), 1, &rhs_zero_points));
111
112 // Validate that the inner shapes are compatible.
113 DimensionHandle inner_lhs = context->Dim(lhs, 1);
114 DimensionHandle inner_rhs = context->Dim(rhs, 0);
115 DimensionHandle merged;
116 TF_RETURN_IF_ERROR(context->Merge(inner_lhs, inner_rhs, &merged));
117
118 DimensionHandle output_rows = context->Dim(lhs, 0);
119 DimensionHandle output_cols = context->Dim(rhs, 1);
120
121 TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid(context, output_cols,
122 rhs_scales, rhs_zero_points));
123
124 context->set_output(0, context->Matrix(output_rows, output_cols));
125 return OkStatus();
126}
127
128} // namespace
129
130REGISTER_OP("UniformQuantize")
131 .Input("input: Tin")
132 .Input("scales: float")
133 .Input("zero_points: int32")
134 .Output("output: Tout")
135 .Attr("Tin: {float}")
136 .Attr("Tout: {qint8, qint32}")
137 .Attr("quantization_axis: int = -1")
138 .Attr("quantization_min_val: int")
139 .Attr("quantization_max_val: int")
140 .SetShapeFn(shape_inference::UnchangedShape);
141
142REGISTER_OP("UniformRequantize")
143 .Input("input: Tin")
144 .Input("input_scales: float")
145 .Input("input_zero_points: int32")
146 .Input("output_scales: float")
147 .Input("output_zero_points: int32")
148 .Output("output: Tout")
149 .Attr("Tin: {qint8, qint32}")
150 .Attr("Tout: {qint8, qint32}")
151 .Attr("input_quantization_axis: int = -1")
152 .Attr("input_quantization_min_val: int")
153 .Attr("input_quantization_max_val: int")
154 .Attr("output_quantization_axis: int = -1")
155 .Attr("output_quantization_min_val: int")
156 .Attr("output_quantization_max_val: int")
157 .SetShapeFn(shape_inference::UnchangedShape);
158
159REGISTER_OP("UniformDequantize")
160 .Input("input: Tin")
161 .Input("scales: float")
162 .Input("zero_points: int32")
163 .Output("output: Tout")
164 .Attr("Tin: {qint8, qint32}")
165 .Attr("Tout: {float}")
166 .Attr("quantization_axis: int = -1")
167 .Attr("quantization_min_val: int")
168 .Attr("quantization_max_val: int")
169 .SetShapeFn(shape_inference::UnchangedShape);
170
171REGISTER_OP("UniformQuantizedDot")
172 .Input("lhs: Tin")
173 .Input("rhs: Tin")
174 .Input("lhs_scales: float")
175 .Input("lhs_zero_points: int32")
176 .Input("rhs_scales: float")
177 .Input("rhs_zero_points: int32")
178 .Input("output_scales: float")
179 .Input("output_zero_points: int32")
180 .Output("output: Tout")
181 .Attr("Tin: {qint8}")
182 .Attr("Tout: {qint32}")
183 .Attr("lhs_quantization_axis: int = -1")
184 .Attr("lhs_quantization_min_val: int")
185 .Attr("lhs_quantization_max_val: int")
186 .Attr("rhs_quantization_axis: int = -1")
187 .Attr("rhs_quantization_min_val: int")
188 .Attr("rhs_quantization_max_val: int")
189 .Attr("output_quantization_axis: int = -1")
190 .Attr("output_quantization_min_val: int")
191 .Attr("output_quantization_max_val: int")
192 .SetShapeFn(DotShape);
193
194REGISTER_OP("UniformQuantizedDotHybrid")
195 .Input("lhs: Tlhs")
196 .Input("rhs: Trhs")
197 .Input("rhs_scales: float")
198 .Input("rhs_zero_points: int32")
199 .Output("output: Tout")
200 .Attr("Tlhs: {float}")
201 .Attr("Trhs: {qint8}")
202 .Attr("Tout: {float}")
203 .Attr("rhs_quantization_axis: int = -1")
204 .Attr("rhs_quantization_min_val: int")
205 .Attr("rhs_quantization_max_val: int")
206 .SetShapeFn(DotHybridShape);
207
208REGISTER_OP("UniformQuantizedClipByValue")
209 .Input("operand: T")
210 .Input("min: T")
211 .Input("max: T")
212 .Input("scales: float")
213 .Input("zero_points: int32")
214 .Output("output: T")
215 .Attr("T: {qint32}")
216 .Attr("quantization_axis: int = -1")
217 .Attr("quantization_min_val: int")
218 .Attr("quantization_max_val: int")
219 .SetShapeFn(shape_inference::UnchangedShape);
220
221} // namespace tensorflow
222