1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/framework/common_shape_fns.h" |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | #include "tensorflow/core/platform/errors.h" |
19 | #include "tensorflow/core/platform/status.h" |
20 | |
21 | namespace tensorflow { |
22 | namespace { |
23 | |
24 | using shape_inference::DimensionHandle; |
25 | using shape_inference::ShapeHandle; |
26 | using tensorflow::errors::InvalidArgument; |
27 | |
28 | Status ScalesZeroPointsShapeValid(shape_inference::InferenceContext* context, |
29 | DimensionHandle match_dimension_handle, |
30 | ShapeHandle scales, ShapeHandle zero_points) { |
31 | const int32_t scales_rank = shape_inference::InferenceContext::Rank(scales); |
32 | const int32_t zero_points_rank = |
33 | shape_inference::InferenceContext::Rank(zero_points); |
34 | // Skip validation when rank is unknown. |
35 | if (scales_rank == shape_inference::InferenceContext::kUnknownRank || |
36 | zero_points_rank == shape_inference::InferenceContext::kUnknownRank) { |
37 | return OkStatus(); |
38 | } |
39 | |
40 | if (scales_rank != zero_points_rank) { |
41 | return InvalidArgument("scales and zero_points must have same rank." ); |
42 | } |
43 | if (scales_rank == 0) { |
44 | return OkStatus(); |
45 | } |
46 | DimensionHandle scales_size = context->Dim(scales, 0); |
47 | DimensionHandle zero_points_size = context->Dim(zero_points, 0); |
48 | DimensionHandle merged_scales; |
49 | TF_RETURN_IF_ERROR( |
50 | context->Merge(scales_size, match_dimension_handle, &merged_scales)); |
51 | DimensionHandle merged_zero_points; |
52 | TF_RETURN_IF_ERROR(context->Merge(zero_points_size, match_dimension_handle, |
53 | &merged_zero_points)); |
54 | return OkStatus(); |
55 | } |
56 | |
57 | Status DotShape(shape_inference::InferenceContext* context) { |
58 | ShapeHandle lhs; |
59 | TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs)); |
60 | ShapeHandle rhs; |
61 | TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 2, &rhs)); |
62 | // lhs scales and zero_points must be scalar tensors. |
63 | ShapeHandle lhs_scales; |
64 | TF_RETURN_IF_ERROR( |
65 | context->WithRankAtMost(context->input(2), 0, &lhs_scales)); |
66 | ShapeHandle lhs_zero_points; |
67 | TF_RETURN_IF_ERROR( |
68 | context->WithRankAtMost(context->input(3), 0, &lhs_zero_points)); |
69 | ShapeHandle rhs_scales; |
70 | TF_RETURN_IF_ERROR( |
71 | context->WithRankAtMost(context->input(4), 1, &rhs_scales)); |
72 | ShapeHandle rhs_zero_points; |
73 | TF_RETURN_IF_ERROR( |
74 | context->WithRankAtMost(context->input(5), 1, &rhs_zero_points)); |
75 | ShapeHandle output_scales; |
76 | TF_RETURN_IF_ERROR( |
77 | context->WithRankAtMost(context->input(6), 1, &output_scales)); |
78 | ShapeHandle output_zero_points; |
79 | TF_RETURN_IF_ERROR( |
80 | context->WithRankAtMost(context->input(7), 1, &output_zero_points)); |
81 | |
82 | // Validate that the inner shapes are compatible. |
83 | DimensionHandle inner_lhs = context->Dim(lhs, 1); |
84 | DimensionHandle inner_rhs = context->Dim(rhs, 0); |
85 | DimensionHandle merged; |
86 | TF_RETURN_IF_ERROR(context->Merge(inner_lhs, inner_rhs, &merged)); |
87 | |
88 | DimensionHandle output_rows = context->Dim(lhs, 0); |
89 | DimensionHandle output_cols = context->Dim(rhs, 1); |
90 | |
91 | TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid(context, output_cols, |
92 | rhs_scales, rhs_zero_points)); |
93 | TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid( |
94 | context, output_cols, output_scales, output_zero_points)); |
95 | |
96 | context->set_output(0, context->Matrix(output_rows, output_cols)); |
97 | return OkStatus(); |
98 | } |
99 | |
100 | Status DotHybridShape(shape_inference::InferenceContext* context) { |
101 | ShapeHandle lhs; |
102 | TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs)); |
103 | ShapeHandle rhs; |
104 | TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 2, &rhs)); |
105 | ShapeHandle rhs_scales; |
106 | TF_RETURN_IF_ERROR( |
107 | context->WithRankAtMost(context->input(2), 1, &rhs_scales)); |
108 | ShapeHandle rhs_zero_points; |
109 | TF_RETURN_IF_ERROR( |
110 | context->WithRankAtMost(context->input(3), 1, &rhs_zero_points)); |
111 | |
112 | // Validate that the inner shapes are compatible. |
113 | DimensionHandle inner_lhs = context->Dim(lhs, 1); |
114 | DimensionHandle inner_rhs = context->Dim(rhs, 0); |
115 | DimensionHandle merged; |
116 | TF_RETURN_IF_ERROR(context->Merge(inner_lhs, inner_rhs, &merged)); |
117 | |
118 | DimensionHandle output_rows = context->Dim(lhs, 0); |
119 | DimensionHandle output_cols = context->Dim(rhs, 1); |
120 | |
121 | TF_RETURN_IF_ERROR(ScalesZeroPointsShapeValid(context, output_cols, |
122 | rhs_scales, rhs_zero_points)); |
123 | |
124 | context->set_output(0, context->Matrix(output_rows, output_cols)); |
125 | return OkStatus(); |
126 | } |
127 | |
128 | } // namespace |
129 | |
130 | REGISTER_OP("UniformQuantize" ) |
131 | .Input("input: Tin" ) |
132 | .Input("scales: float" ) |
133 | .Input("zero_points: int32" ) |
134 | .Output("output: Tout" ) |
135 | .Attr("Tin: {float}" ) |
136 | .Attr("Tout: {qint8, qint32}" ) |
137 | .Attr("quantization_axis: int = -1" ) |
138 | .Attr("quantization_min_val: int" ) |
139 | .Attr("quantization_max_val: int" ) |
140 | .SetShapeFn(shape_inference::UnchangedShape); |
141 | |
142 | REGISTER_OP("UniformRequantize" ) |
143 | .Input("input: Tin" ) |
144 | .Input("input_scales: float" ) |
145 | .Input("input_zero_points: int32" ) |
146 | .Input("output_scales: float" ) |
147 | .Input("output_zero_points: int32" ) |
148 | .Output("output: Tout" ) |
149 | .Attr("Tin: {qint8, qint32}" ) |
150 | .Attr("Tout: {qint8, qint32}" ) |
151 | .Attr("input_quantization_axis: int = -1" ) |
152 | .Attr("input_quantization_min_val: int" ) |
153 | .Attr("input_quantization_max_val: int" ) |
154 | .Attr("output_quantization_axis: int = -1" ) |
155 | .Attr("output_quantization_min_val: int" ) |
156 | .Attr("output_quantization_max_val: int" ) |
157 | .SetShapeFn(shape_inference::UnchangedShape); |
158 | |
159 | REGISTER_OP("UniformDequantize" ) |
160 | .Input("input: Tin" ) |
161 | .Input("scales: float" ) |
162 | .Input("zero_points: int32" ) |
163 | .Output("output: Tout" ) |
164 | .Attr("Tin: {qint8, qint32}" ) |
165 | .Attr("Tout: {float}" ) |
166 | .Attr("quantization_axis: int = -1" ) |
167 | .Attr("quantization_min_val: int" ) |
168 | .Attr("quantization_max_val: int" ) |
169 | .SetShapeFn(shape_inference::UnchangedShape); |
170 | |
171 | REGISTER_OP("UniformQuantizedDot" ) |
172 | .Input("lhs: Tin" ) |
173 | .Input("rhs: Tin" ) |
174 | .Input("lhs_scales: float" ) |
175 | .Input("lhs_zero_points: int32" ) |
176 | .Input("rhs_scales: float" ) |
177 | .Input("rhs_zero_points: int32" ) |
178 | .Input("output_scales: float" ) |
179 | .Input("output_zero_points: int32" ) |
180 | .Output("output: Tout" ) |
181 | .Attr("Tin: {qint8}" ) |
182 | .Attr("Tout: {qint32}" ) |
183 | .Attr("lhs_quantization_axis: int = -1" ) |
184 | .Attr("lhs_quantization_min_val: int" ) |
185 | .Attr("lhs_quantization_max_val: int" ) |
186 | .Attr("rhs_quantization_axis: int = -1" ) |
187 | .Attr("rhs_quantization_min_val: int" ) |
188 | .Attr("rhs_quantization_max_val: int" ) |
189 | .Attr("output_quantization_axis: int = -1" ) |
190 | .Attr("output_quantization_min_val: int" ) |
191 | .Attr("output_quantization_max_val: int" ) |
192 | .SetShapeFn(DotShape); |
193 | |
194 | REGISTER_OP("UniformQuantizedDotHybrid" ) |
195 | .Input("lhs: Tlhs" ) |
196 | .Input("rhs: Trhs" ) |
197 | .Input("rhs_scales: float" ) |
198 | .Input("rhs_zero_points: int32" ) |
199 | .Output("output: Tout" ) |
200 | .Attr("Tlhs: {float}" ) |
201 | .Attr("Trhs: {qint8}" ) |
202 | .Attr("Tout: {float}" ) |
203 | .Attr("rhs_quantization_axis: int = -1" ) |
204 | .Attr("rhs_quantization_min_val: int" ) |
205 | .Attr("rhs_quantization_max_val: int" ) |
206 | .SetShapeFn(DotHybridShape); |
207 | |
208 | REGISTER_OP("UniformQuantizedClipByValue" ) |
209 | .Input("operand: T" ) |
210 | .Input("min: T" ) |
211 | .Input("max: T" ) |
212 | .Input("scales: float" ) |
213 | .Input("zero_points: int32" ) |
214 | .Output("output: T" ) |
215 | .Attr("T: {qint32}" ) |
216 | .Attr("quantization_axis: int = -1" ) |
217 | .Attr("quantization_min_val: int" ) |
218 | .Attr("quantization_max_val: int" ) |
219 | .SetShapeFn(shape_inference::UnchangedShape); |
220 | |
221 | } // namespace tensorflow |
222 | |