1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Quantization/Base/Profile.h"
18
19#include <cmath>
20
21namespace glow {
22namespace quantization {
23
24/// Gen a bin number to insert \p value into the histogram which has \p nBins
25/// with \p minValue and binWidth in histogram.
26static size_t getBin(size_t nBins, float binWidth, float minValue,
27 float value) {
28 size_t result =
29 binWidth == 0
30 ? 0
31 : std::min(static_cast<size_t>((value - minValue) / binWidth),
32 nBins - 1);
33 return result;
34}
35
36void generateTensorHistogram(const Handle<float> inputTensor,
37 Handle<float> existingHistogram, float &min,
38 float &max) {
39 auto minMaxPos = inputTensor.minMaxArg();
40 float minInput = inputTensor.raw(minMaxPos.first);
41 float maxInput = inputTensor.raw(minMaxPos.second);
42
43 if (existingHistogram.isZero()) {
44 min = minInput;
45 max = maxInput;
46 }
47
48 size_t nBins = existingHistogram.size();
49
50 // Check if we need to rescale histogram.
51 if (minInput < min || maxInput > max) {
52 float newMin = std::min(minInput, min);
53 float newMax = std::max(maxInput, max);
54
55 float destBinWidth = (newMax - newMin) / nBins;
56 float srcBinWidth = (max - min) / nBins;
57
58 std::vector<float> scaledHistogram(nBins, 0);
59
60 for (size_t i = 0; i < nBins; ++i) {
61 if (existingHistogram.raw(i) == 0)
62 continue;
63
64 float srcBinBegin = min + srcBinWidth * i;
65 size_t destBin = (srcBinBegin - newMin) / destBinWidth;
66 float destBinEnd = newMin + destBinWidth * (destBin + 1);
67
68 float srcBinEnd = srcBinBegin + srcBinWidth;
69 size_t destBinToVerify = (srcBinEnd - newMin) / destBinWidth;
70 // Make sure that destination bin is mapped at most to 2 final bins, based
71 // on that redistribute percentage is calculated.
72 assert(destBinToVerify <= destBin + 2);
73 (void)destBinToVerify;
74
75 // Calculate how much we need to redistribute.
76 uint64_t dstBinCnt = static_cast<uint64_t>(std::min(
77 static_cast<float>(round((destBinEnd - srcBinBegin) / srcBinWidth *
78 existingHistogram.raw(i))),
79 existingHistogram.raw(i)));
80
81 size_t newBin = getBin(nBins, destBinWidth, newMin, srcBinBegin);
82 scaledHistogram[newBin] += dstBinCnt;
83
84 if (dstBinCnt < existingHistogram.raw(i)) {
85 size_t newBin =
86 getBin(nBins, destBinWidth, newMin, srcBinBegin + destBinWidth);
87 scaledHistogram[newBin] += existingHistogram.raw(i) - dstBinCnt;
88 }
89 }
90
91 // Copy scaled histogram back to the existing histogram.
92 for (size_t i = 0, e = scaledHistogram.size(); i < e; ++i) {
93 existingHistogram.raw(i) = scaledHistogram[i];
94 }
95
96 // Update global min and max.
97 min = newMin;
98 max = newMax;
99 }
100
101 float binWidth = (max - min) / nBins;
102 for (auto elem : inputTensor) {
103 size_t newBin = getBin(nBins, binWidth, min, elem);
104 existingHistogram.raw(newBin)++;
105 // Sanity check for NaN and Infinity.
106 assert(!std::isnan(elem) && "NaN value found!");
107 assert(!std::isinf(elem) && "Infinity value found!");
108 }
109}
110
111std::vector<float> rescaleHistogram(const std::vector<float> &srcHist,
112 const float srcHistMin,
113 const float srcHistMax,
114 const float destHistMin,
115 const float destHistMax) {
116
117 // If histogram is empty then return.
118 if (srcHist.size() == 0) {
119 return srcHist;
120 }
121
122 // Check if we need to rescale the histogram.
123 assert(srcHistMin < srcHistMax && "Invalid source histogram min/max range!");
124 assert(destHistMin < destHistMax &&
125 "Invalid destination histogram min/max range!");
126 if ((srcHistMin == destHistMin) && (srcHistMax == destHistMax)) {
127 return srcHist;
128 }
129
130 // Number of histogram bins and bin widths.
131 const size_t numBins = srcHist.size();
132 const float srcBinWidth = (srcHistMax - srcHistMin) / numBins;
133 const float destBinWidth = (destHistMax - destHistMin) / numBins;
134
135 // Iterate the source bins and distribute into the destination bins.
136 std::vector<float> destHist(numBins, 0);
137 for (size_t srcBinIdx = 0; srcBinIdx < numBins; srcBinIdx++) {
138
139 // Get current source bin value.
140 float srcBinVal = srcHist[srcBinIdx];
141 if (srcBinVal == 0) {
142 continue;
143 }
144
145 // Get source bin start/stop values for this bin.
146 float srcBinStart = srcHistMin + srcBinIdx * srcBinWidth;
147 float srcBinStop = srcHistMin + (srcBinIdx + 1) * srcBinWidth;
148
149 // Get destination bin indices (inclusive) which overlap with the current
150 // source bin.
151 float dstBinIdxStartF =
152 std::floor((srcBinStart - destHistMin) / destBinWidth);
153 float dstBinIdxStopF = std::ceil((srcBinStop - destHistMin) / destBinWidth);
154 size_t dstBinIdxStart = static_cast<size_t>(std::max(dstBinIdxStartF, 0.f));
155 size_t dstBinIdxStop = static_cast<size_t>(std::max(dstBinIdxStopF, 0.f));
156
157 // Upper saturate the destination bin indices.
158 if (dstBinIdxStart >= numBins) {
159 dstBinIdxStart = numBins - 1;
160 }
161 if (dstBinIdxStop >= numBins) {
162 dstBinIdxStop = numBins - 1;
163 }
164
165 // Redistribute the source bin into all the destination bins.
166 // Only integer values will be distributed.
167 float srcBinRem = srcBinVal;
168 for (size_t destBinIdx = dstBinIdxStart; destBinIdx <= dstBinIdxStop;
169 destBinIdx++) {
170
171 // Get destination bin start/stop values for this bin.
172 float destBinStart = destHistMin + destBinIdx * destBinWidth;
173 float destBinStop = destHistMin + (destBinIdx + 1) * destBinWidth;
174
175 // Get source/destination overlap boundaries and ratio.
176 float overlapStart = std::max(srcBinStart, destBinStart);
177 float overlapStop = std::min(srcBinStop, destBinStop);
178 float overlapRatio = (overlapStop - overlapStart) / srcBinWidth;
179 overlapRatio = overlapRatio >= 0.0f ? overlapRatio : 0.0f;
180 overlapRatio = overlapRatio <= 1.0f ? overlapRatio : 1.0f;
181
182 // Compute distribution value.
183 float distVal = std::round(overlapRatio * srcBinVal);
184 distVal = distVal <= srcBinRem ? distVal : srcBinRem;
185
186 // Distribute value.
187 destHist[destBinIdx] += distVal;
188 srcBinRem -= distVal;
189 }
190 }
191
192 return destHist;
193}
194
195} // namespace quantization
196} // namespace glow
197