1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Quantization/Base/Calibration.h" |
18 | |
19 | #include <cmath> |
20 | #include <numeric> |
21 | |
22 | namespace glow { |
23 | namespace quantization { |
24 | |
25 | /// Function to prepare the histogram \p hist with length \p length before |
26 | /// computing the relative entropy. The histogram is NOT normalized and is |
27 | /// assumed to be a sequence of positive integers (0,1,...) stored as a float |
28 | /// sequence to accommodate very large values. The histogram is conditioned |
29 | /// in the following way: zero values are replaced with \p epsZero while the |
30 | /// corresponding amount is subtracted from the non-zero values. |
31 | static void conditionHistogram(float *hist, const size_t length, |
32 | const float epsZero = 0.0001) { |
33 | |
34 | // If histogram is empty then return. |
35 | if (length == 0) { |
36 | return; |
37 | } |
38 | |
39 | // Get information about the zero values within the histogram. |
40 | std::vector<int> isZero(length); |
41 | size_t numZeros = 0; |
42 | for (size_t idx = 0, e = length; idx < e; idx++) { |
43 | isZero[idx] = static_cast<int>(hist[idx] == 0.f); |
44 | numZeros += isZero[idx]; |
45 | } |
46 | |
47 | // If histogram is all zeros then return. |
48 | if (numZeros == length) { |
49 | return; |
50 | } |
51 | |
52 | // Compute epsilon to subtract from non-zero histogram values. |
53 | size_t numNonZeros = length - numZeros; |
54 | float epsNonZero = |
55 | epsZero * static_cast<float>(numZeros) / static_cast<float>(numNonZeros); |
56 | |
57 | // If value to subtract from non-zero values is higher than 1.0 then return. |
58 | if (epsNonZero >= 1.0) { |
59 | return; |
60 | } |
61 | |
62 | // Perform histogram conditioning: |
63 | // - zero histogram values are increased with epsZero. |
64 | // - non-zero histogram values are decreased with epsNonZero. |
65 | for (size_t idx = 0, e = length; idx < e; idx++) { |
66 | hist[idx] += epsZero * isZero[idx]; |
67 | hist[idx] -= epsNonZero * (1 - isZero[idx]); |
68 | } |
69 | } |
70 | |
71 | /// Function to compute the Kullback-Leibler divergence (relative entropy) |
72 | /// of the distribution \p P with respect to \p Q denoted with D(P||Q) and |
73 | /// defined as: |
74 | /// D(P||Q) = sum(P[k] * log(P[k] / Q[k])) |
75 | /// Depending on the base of the logarithm, the unit of measurement for the |
76 | /// divergence is "bits" for log2 and "nats" for ln (natural logarithm). |
77 | /// This function does NOT require the distributions \p P and \p Q to be |
78 | /// normalized because are normalized automatically in-place. The length |
79 | /// of the distributions is \p length. \returns the relative entropy metric |
80 | /// as a float scalar. |
81 | /// The meaning of this metric is the amount of information (entropy) lost |
82 | /// when the distribution \p Q is used to approximate the ground truth |
83 | /// (reference) distribution \p P. The divergence metric is always positive |
84 | /// that is D(P||Q) >= 0. |
85 | static float computeKL(float *P, float *Q, size_t length) { |
86 | |
87 | // Compute sum of P and Q to use for normalization. |
88 | float sumP = std::accumulate(P, P + length, 0.f); |
89 | float sumQ = std::accumulate(Q, Q + length, 0.f); |
90 | |
91 | // Return 0 when one of the distributions is all zero. |
92 | if ((sumP == 0.f) || (sumQ == 0.f)) { |
93 | return 0; |
94 | } |
95 | |
96 | // Compute relative entropy. |
97 | float divergence = 0; |
98 | for (size_t idx = 0, e = length; idx < e; idx++) { |
99 | P[idx] /= sumP; |
100 | Q[idx] /= sumQ; |
101 | if ((P[idx] > 0) && (Q[idx] > 0)) { |
102 | divergence += P[idx] * std::log(P[idx] / Q[idx]); |
103 | } |
104 | } |
105 | return divergence; |
106 | } |
107 | |
108 | FloatRange optimizeKL(const std::vector<float> &hist, const float histMin, |
109 | const float histMax, const size_t numQuantizedBins, |
110 | const bool symmetric) { |
111 | |
112 | // Number of histogram bins. |
113 | const size_t numBins = hist.size(); |
114 | |
115 | // If the input histogram is empty or the number of histogram bins is smaller |
116 | // than numQuantizedBins then return the histogram range. |
117 | if ((numBins == 0) || (numBins < numQuantizedBins)) { |
118 | return {histMin, histMax}; |
119 | } |
120 | |
121 | // Histogram bin width. |
122 | assert(histMin < histMax && "Invalid histogram min/max range!" ); |
123 | const float histBinWidth = (histMax - histMin) / (float)numBins; |
124 | |
125 | // Optimal divergence value (minimum). |
126 | float divergenceOpt = std::numeric_limits<float>::infinity(); |
127 | |
128 | // Optimal threshold values for minimum divergence. |
129 | float thresholdMinOpt = histMin; |
130 | float thresholdMaxOpt = histMax; |
131 | |
132 | // Initialize start/stop bin indices (inclusive) with the first and last bin. |
133 | size_t histWinIdxStart = 0; |
134 | size_t histWinIdxStop = numBins - 1; |
135 | |
136 | // Start iterations by increasingly saturating the input histogram while |
137 | // the histogram window is larger or equal to numQuantizedBins. The expected |
138 | // behavior of the computed divergence is either: |
139 | // (1) increase monotonically in which case it would make sense to include |
140 | // some logic to exit the loop prematurely in order to not waste time. |
141 | // (2) either slightly decrease in the first iterations, settle in a local |
142 | // minimum and then increase monotonically. This is the case in which |
143 | // this algorithm hopes to achieve better ranges for quantizing a tensor. |
144 | while ((histWinIdxStop - histWinIdxStart + 1) >= numQuantizedBins) { |
145 | |
146 | // Current histogram window size. |
147 | const size_t histWinSize = histWinIdxStop - histWinIdxStart + 1; |
148 | |
149 | // Current histogram window raw pointer. |
150 | const float *histWinPtr = hist.data() + histWinIdxStart; |
151 | |
152 | // Note: MXNet / TVM have an error in their programs since they explicitly |
153 | // extract the histogram window in the variable 'sliced_nd_hist' which has |
154 | // always 0 on the first position that is sliced_nd_hist.front() == 0. |
155 | |
156 | // ------------------------------------------------------------------------- |
157 | // Compute the reference distribution P as the input histogram saturated in |
158 | // the current window given by histWinIdxStart and histWinIdxStop. |
159 | // ------------------------------------------------------------------------- |
160 | std::vector<float> P(histWinSize); |
161 | |
162 | // Saturate the histogram left. |
163 | float leftSum = 0; |
164 | for (size_t histIdx = 0; histIdx <= histWinIdxStart; histIdx++) { |
165 | leftSum += hist[histIdx]; |
166 | } |
167 | P.front() += leftSum; |
168 | |
169 | // Extract the non-saturated part of the histogram. |
170 | for (size_t histIdx = histWinIdxStart + 1; histIdx < histWinIdxStop; |
171 | histIdx++) { |
172 | P[histIdx - histWinIdxStart] = hist[histIdx]; |
173 | } |
174 | |
175 | // Saturate the histogram right. |
176 | float rightSum = 0; |
177 | for (size_t histIdx = histWinIdxStop; histIdx < numBins; histIdx++) { |
178 | rightSum += hist[histIdx]; |
179 | } |
180 | P.back() += rightSum; |
181 | |
182 | // ------------------------------------------------------------------------- |
183 | // Compute the approximation distribution Q as the input histogram sliced in |
184 | // the current window given by histWinIdxStart and histWinIdxStop, rescaled |
185 | // to numQuantizedBins and then expanded back to the current window length. |
186 | // ------------------------------------------------------------------------- |
187 | // The bins from the current histogram window are distributed equally in the |
188 | // quantized bins. The remainder is distributed in the last quantized bin. |
189 | assert(histWinSize >= numQuantizedBins && "Invalid histogram window size!" ); |
190 | const size_t numMergedBins = histWinSize / numQuantizedBins; |
191 | |
192 | // Compute Q. |
193 | std::vector<float> Q(histWinSize, 0); |
194 | for (size_t qIdx = 0; qIdx < numQuantizedBins; qIdx++) { |
195 | |
196 | // Histogram window bin start index (inclusive) for this quantized bin. |
197 | const size_t idxStart = qIdx * numMergedBins; |
198 | |
199 | // Histogram window bin stop index (exclusive) for this quantized bin. |
200 | // If last quantized bin then go to the end of the window. |
201 | const size_t idxStop = (qIdx < (numQuantizedBins - 1)) |
202 | ? (idxStart + numMergedBins) |
203 | : histWinSize; |
204 | |
205 | // Sum all the values for this quantized bin. |
206 | // Count all the non-negative values for this quantized bin to use for |
207 | // normalization. |
208 | float sum = 0; |
209 | size_t norm = 0; |
210 | for (size_t idx = idxStart; idx < idxStop; idx++) { |
211 | sum += histWinPtr[idx]; |
212 | norm += (histWinPtr[idx] != 0); |
213 | } |
214 | |
215 | // Compute Q by expanding and normalizing the quantized bins. |
216 | if (norm != 0) { |
217 | for (size_t idx = idxStart; idx < idxStop; idx++) { |
218 | if (P[idx]) { |
219 | Q[idx] = sum / (float)norm; |
220 | } |
221 | } |
222 | } |
223 | } |
224 | |
225 | // ------------------------------------------------------------------------- |
226 | // Compute the KL divergence metric and check for optimal values. |
227 | // ------------------------------------------------------------------------- |
228 | // Condition the histograms P and Q. |
229 | conditionHistogram(P.data(), P.size()); |
230 | conditionHistogram(Q.data(), Q.size()); |
231 | |
232 | // Compute the divergence of P with respect to Q. |
233 | float divergence = computeKL(P.data(), Q.data(), P.size()); |
234 | |
235 | // Check if current divergence is the new optimal. |
236 | if (divergence < divergenceOpt) { |
237 | |
238 | // Update optimal divergence with current divergence. |
239 | divergenceOpt = divergence; |
240 | |
241 | // Update optimal thresholds with current thresholds. |
242 | thresholdMinOpt = histMin + histWinIdxStart * histBinWidth; |
243 | thresholdMaxOpt = histMin + (histWinIdxStop + 1) * histBinWidth; |
244 | } |
245 | |
246 | // ------------------------------------------------------------------------- |
247 | // Update histogram window for next iteration. |
248 | // ------------------------------------------------------------------------- |
249 | if (symmetric) { |
250 | // For symmetric schema we shrink the histogram window symmetrically. |
251 | histWinIdxStart++; |
252 | histWinIdxStop--; |
253 | |
254 | } else { |
255 | // For asymmetric schema we shrink the histogram window either left-only, |
256 | // right-only or symmetrically depending on which case has minimum |
257 | // histogram data loss. |
258 | float symmLoss = hist[histWinIdxStart] + hist[histWinIdxStop]; |
259 | float leftLoss = hist[histWinIdxStart] + hist[histWinIdxStart + 1]; |
260 | float rightLoss = hist[histWinIdxStop] + hist[histWinIdxStop - 1]; |
261 | |
262 | std::vector<float> loss = {symmLoss, leftLoss, rightLoss}; |
263 | auto lossMinIdx = std::distance( |
264 | loss.begin(), std::min_element(loss.begin(), loss.end())); |
265 | if (lossMinIdx == 0) { |
266 | // Saturate symmetrically. |
267 | histWinIdxStart++; |
268 | histWinIdxStop--; |
269 | } else if (lossMinIdx == 1) { |
270 | // Saturate left. |
271 | histWinIdxStart += 2; |
272 | } else { |
273 | // Saturate right. |
274 | histWinIdxStop -= 2; |
275 | } |
276 | } |
277 | } |
278 | |
279 | // For symmetric schema we must make sure the optimized thresholds maintain |
280 | // the same ratio as the input min/max of the histogram in order to map the |
281 | // zero-point to quantized 0. |
282 | if (symmetric) { |
283 | assert(histMin < 0 && "Invalid histogram minimum!" ); |
284 | assert(histMax > 0 && "Invalid histogram maximum!" ); |
285 | assert(thresholdMinOpt < 0 && "Invalid threshold minimum!" ); |
286 | assert(thresholdMaxOpt > 0 && "Invalid threshold maximum!" ); |
287 | double ratioMin = (double)thresholdMinOpt / (double)histMin; |
288 | double ratioMax = (double)thresholdMaxOpt / (double)histMax; |
289 | if (ratioMin > ratioMax) { |
290 | thresholdMaxOpt = ratioMin * histMax; |
291 | } else { |
292 | thresholdMinOpt = ratioMax * histMin; |
293 | } |
294 | } |
295 | |
296 | return {thresholdMinOpt, thresholdMaxOpt}; |
297 | } |
298 | |
299 | } // namespace quantization |
300 | } // namespace glow |
301 | |