1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Backends/Interpreter/Interpreter.h" |
18 | |
19 | #include "glow/Base/TensorSerialization.h" |
20 | #include "glow/Base/Type.h" |
21 | #include "glow/IR/Instrs.h" |
22 | #include "glow/Quantization/Base/Base.h" |
23 | #include "glow/Quantization/Base/Profile.h" |
24 | |
25 | #include "llvm/ADT/SmallVector.h" |
26 | #include "llvm/Support/Casting.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | |
29 | #include <chrono> |
30 | #include <cmath> |
31 | #include <numeric> |
32 | #include <random> |
33 | |
34 | #ifdef WIN32 |
35 | #include <corecrt_math_defines.h> |
36 | #endif |
37 | |
38 | using namespace glow; |
39 | |
40 | namespace IntNBitSplitEmbeddingBagsHelper { |
41 | inline int32_t unpaddedRowSizeInBytes(int32_t dim, |
42 | SplitEmbeddingSparseType weight_ty) { |
43 | if (weight_ty == SplitEmbeddingSparseType::EST_FLOAT) { |
44 | return dim * sizeof(float); |
45 | } |
46 | if (weight_ty == SplitEmbeddingSparseType::EST_FLOAT16) { |
47 | return dim * sizeof(float16_t); |
48 | } |
49 | if (weight_ty == SplitEmbeddingSparseType::EST_INT8) { |
50 | return dim + 2 * sizeof(float16_t); |
51 | } |
52 | if (weight_ty == SplitEmbeddingSparseType::EST_INT4) { |
53 | return dim / 2 + 2 * sizeof(float16_t); |
54 | } |
55 | if (weight_ty == SplitEmbeddingSparseType::EST_INT2) { |
56 | return dim / 4 + 2 * sizeof(float16_t); |
57 | } |
58 | llvm_unreachable("Unsupported SparseType" ); |
59 | } |
60 | |
61 | uint32_t roundUp(uint32_t a, uint32_t b) { return ((a + b - 1) / b) * b; } |
62 | |
63 | inline int32_t paddedRowSizeInBytes(int32_t dim, |
64 | SplitEmbeddingSparseType weight_ty) { |
65 | auto r = unpaddedRowSizeInBytes(dim, weight_ty); |
66 | return roundUp(r, 16); |
67 | } |
68 | |
69 | template <typename SumTy> |
70 | SumTy add(SumTy a, const uint8_t *row, const uint8_t *data, |
71 | SplitEmbeddingSparseType dataTy, bool isMSB, float weight) { |
72 | float sum = a; |
73 | |
74 | if (dataTy == SplitEmbeddingSparseType::EST_FLOAT) { |
75 | return sum + weight * static_cast<float>( |
76 | *(reinterpret_cast<const float *>(data))); |
77 | } |
78 | |
79 | if (dataTy == SplitEmbeddingSparseType::EST_FLOAT16) { |
80 | return sum + weight * static_cast<float>( |
81 | *(reinterpret_cast<const float16_t *>(data))); |
82 | } |
83 | |
84 | float scale = *(reinterpret_cast<const float16_t *>(row)); |
85 | float offset = *(reinterpret_cast<const float16_t *>(row) + 1); |
86 | |
87 | if (dataTy == SplitEmbeddingSparseType::EST_INT8) { |
88 | return sum + weight * (static_cast<float>(*data) * scale + offset); |
89 | } |
90 | |
91 | if (dataTy == SplitEmbeddingSparseType::EST_INT4) { |
92 | if (isMSB) { |
93 | return sum + weight * (static_cast<float>(*data >> 4) * scale + offset); |
94 | } |
95 | return sum + weight * (static_cast<float>(*data & 0xF) * scale + offset); |
96 | } |
97 | |
98 | llvm_unreachable("Unsuppored SplitEmbeddingSparseType" ); |
99 | } |
100 | |
101 | template <typename DataTy> |
102 | void save(DataTy a, uint8_t *data, SplitEmbeddingSparseType dataTy) { |
103 | if (dataTy == SplitEmbeddingSparseType::EST_FLOAT) { |
104 | *(reinterpret_cast<float *>(data)) = a; |
105 | } else if (dataTy == SplitEmbeddingSparseType::EST_FLOAT16) { |
106 | *(reinterpret_cast<float16_t *>(data)) = a; |
107 | } |
108 | } |
109 | |
110 | } // namespace IntNBitSplitEmbeddingBagsHelper |
111 | |
112 | #define dispatchImpl(functionName, elemTy, ...) \ |
113 | switch (elemTy) { \ |
114 | case ElemKind::FloatTy: \ |
115 | functionName<float>(__VA_ARGS__); \ |
116 | break; \ |
117 | case ElemKind::Float16Ty: \ |
118 | functionName<float16_t>(__VA_ARGS__); \ |
119 | break; \ |
120 | case ElemKind::BFloat16Ty: \ |
121 | functionName<bfloat16_t>(__VA_ARGS__); \ |
122 | break; \ |
123 | case ElemKind::Int8QTy: \ |
124 | functionName<int8_t>(__VA_ARGS__); \ |
125 | break; \ |
126 | case ElemKind::Int16QTy: \ |
127 | functionName<int16_t>(__VA_ARGS__); \ |
128 | break; \ |
129 | case ElemKind::Int32QTy: \ |
130 | functionName<int32_t>(__VA_ARGS__); \ |
131 | break; \ |
132 | case ElemKind::Int32ITy: \ |
133 | functionName<int32_t>(__VA_ARGS__); \ |
134 | break; \ |
135 | case ElemKind::Int64ITy: \ |
136 | functionName<int64_t>(__VA_ARGS__); \ |
137 | break; \ |
138 | case ElemKind::BoolTy: \ |
139 | functionName<bool>(__VA_ARGS__); \ |
140 | break; \ |
141 | default: \ |
142 | llvm_unreachable("Type is not supported"); \ |
143 | } |
144 | |
145 | #define dispatchFloatingPointImpl(functionName, elemTy, ...) \ |
146 | switch (elemTy) { \ |
147 | case ElemKind::FloatTy: \ |
148 | functionName<float>(__VA_ARGS__); \ |
149 | break; \ |
150 | case ElemKind::Float16Ty: \ |
151 | functionName<float16_t>(__VA_ARGS__); \ |
152 | break; \ |
153 | case ElemKind::BFloat16Ty: \ |
154 | functionName<bfloat16_t>(__VA_ARGS__); \ |
155 | break; \ |
156 | default: \ |
157 | llvm_unreachable("Type is not supported"); \ |
158 | } |
159 | |
160 | #define dispatchFloatingPointAndInt32Impl(functionName, elemTy, ...) \ |
161 | switch (elemTy) { \ |
162 | case ElemKind::FloatTy: \ |
163 | functionName<float>(__VA_ARGS__); \ |
164 | break; \ |
165 | case ElemKind::Float16Ty: \ |
166 | functionName<float16_t>(__VA_ARGS__); \ |
167 | break; \ |
168 | case ElemKind::BFloat16Ty: \ |
169 | functionName<bfloat16_t>(__VA_ARGS__); \ |
170 | break; \ |
171 | case ElemKind::Int32ITy: \ |
172 | functionName<int>(__VA_ARGS__); \ |
173 | break; \ |
174 | default: \ |
175 | llvm_unreachable("Type is not supported"); \ |
176 | } |
177 | |
178 | #define dispatchFloatingPointAndIndexImpl(functionName, elemTy, elemTyIndex, \ |
179 | ...) \ |
180 | switch (elemTy) { \ |
181 | case ElemKind::FloatTy: \ |
182 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
183 | functionName<float, int64_t>(__VA_ARGS__); \ |
184 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
185 | functionName<float, int32_t>(__VA_ARGS__); \ |
186 | } \ |
187 | break; \ |
188 | case ElemKind::Float16Ty: \ |
189 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
190 | functionName<float16, int64_t>(__VA_ARGS__); \ |
191 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
192 | functionName<float16, int32_t>(__VA_ARGS__); \ |
193 | } \ |
194 | break; \ |
195 | case ElemKind::BFloat16Ty: \ |
196 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
197 | functionName<bfloat16, int64_t>(__VA_ARGS__); \ |
198 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
199 | functionName<bfloat16, int32_t>(__VA_ARGS__); \ |
200 | } \ |
201 | break; \ |
202 | default: \ |
203 | llvm_unreachable("Type is not supported"); \ |
204 | } |
205 | |
206 | #define dispatchIndexTypeImpl(functionName, elemTy, ...) \ |
207 | switch (elemTy) { \ |
208 | case ElemKind::Int32ITy: \ |
209 | functionName<int32_t>(__VA_ARGS__); \ |
210 | break; \ |
211 | case ElemKind::Int64ITy: \ |
212 | functionName<int64_t>(__VA_ARGS__); \ |
213 | break; \ |
214 | default: \ |
215 | llvm_unreachable("Type is not supported"); \ |
216 | } |
217 | |
218 | #define dispatchIndexAndOutputTypeImpl(functionName, indexTy, outputTy, ...) \ |
219 | switch (indexTy) { \ |
220 | case ElemKind::Int32ITy: \ |
221 | if (outputTy == SplitEmbeddingSparseType::EST_FLOAT) { \ |
222 | functionName<int32_t, float>(__VA_ARGS__); \ |
223 | } else if (outputTy == SplitEmbeddingSparseType::EST_FLOAT16) { \ |
224 | functionName<int32_t, float16>(__VA_ARGS__); \ |
225 | } \ |
226 | break; \ |
227 | case ElemKind::Int64ITy: \ |
228 | if (outputTy == SplitEmbeddingSparseType::EST_FLOAT) { \ |
229 | functionName<int64_t, float>(__VA_ARGS__); \ |
230 | } else if (outputTy == SplitEmbeddingSparseType::EST_FLOAT16) { \ |
231 | functionName<int64_t, float16>(__VA_ARGS__); \ |
232 | } \ |
233 | break; \ |
234 | default: \ |
235 | llvm_unreachable("Type is not supported"); \ |
236 | } |
237 | |
238 | #define dispatchArithmeticImpl(functionName, elemTy, ...) \ |
239 | switch (elemTy) { \ |
240 | case ElemKind::FloatTy: \ |
241 | functionName<float>(__VA_ARGS__); \ |
242 | break; \ |
243 | case ElemKind::Float16Ty: \ |
244 | functionName<float16_t>(__VA_ARGS__); \ |
245 | break; \ |
246 | case ElemKind::BFloat16Ty: \ |
247 | functionName<bfloat16_t>(__VA_ARGS__); \ |
248 | break; \ |
249 | case ElemKind::Int32ITy: \ |
250 | functionName<int32_t>(__VA_ARGS__); \ |
251 | break; \ |
252 | case ElemKind::Int64ITy: \ |
253 | functionName<int64_t>(__VA_ARGS__); \ |
254 | break; \ |
255 | default: \ |
256 | llvm_unreachable("Type is not supported"); \ |
257 | } |
258 | |
259 | #define dispatchBitwiseImpl(functionName, elemTy, ...) \ |
260 | switch (elemTy) { \ |
261 | case ElemKind::Int32ITy: \ |
262 | functionName<int32_t>(__VA_ARGS__); \ |
263 | break; \ |
264 | case ElemKind::Int64ITy: \ |
265 | functionName<int64_t>(__VA_ARGS__); \ |
266 | break; \ |
267 | case ElemKind::BoolTy: \ |
268 | functionName<bool>(__VA_ARGS__); \ |
269 | break; \ |
270 | default: \ |
271 | llvm_unreachable("Type is not supported"); \ |
272 | } |
273 | |
274 | #define dispatchQuantizedImpl(functionName, elemTy, ...) \ |
275 | switch (elemTy) { \ |
276 | case ElemKind::Int8QTy: \ |
277 | functionName<int8_t>(__VA_ARGS__); \ |
278 | break; \ |
279 | case ElemKind::Int16QTy: \ |
280 | functionName<int16_t>(__VA_ARGS__); \ |
281 | break; \ |
282 | case ElemKind::Int32QTy: \ |
283 | functionName<int32_t>(__VA_ARGS__); \ |
284 | break; \ |
285 | default: \ |
286 | llvm_unreachable("Type is not supported"); \ |
287 | } |
288 | |
289 | #define dispatchQuantizedWithAccumulationImpl(functionName, elemTy, ...) \ |
290 | switch (elemTy) { \ |
291 | case ElemKind::Int8QTy: \ |
292 | functionName<int8_t, int32_t>(__VA_ARGS__); \ |
293 | break; \ |
294 | case ElemKind::Int16QTy: \ |
295 | functionName<int16_t, int64_t>(__VA_ARGS__); \ |
296 | break; \ |
297 | default: \ |
298 | llvm_unreachable("Type is not supported"); \ |
299 | } |
300 | |
301 | #define dispatchQuantizedWithAccumulationAndBiasImpl(functionName, elemTy, \ |
302 | biasElemType, ...) \ |
303 | if (elemTy == ElemKind::Int8QTy && biasElemType == ElemKind::Int8QTy) { \ |
304 | functionName<int8_t, int32_t, int8_t>(__VA_ARGS__); \ |
305 | } else if (elemTy == ElemKind::Int8QTy && \ |
306 | biasElemType == ElemKind::Int32QTy) { \ |
307 | functionName<int8_t, int32_t, int32_t>(__VA_ARGS__); \ |
308 | } else if (elemTy == ElemKind::Int16QTy && \ |
309 | biasElemType == ElemKind::Int16QTy) { \ |
310 | functionName<int16_t, int64_t, int16_t>(__VA_ARGS__); \ |
311 | } else if (elemTy == ElemKind::Int16QTy && \ |
312 | biasElemType == ElemKind::Int32QTy) { \ |
313 | functionName<int16_t, int64_t, int32_t>(__VA_ARGS__); \ |
314 | } else { \ |
315 | llvm_unreachable("Type is not supported"); \ |
316 | } |
317 | |
318 | #define staticAssertFloatingPointType(ElemTy) \ |
319 | static_assert( \ |
320 | std::is_floating_point<ElemTy>::value || \ |
321 | std::is_same<float16_t, \ |
322 | typename std::remove_cv<ElemTy>::type>::value || \ |
323 | std::is_same<bfloat16_t, \ |
324 | typename std::remove_cv<ElemTy>::type>::value, \ |
325 | "This implementation is for floating-point values only") |
326 | |
327 | #define staticAssertArithmeticType(ElemTy) \ |
328 | static_assert( \ |
329 | std::is_arithmetic<ElemTy>::value || \ |
330 | std::is_same<float16_t, \ |
331 | typename std::remove_cv<ElemTy>::type>::value || \ |
332 | std::is_same<bfloat16_t, \ |
333 | typename std::remove_cv<ElemTy>::type>::value, \ |
334 | "This implementation is for arithmetic values only") |
335 | |
336 | #ifndef MIN |
337 | #define MIN(a, b) (((a) < (b)) ? (a) : (b)) |
338 | #endif |
339 | |
340 | #ifndef MAX |
341 | #define MAX(a, b) (((a) > (b)) ? (a) : (b)) |
342 | #endif |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // Convolution |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | /// This is the floating point implementation of Convolution. |
349 | template <typename ElemTy> |
350 | void BoundInterpreterFunction::fwdConvolutionInstFloatImpl( |
351 | Value *inV, Value *outV, Value *filterV, Value *biasV, |
352 | llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides, |
353 | llvm::ArrayRef<unsigned_t> pads, size_t group, |
354 | llvm::ArrayRef<unsigned_t> dilation) { |
355 | staticAssertFloatingPointType(ElemTy); |
356 | |
357 | auto inW = getWeightHandle<ElemTy>(inV); |
358 | auto outW = getWeightHandle<ElemTy>(outV); |
359 | auto filterW = getWeightHandle<ElemTy>(filterV); |
360 | auto biasW = getWeightHandle<ElemTy>(biasV); |
361 | |
362 | ShapeNHWC odim(outW.dims()); |
363 | ShapeNHWC idim(inW.dims()); |
364 | ShapeHW kdim(kernelSizes); |
365 | ShapeHW sdim(strides); |
366 | |
367 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
368 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
369 | dim_t inCperG = idim.c / group; |
370 | dim_t outCperG = odim.c / group; |
371 | |
372 | PaddingTLBR pdim(pads); |
373 | |
374 | // For each input in the batch: |
375 | for (dim_t n = 0; n < idim.n; n++) { |
376 | |
377 | // For each group of input channels: |
378 | for (dim_t g = 0; g < group; g++) { |
379 | |
380 | // For each output channel in the group: |
381 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
382 | |
383 | // For each convolution 'jump' in the input tensor: |
384 | ssize_t x = -ssize_t(pdim.top); |
385 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
386 | ssize_t y = -ssize_t(pdim.left); |
387 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
388 | |
389 | // For each element in the convolution-filter: |
390 | float sum = 0; |
391 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
392 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
393 | sdim_t ox = x + fx * dilation[0]; |
394 | sdim_t oy = y + fy * dilation[1]; |
395 | |
396 | // Ignore index access below zero (this is due to padding). |
397 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
398 | oy >= ssize_t(idim.w)) { |
399 | continue; |
400 | } |
401 | for (dim_t fd = 0; fd < inCperG; fd++) { |
402 | sum += float( |
403 | filterW.at({d, fx, fy, fd}) * |
404 | inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd})); |
405 | } |
406 | } |
407 | } |
408 | |
409 | sum += float(biasW.at({d})); |
410 | outW.at({n, ax, ay, d}) = ElemTy(sum); |
411 | } // W |
412 | } // H |
413 | } // C |
414 | } // G |
415 | } // N |
416 | } |
417 | |
418 | /// This is the quantized implementation of Convolution. |
419 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
420 | void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl( |
421 | Value *inV, Value *outV, Value *filterV, Value *biasV, |
422 | llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides, |
423 | llvm::ArrayRef<unsigned_t> pads, size_t group, |
424 | llvm::ArrayRef<unsigned_t> dilation) { |
425 | auto inW = getWeightHandle<ElemTy>(inV); |
426 | auto outW = getWeightHandle<ElemTy>(outV); |
427 | auto filterW = getWeightHandle<ElemTy>(filterV); |
428 | auto biasW = getWeightHandle<BiasElemTy>(biasV); |
429 | |
430 | ShapeNHWC odim(outW.dims()); |
431 | ShapeNHWC idim(inW.dims()); |
432 | ShapeHW kdim(kernelSizes); |
433 | ShapeHW sdim(strides); |
434 | |
435 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
436 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
437 | dim_t inCperG = idim.c / group; |
438 | dim_t outCperG = odim.c / group; |
439 | |
440 | PaddingTLBR pdim(pads); |
441 | auto outTy = outV->getType(); |
442 | auto inTy = inV->getType(); |
443 | auto filterTy = filterV->getType(); |
444 | auto biasTy = biasV->getType(); |
445 | |
446 | int32_t outOffset = outTy->getOffset(); |
447 | int32_t inOffset = inTy->getOffset(); |
448 | int32_t filterOffset = filterTy->getOffset(); |
449 | int32_t biasOffset = biasTy->getOffset(); |
450 | |
451 | float outScale = outTy->getScale(); |
452 | float inScale = inTy->getScale(); |
453 | float filterScale = filterTy->getScale(); |
454 | float biasScale = biasTy->getScale(); |
455 | |
456 | // Calculate the scale of the values that come out of the matrix |
457 | // multiplication part of the calculation. |
458 | float matMulScale = inScale * filterScale; |
459 | |
460 | // For each input in the batch: |
461 | for (dim_t n = 0; n < idim.n; n++) { |
462 | // For each group of input channels: |
463 | for (dim_t g = 0; g < group; g++) { |
464 | |
465 | // For each output channel in the group: |
466 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
467 | |
468 | // For each convolution 'jump' in the input tensor: |
469 | ssize_t x = -ssize_t(pdim.top); |
470 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
471 | ssize_t y = -ssize_t(pdim.left); |
472 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
473 | |
474 | // For each element in the convolution-filter: |
475 | AccumulatorTy sum = 0; |
476 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
477 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
478 | sdim_t ox = x + fx * dilation[0]; |
479 | sdim_t oy = y + fy * dilation[1]; |
480 | |
481 | // Ignore index access below zero (this is due to padding). |
482 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
483 | oy >= sdim_t(idim.w)) { |
484 | continue; |
485 | } |
486 | for (dim_t fd = 0; fd < inCperG; fd++) { |
487 | |
488 | AccumulatorTy F = filterW.at({d, fx, fy, fd}); |
489 | AccumulatorTy I = |
490 | inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}); |
491 | // We represent the element multiplication with offset as |
492 | // (value - offset). |
493 | sum += (F - filterOffset) * (I - inOffset); |
494 | } |
495 | } |
496 | } |
497 | |
498 | // Scale the bias to match the scale of the matrix multiplication. |
499 | AccumulatorTy B = std::round(float(biasW.at({d}) - biasOffset) * |
500 | (biasScale / matMulScale)); |
501 | |
502 | // Add the bias. |
503 | sum += B; |
504 | |
505 | // Scale the result back to the expected destination scale. |
506 | outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>( |
507 | std::round(float(sum) * (matMulScale / outScale) + outOffset)); |
508 | } // W |
509 | } // H |
510 | } // C |
511 | } // G |
512 | } // N |
513 | } |
514 | |
515 | /// This is the floating point implementation of ConvTranspose. |
516 | template <typename ElemTy> |
517 | void BoundInterpreterFunction::fwdConvTransposeInstFloatImpl( |
518 | Value *inV, Value *outV, Value *filterV, Value *biasV, |
519 | llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides, |
520 | llvm::ArrayRef<unsigned_t> pads, size_t group, |
521 | llvm::ArrayRef<unsigned_t> dilation) { |
522 | staticAssertFloatingPointType(ElemTy); |
523 | |
524 | auto inW = getWeightHandle<ElemTy>(inV); |
525 | auto outW = getWeightHandle<ElemTy>(outV); |
526 | auto filterW = getWeightHandle<ElemTy>(filterV); |
527 | auto biasW = getWeightHandle<ElemTy>(biasV); |
528 | |
529 | ShapeNHWC odim(outW.dims()); |
530 | ShapeNHWC idim(inW.dims()); |
531 | ShapeHW kdim(kernelSizes); |
532 | ShapeHW sdim(strides); |
533 | |
534 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
535 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
536 | |
537 | dim_t inCperG = idim.c / group; |
538 | dim_t outCperG = odim.c / group; |
539 | |
540 | PaddingTLBR pdim(pads); |
541 | |
542 | // For each input in the batch: |
543 | for (dim_t n = 0; n < idim.n; n++) { |
544 | |
545 | // Initialize bias (TODO take out to a separate function when quant is in). |
546 | for (dim_t ax = 0; ax < odim.h; ax++) { |
547 | for (dim_t ay = 0; ay < odim.w; ay++) { |
548 | for (dim_t d = 0; d < odim.c; d++) { |
549 | outW.at({n, ax, ay, d}) = static_cast<ElemTy>(biasW.at({d})); |
550 | } |
551 | } |
552 | } |
553 | |
554 | // For each group of input channels: |
555 | for (dim_t g = 0; g < group; g++) { |
556 | |
557 | // For each input channel in the group: |
558 | for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) { |
559 | |
560 | // For each transposed convolution 'jump' in the input tensor: |
561 | ssize_t x = -ssize_t(pdim.top); |
562 | for (dim_t bx = 0; bx < idim.h; bx++, x += sdim.height) { |
563 | ssize_t y = -ssize_t(pdim.left); |
564 | for (dim_t by = 0; by < idim.w; by++, y += sdim.width) { |
565 | |
566 | // For each element in the each transposed convolution filter: |
567 | ElemTy input = inW.at({n, bx, by, d}); |
568 | |
569 | for (dim_t kx = 0; kx < kdim.height; kx++) { |
570 | for (dim_t ky = 0; ky < kdim.width; ky++) { |
571 | ssize_t ax = x + kx * dilation[0]; |
572 | ssize_t ay = y + ky * dilation[1]; |
573 | |
574 | // Ignore index access below zero (this is due to padding). |
575 | if (ax < 0 || ay < 0 || ax >= ssize_t(odim.h) || |
576 | ay >= ssize_t(odim.w)) { |
577 | continue; |
578 | } |
579 | for (dim_t c = 0; c < outCperG; c++) { |
580 | outW.at({n, (dim_t)ax, (dim_t)ay, g * outCperG + c}) += |
581 | filterW.at({c, kx, ky, d}) * input; |
582 | } |
583 | } |
584 | } |
585 | } // W |
586 | } // H |
587 | } // C |
588 | } // G |
589 | } // N |
590 | } |
591 | |
592 | void BoundInterpreterFunction::fwdConvTransposeInst( |
593 | const ConvTransposeInst *I) { |
594 | auto kernelSizes = I->getKernels(); |
595 | auto pads = I->getPads(); |
596 | auto strides = I->getStrides(); |
597 | size_t group = I->getGroup(); |
598 | |
599 | if (I->getSrc()->getType()->isQuantizedType()) { |
600 | llvm_unreachable("Quantized ConvTranspose not supported" ); |
601 | return; |
602 | } |
603 | |
604 | dispatchFloatingPointImpl( |
605 | fwdConvTransposeInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(), |
606 | I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads, |
607 | group, I->getDilation()); |
608 | } |
609 | |
610 | void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { |
611 | auto kernelSizes = I->getKernels(); |
612 | auto pads = I->getPads(); |
613 | auto strides = I->getStrides(); |
614 | size_t group = I->getGroup(); |
615 | |
616 | if (I->getSrc()->getType()->isQuantizedType()) { |
617 | dispatchQuantizedWithAccumulationAndBiasImpl( |
618 | fwdConvolutionInstQuantizedImpl, I->getSrc()->getElementType(), |
619 | I->getBias()->getElementType(), I->getSrc(), I->getDest(), |
620 | I->getFilter(), I->getBias(), kernelSizes, strides, pads, group, |
621 | I->getDilation()); |
622 | return; |
623 | } |
624 | |
625 | dispatchFloatingPointImpl( |
626 | fwdConvolutionInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(), |
627 | I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads, |
628 | group, I->getDilation()); |
629 | } |
630 | |
631 | void BoundInterpreterFunction::fwdConcatInst(const ConcatInst *I) { |
632 | (void)I; |
633 | // TODO |
634 | llvm_unreachable("not yet implemented" ); |
635 | } |
636 | |
637 | void BoundInterpreterFunction::fwdConvolutionGradInst( |
638 | const ConvolutionGradInst *I) { |
639 | auto inW = getWeightHandle(I->getSrc()); |
640 | auto inG = getWeightHandle(I->getSrcGrad()); |
641 | auto outG = getWeightHandle(I->getDestGrad()); |
642 | |
643 | auto filterW = getWeightHandle(I->getFilter()); |
644 | auto filterG = getWeightHandle(I->getFilterGrad()); |
645 | auto biasG = getWeightHandle(I->getBiasGrad()); |
646 | |
647 | size_t group = I->getGroup(); |
648 | auto dilation = I->getDilation(); |
649 | |
650 | inG.clear(); |
651 | filterG.clear(); |
652 | biasG.clear(); |
653 | |
654 | ShapeNHWC odim(outG.dims()); |
655 | ShapeNHWC idim(inW.dims()); |
656 | ShapeHW kdim(I->getKernels()); |
657 | ShapeHW sdim(I->getStrides()); |
658 | PaddingTLBR pdim(I->getPads()); |
659 | |
660 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
661 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
662 | dim_t inCperG = idim.c / group; |
663 | dim_t outCperG = odim.c / group; |
664 | |
665 | // For each input in the batch: |
666 | for (dim_t n = 0; n < odim.n; n++) { |
667 | |
668 | // For each group of input channels: |
669 | for (dim_t g = 0; g < group; g++) { |
670 | |
671 | // Compute the gradient. For each layer in the output tensor: |
672 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
673 | |
674 | // For each convolution 'jump' in the input tensor: |
675 | sdim_t x = -sdim_t(pdim.top); |
676 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
677 | sdim_t y = -sdim_t(pdim.left); |
678 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
679 | |
680 | float chainGrad = outG.at({n, ax, ay, d}); |
681 | |
682 | // For each element in the convolution-filter: |
683 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
684 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
685 | sdim_t ox = x + fx * dilation[0]; |
686 | sdim_t oy = y + fy * dilation[1]; |
687 | |
688 | // Ignore index access below zero (this is due to padding). |
689 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
690 | oy >= sdim_t(idim.w)) { |
691 | continue; |
692 | } |
693 | |
694 | for (dim_t fd = 0; fd < inCperG; fd++) { |
695 | filterG.at({d, fx, fy, fd}) += |
696 | inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) * |
697 | chainGrad; |
698 | inG.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) += |
699 | filterW.at({d, fx, fy, fd}) * chainGrad; |
700 | } |
701 | } |
702 | } |
703 | |
704 | biasG.at({d}) += chainGrad; |
705 | } // W |
706 | } // H |
707 | } // C |
708 | } // G |
709 | } // N |
710 | } |
711 | |
712 | /// This is the floating point implementation of Convolution3D. |
713 | template <typename ElemTy> |
714 | void BoundInterpreterFunction::fwdConvolution3DInstFloatImpl( |
715 | Value *inV, Value *outV, Value *filterV, Value *biasV, |
716 | llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides, |
717 | llvm::ArrayRef<unsigned_t> pads, size_t group) { |
718 | staticAssertFloatingPointType(ElemTy); |
719 | |
720 | auto inW = getWeightHandle<ElemTy>(inV); |
721 | auto outW = getWeightHandle<ElemTy>(outV); |
722 | auto filterW = getWeightHandle<ElemTy>(filterV); |
723 | auto biasW = getWeightHandle<ElemTy>(biasV); |
724 | |
725 | ShapeNTHWC odim(outW.dims()); |
726 | ShapeNTHWC idim(inW.dims()); |
727 | ShapeTHW kdim(kernelSizes); |
728 | ShapeTHW sdim(strides); |
729 | |
730 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
731 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
732 | dim_t inCperG = idim.c / group; |
733 | dim_t outCperG = odim.c / group; |
734 | |
735 | PaddingNFTBLR pdim(pads); |
736 | |
737 | // For each input in the batch: |
738 | for (dim_t n = 0; n < idim.n; n++) { |
739 | |
740 | // For each group of input channels: |
741 | for (dim_t ig = 0; ig < group; ig++) { |
742 | |
743 | // For each output channel in the group: |
744 | for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) { |
745 | |
746 | ssize_t t = -ssize_t(pdim.near); |
747 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
748 | // For each convolution 'jump' in the input tensor: |
749 | ssize_t x = -ssize_t(pdim.top); |
750 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
751 | ssize_t y = -ssize_t(pdim.left); |
752 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
753 | // For each element in the 3D convolution-filter: |
754 | float sum = 0; |
755 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
756 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
757 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
758 | sdim_t ot = t + ft; |
759 | sdim_t ox = x + fx; |
760 | sdim_t oy = y + fy; |
761 | |
762 | // Ignore index access below zero (this is due to padding). |
763 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
764 | ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) { |
765 | continue; |
766 | } |
767 | for (dim_t fg = 0; fg < inCperG; fg++) { |
768 | sum += float(filterW.at({og, ft, fx, fy, fg}) * |
769 | inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, |
770 | ig * inCperG + fg})); |
771 | } |
772 | } |
773 | } |
774 | } |
775 | |
776 | sum += float(biasW.at({og})); |
777 | outW.at({n, at, ax, ay, og}) = ElemTy(sum); |
778 | } // D |
779 | } // W |
780 | } // H |
781 | } // C |
782 | } // G |
783 | } // N |
784 | } |
785 | |
786 | /// This is the quantized implementation of Convolution3D. |
787 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
788 | void BoundInterpreterFunction::fwdConvolution3DInstQuantizedImpl( |
789 | Value *inV, Value *outV, Value *filterV, Value *biasV, |
790 | llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides, |
791 | llvm::ArrayRef<unsigned_t> pads, size_t group) { |
792 | auto inW = getWeightHandle<ElemTy>(inV); |
793 | auto outW = getWeightHandle<ElemTy>(outV); |
794 | auto filterW = getWeightHandle<ElemTy>(filterV); |
795 | auto biasW = getWeightHandle<BiasElemTy>(biasV); |
796 | |
797 | ShapeNTHWC odim(outW.dims()); |
798 | ShapeNTHWC idim(inW.dims()); |
799 | ShapeTHW kdim(kernelSizes); |
800 | ShapeTHW sdim(strides); |
801 | |
802 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
803 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
804 | dim_t inCperG = idim.c / group; |
805 | dim_t outCperG = odim.c / group; |
806 | |
807 | PaddingNFTBLR pdim(pads); |
808 | |
809 | auto outTy = outV->getType(); |
810 | auto inTy = inV->getType(); |
811 | auto filterTy = filterV->getType(); |
812 | auto biasTy = biasV->getType(); |
813 | |
814 | int32_t outOffset = outTy->getOffset(); |
815 | int32_t inOffset = inTy->getOffset(); |
816 | int32_t filterOffset = filterTy->getOffset(); |
817 | int32_t biasOffset = biasTy->getOffset(); |
818 | |
819 | float outScale = outTy->getScale(); |
820 | float inScale = inTy->getScale(); |
821 | float filterScale = filterTy->getScale(); |
822 | float biasScale = biasTy->getScale(); |
823 | |
824 | // Calculate the scale of the values that come out of the matrix |
825 | // multiplication part of the calculation. |
826 | float matMulScale = inScale * filterScale; |
827 | |
828 | // For each input in the batch: |
829 | for (dim_t n = 0; n < idim.n; n++) { |
830 | |
831 | // For each group of input channels: |
832 | for (dim_t ig = 0; ig < group; ig++) { |
833 | |
834 | // For each output channel in the group: |
835 | for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) { |
836 | |
837 | // For each convolution 'jump' in the input tensor: |
838 | ssize_t t = -ssize_t(pdim.near); |
839 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
840 | ssize_t x = -ssize_t(pdim.top); |
841 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
842 | ssize_t y = -ssize_t(pdim.left); |
843 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
844 | |
845 | // For each element in the convolution-filter: |
846 | AccumulatorTy sum = 0; |
847 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
848 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
849 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
850 | ssize_t ot = t + ft; |
851 | ssize_t ox = x + fx; |
852 | ssize_t oy = y + fy; |
853 | |
854 | // Ignore index access below zero (this is due to padding). |
855 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
856 | ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) { |
857 | continue; |
858 | } |
859 | for (dim_t fg = 0; fg < inCperG; fg++) { |
860 | |
861 | AccumulatorTy F = filterW.at({og, ft, fx, fy, fg}); |
862 | AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox, |
863 | (dim_t)oy, ig * inCperG + fg}); |
864 | // We represent the element multiplication with offset as |
865 | // (value - offset). |
866 | sum += (F - filterOffset) * (I - inOffset); |
867 | } |
868 | } |
869 | } |
870 | } |
871 | |
872 | // Scale the bias to match the scale of the matrix multiplication. |
873 | AccumulatorTy B = std::round(float(biasW.at({og}) - biasOffset) * |
874 | (biasScale / matMulScale)); |
875 | |
876 | // Add the bias: |
877 | sum += B; |
878 | |
879 | // Scale the result back to the expected destination scale. |
880 | outW.at({n, at, ax, ay, og}) = |
881 | quantization::clip<AccumulatorTy, ElemTy>(std::round( |
882 | float(sum) * (matMulScale / outScale) + outOffset)); |
883 | } // D |
884 | } // W |
885 | } // H |
886 | } // C |
887 | } // G |
888 | } // N |
889 | } |
890 | |
891 | void BoundInterpreterFunction::fwdConvolution3DInst( |
892 | const Convolution3DInst *I) { |
893 | auto kernelSizes = I->getKernels(); |
894 | auto pads = I->getPads(); |
895 | auto strides = I->getStrides(); |
896 | size_t group = I->getGroup(); |
897 | |
898 | if (I->getSrc()->getType()->isQuantizedType()) { |
899 | dispatchQuantizedWithAccumulationAndBiasImpl( |
900 | fwdConvolution3DInstQuantizedImpl, I->getSrc()->getElementType(), |
901 | I->getBias()->getElementType(), I->getSrc(), I->getDest(), |
902 | I->getFilter(), I->getBias(), kernelSizes, strides, pads, group); |
903 | return; |
904 | } |
905 | |
906 | dispatchFloatingPointImpl(fwdConvolution3DInstFloatImpl, |
907 | I->getSrc()->getElementType(), I->getSrc(), |
908 | I->getDest(), I->getFilter(), I->getBias(), |
909 | kernelSizes, strides, pads, group); |
910 | } |
911 | |
912 | void BoundInterpreterFunction::fwdConvolution3DGradInst( |
913 | const Convolution3DGradInst *I) { |
914 | (void)I; |
915 | // TODO |
916 | llvm_unreachable("not yet implemented" ); |
917 | } |
918 | |
919 | //===----------------------------------------------------------------------===// |
920 | // Channelwise quantized Convolution |
921 | //===----------------------------------------------------------------------===// |
922 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
923 | void BoundInterpreterFunction::fwdChannelwiseQuantizedConv2DInstImpl( |
924 | const ChannelwiseQuantizedConvolutionInst *I) { |
925 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
926 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
927 | auto filterW = getWeightHandle<ElemTy>(I->getFilter()); |
928 | auto biasW = getWeightHandle<BiasElemTy>(I->getBias()); |
929 | auto filterScales = getWeightHandle<float>(I->getFilterScales()); |
930 | auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets()); |
931 | auto biasScales = getWeightHandle<float>(I->getBiasScales()); |
932 | auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets()); |
933 | |
934 | llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels(); |
935 | llvm::ArrayRef<unsigned_t> pads = I->getPads(); |
936 | llvm::ArrayRef<unsigned_t> strides = I->getStrides(); |
937 | dim_t group = I->getGroup(); |
938 | llvm::ArrayRef<unsigned_t> dilation = I->getDilation(); |
939 | |
940 | ShapeNHWC odim(outW.dims()); |
941 | ShapeNHWC idim(inW.dims()); |
942 | ShapeHW kdim(kernelSizes); |
943 | ShapeHW sdim(strides); |
944 | |
945 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
946 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
947 | dim_t inCperG = idim.c / group; |
948 | dim_t outCperG = odim.c / group; |
949 | |
950 | PaddingTLBR pdim(pads); |
951 | |
952 | auto &inTy = inW.getType(); |
953 | auto &outTy = outW.getType(); |
954 | |
955 | float inScale = inTy.getScale(); |
956 | float outScale = outTy.getScale(); |
957 | |
958 | int32_t inOffset = inTy.getOffset(); |
959 | int32_t outOffset = outTy.getOffset(); |
960 | |
961 | // For each input in the batch: |
962 | for (dim_t n = 0; n < idim.n; n++) { |
963 | // For each group of input channels: |
964 | for (dim_t g = 0; g < group; g++) { |
965 | // For each output channel in the group: |
966 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
967 | |
968 | // Get channel wise quantization params. |
969 | int32_t filterOffset = filterOffsets.at(d); |
970 | float filterScale = filterScales.at(d); |
971 | int32_t biasOffset = biasOffsets.at(d); |
972 | float biasScale = biasScales.at(d); |
973 | float matMulScale = inScale * filterScale; |
974 | |
975 | // For each convolution 'jump' in the input tensor: |
976 | sdim_t x = -sdim_t(pdim.top); |
977 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
978 | sdim_t y = -sdim_t(pdim.left); |
979 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
980 | |
981 | // For each element in the convolution-filter: |
982 | AccumulatorTy sum = 0; |
983 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
984 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
985 | sdim_t ox = x + fx * dilation[0]; |
986 | sdim_t oy = y + fy * dilation[1]; |
987 | |
988 | // Ignore index access below zero (this is due to padding). |
989 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
990 | oy >= sdim_t(idim.w)) { |
991 | continue; |
992 | } |
993 | |
994 | // Accumulate along the filter depth. |
995 | for (dim_t fd = 0; fd < inCperG; fd++) { |
996 | AccumulatorTy F = filterW.at({d, fx, fy, fd}); |
997 | AccumulatorTy I = |
998 | inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}); |
999 | // We represent the element multiplication with offset as |
1000 | // (value - offset). |
1001 | sum += (F - filterOffset) * (I - inOffset); |
1002 | } |
1003 | } |
1004 | } |
1005 | |
1006 | // Scale the bias to match the scale of the matrix multiplication. |
1007 | sum += std::round(float(biasW.at({d}) - biasOffset) * |
1008 | (biasScale / matMulScale)); |
1009 | |
1010 | // Scale the result back to the expected destination scale. |
1011 | outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>( |
1012 | std::round(float(sum) * (matMulScale / outScale) + outOffset)); |
1013 | } // W |
1014 | } // H |
1015 | } // C |
1016 | } // G |
1017 | } // N |
1018 | } |
1019 | |
1020 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
1021 | void BoundInterpreterFunction::fwdChannelwiseQuantizedConv3DInstImpl( |
1022 | const ChannelwiseQuantizedConvolutionInst *I) { |
1023 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
1024 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1025 | auto filterW = getWeightHandle<ElemTy>(I->getFilter()); |
1026 | auto biasW = getWeightHandle<BiasElemTy>(I->getBias()); |
1027 | auto filterScales = getWeightHandle<float>(I->getFilterScales()); |
1028 | auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets()); |
1029 | auto biasScales = getWeightHandle<float>(I->getBiasScales()); |
1030 | auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets()); |
1031 | |
1032 | llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels(); |
1033 | llvm::ArrayRef<unsigned_t> pads = I->getPads(); |
1034 | llvm::ArrayRef<unsigned_t> strides = I->getStrides(); |
1035 | dim_t group = I->getGroup(); |
1036 | |
1037 | ShapeNTHWC odim(outW.dims()); |
1038 | ShapeNTHWC idim(inW.dims()); |
1039 | ShapeTHW kdim(kernelSizes); |
1040 | ShapeTHW sdim(strides); |
1041 | |
1042 | assert(idim.c % group == 0 && "Input channels must be divisible by group." ); |
1043 | assert(odim.c % group == 0 && "Output channels must be divisible by group." ); |
1044 | dim_t inCperG = idim.c / group; |
1045 | dim_t outCperG = odim.c / group; |
1046 | |
1047 | PaddingNFTBLR pdim(pads); |
1048 | |
1049 | auto &inTy = inW.getType(); |
1050 | auto &outTy = outW.getType(); |
1051 | |
1052 | float inScale = inTy.getScale(); |
1053 | float outScale = outTy.getScale(); |
1054 | |
1055 | int32_t inOffset = inTy.getOffset(); |
1056 | int32_t outOffset = outTy.getOffset(); |
1057 | |
1058 | // For each input in the batch: |
1059 | for (dim_t n = 0; n < idim.n; n++) { |
1060 | // For each group of input channels: |
1061 | for (dim_t g = 0; g < group; g++) { |
1062 | // For each output channel in the group: |
1063 | for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) { |
1064 | |
1065 | // Get channel wise quantization params. |
1066 | int32_t filterOffset = filterOffsets.at(d); |
1067 | float filterScale = filterScales.at(d); |
1068 | int32_t biasOffset = biasOffsets.at(d); |
1069 | float biasScale = biasScales.at(d); |
1070 | float matMulScale = inScale * filterScale; |
1071 | |
1072 | // For each convolution 'jump' in the input tensor: |
1073 | sdim_t t = -sdim_t(pdim.near); |
1074 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
1075 | sdim_t x = -sdim_t(pdim.top); |
1076 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1077 | sdim_t y = -sdim_t(pdim.left); |
1078 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1079 | |
1080 | // For each element in the convolution-filter: |
1081 | AccumulatorTy sum = 0; |
1082 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
1083 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1084 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1085 | sdim_t ot = t + ft; |
1086 | sdim_t ox = x + fx; |
1087 | sdim_t oy = y + fy; |
1088 | |
1089 | // Ignore index access below zero (this is due to |
1090 | // padding). |
1091 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
1092 | ox >= ssize_t(idim.h) || oy >= sdim_t(idim.w)) { |
1093 | continue; |
1094 | } |
1095 | |
1096 | // Accumulate along the filter depth. |
1097 | for (dim_t fd = 0; fd < inCperG; fd++) { |
1098 | |
1099 | AccumulatorTy F = filterW.at({d, ft, fx, fy, fd}); |
1100 | AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox, |
1101 | (dim_t)oy, g * inCperG + fd}); |
1102 | // We represent the element multiplication with offset |
1103 | // as (value - offset). |
1104 | sum += (F - filterOffset) * (I - inOffset); |
1105 | } |
1106 | } |
1107 | } |
1108 | } |
1109 | |
1110 | // Scale the bias to match the scale of the matrix multiplication. |
1111 | sum += std::round(float(biasW.at({d}) - biasOffset) * |
1112 | (biasScale / matMulScale)); |
1113 | |
1114 | // Scale the result back to the expected destination scale. |
1115 | outW.at({n, at, ax, ay, d}) = |
1116 | quantization::clip<AccumulatorTy, ElemTy>(std::round( |
1117 | float(sum) * (matMulScale / outScale) + outOffset)); |
1118 | } // W |
1119 | } // H |
1120 | } // T |
1121 | } // C |
1122 | } // G |
1123 | } // N |
1124 | } |
1125 | |
1126 | void BoundInterpreterFunction::fwdChannelwiseQuantizedConvolutionInst( |
1127 | const ChannelwiseQuantizedConvolutionInst *I) { |
1128 | bool isConv3D = (I->getSrc()->dims().size() == 5); |
1129 | if (isConv3D) { |
1130 | dispatchQuantizedWithAccumulationAndBiasImpl( |
1131 | fwdChannelwiseQuantizedConv3DInstImpl, I->getSrc()->getElementType(), |
1132 | I->getBias()->getElementType(), I); |
1133 | } else { |
1134 | dispatchQuantizedWithAccumulationAndBiasImpl( |
1135 | fwdChannelwiseQuantizedConv2DInstImpl, I->getSrc()->getElementType(), |
1136 | I->getBias()->getElementType(), I); |
1137 | } |
1138 | } |
1139 | |
1140 | template <typename ElemTy> |
1141 | void BoundInterpreterFunction::fwdBatchNormalizationFloatImpl( |
1142 | const BatchNormalizationInst *I, int numDims) { |
1143 | staticAssertFloatingPointType(ElemTy); |
1144 | |
1145 | // input |
1146 | auto inH = getWeightHandle<ElemTy>(I->getSrc()); |
1147 | auto scaleH = getWeightHandle<ElemTy>(I->getScale()); |
1148 | auto biasH = getWeightHandle<ElemTy>(I->getBias()); |
1149 | auto meanH = getWeightHandle<ElemTy>(I->getMean()); |
1150 | auto varH = getWeightHandle<ElemTy>(I->getVar()); |
1151 | unsigned_t channelIdx = I->getChannelIdx(); |
1152 | float epsilon = I->getEpsilon(); |
1153 | |
1154 | // output |
1155 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1156 | |
1157 | dim_t N, C, sizeN, sizeImg; |
1158 | bool isCMinor; |
1159 | if (numDims == 3) { |
1160 | if (channelIdx == 4) { |
1161 | ShapeNTHWC idim(I->getSrc()->dims()); |
1162 | N = idim.n; |
1163 | C = idim.c; |
1164 | sizeImg = idim.t * idim.h * idim.w; |
1165 | sizeN = idim.c * sizeImg; |
1166 | isCMinor = true; |
1167 | } else { |
1168 | ShapeNCTHW idim(I->getSrc()->dims()); |
1169 | N = idim.n; |
1170 | C = idim.c; |
1171 | sizeImg = idim.t * idim.h * idim.w; |
1172 | sizeN = idim.c * sizeImg; |
1173 | isCMinor = false; |
1174 | } |
1175 | } else if (numDims == 2) { |
1176 | if (channelIdx == 3) { |
1177 | ShapeNHWC idim(I->getSrc()->dims()); |
1178 | N = idim.n; |
1179 | C = idim.c; |
1180 | sizeImg = idim.h * idim.w; |
1181 | sizeN = idim.c * sizeImg; |
1182 | isCMinor = true; |
1183 | } else { |
1184 | ShapeNCHW idim(I->getSrc()->dims()); |
1185 | N = idim.n; |
1186 | C = idim.c; |
1187 | sizeImg = idim.h * idim.w; |
1188 | sizeN = idim.c * sizeImg; |
1189 | isCMinor = false; |
1190 | } |
1191 | } else if (numDims == 1) { |
1192 | N = I->getSrc()->dims()[0]; |
1193 | C = I->getSrc()->dims()[channelIdx]; |
1194 | sizeImg = I->getSrc()->dims()[channelIdx == 2 ? 1 : 2]; |
1195 | sizeN = C * sizeImg; |
1196 | isCMinor = (channelIdx == 2); |
1197 | } else { |
1198 | N = I->getSrc()->dims()[0]; |
1199 | C = I->getSrc()->dims()[channelIdx]; |
1200 | sizeImg = 1; |
1201 | sizeN = C; |
1202 | isCMinor = false; |
1203 | } |
1204 | |
1205 | std::vector<float> scale(C), mean(C), bias(C); |
1206 | for (dim_t c = 0; c < C; c++) { |
1207 | scale[c] = float(scaleH.at({c})) / std::sqrt(float(varH.at({c})) + epsilon); |
1208 | bias[c] = biasH.at({c}); |
1209 | mean[c] = meanH.at({c}); |
1210 | } |
1211 | |
1212 | // For each input in the batch: |
1213 | for (dim_t n = 0; n < N; n++) { |
1214 | if (isCMinor) { |
1215 | // For each H*W{*T} of the image |
1216 | for (dim_t i = 0; i < sizeImg; i++) { |
1217 | // For each channel |
1218 | for (dim_t c = 0; c < C; c++) { |
1219 | int index = n * sizeN + i * C + c; |
1220 | outW.raw(index) = |
1221 | ElemTy(scale[c] * (float(inH.raw(index)) - mean[c]) + bias[c]); |
1222 | } // C |
1223 | } // image |
1224 | } else { |
1225 | // For each channel |
1226 | for (dim_t c = 0; c < C; c++) { |
1227 | // For each H*W{*T} of the image |
1228 | for (dim_t i = 0; i < sizeImg; i++) { |
1229 | int index = n * sizeN + c * sizeImg + i; |
1230 | outW.raw(index) = |
1231 | ElemTy(scale[c] * (float(inH.raw(index)) - mean[c]) + bias[c]); |
1232 | } // image |
1233 | } // C |
1234 | } |
1235 | } // N |
1236 | } |
1237 | |
1238 | template <typename ParamTy> |
1239 | void BoundInterpreterFunction::fwdBatchNormalizationI8Impl( |
1240 | const BatchNormalizationInst *I, int numDims) { |
1241 | |
1242 | // input |
1243 | auto inH = getWeightHandle<int8_t>(I->getSrc()); |
1244 | auto scaleH = getWeightHandle<ParamTy>(I->getScale()); |
1245 | auto biasH = getWeightHandle<ParamTy>(I->getBias()); |
1246 | auto meanH = getWeightHandle<ParamTy>(I->getMean()); |
1247 | auto varH = getWeightHandle<ParamTy>(I->getVar()); |
1248 | unsigned_t channelIdx = |
1249 | I->getChannelIdx(); // NOTE: We only support NTHWC, NHWC, NWC and NCW |
1250 | float epsilon = I->getEpsilon(); |
1251 | auto inScale = float(I->getSrc()->getType()->getScale()); |
1252 | auto inZero = I->getSrc()->getType()->getOffset(); |
1253 | |
1254 | // output |
1255 | auto outH = getWeightHandle<int8_t>(I->getDest()); |
1256 | auto outScale = float(I->getDest()->getType()->getScale()); |
1257 | auto outZero = I->getDest()->getType()->getOffset(); |
1258 | |
1259 | dim_t N, C, sizeN, sizeImg; |
1260 | bool isCMinor; |
1261 | if (numDims == 3) { |
1262 | if (channelIdx == 4) { |
1263 | ShapeNTHWC idim(I->getSrc()->dims()); |
1264 | N = idim.n; |
1265 | C = idim.c; |
1266 | sizeImg = idim.t * idim.h * idim.w; |
1267 | sizeN = idim.c * sizeImg; |
1268 | isCMinor = true; |
1269 | |
1270 | } else { |
1271 | ShapeNCTHW idim(I->getSrc()->dims()); |
1272 | N = idim.n; |
1273 | C = idim.c; |
1274 | sizeImg = idim.t * idim.h * idim.w; |
1275 | sizeN = idim.c * sizeImg; |
1276 | isCMinor = false; |
1277 | } |
1278 | } else if (numDims == 2) { |
1279 | if (channelIdx == 3) { |
1280 | ShapeNHWC idim(I->getSrc()->dims()); |
1281 | N = idim.n; |
1282 | C = idim.c; |
1283 | sizeImg = idim.h * idim.w; |
1284 | sizeN = idim.c * sizeImg; |
1285 | isCMinor = true; |
1286 | |
1287 | } else { |
1288 | ShapeNCHW idim(I->getSrc()->dims()); |
1289 | N = idim.n; |
1290 | C = idim.c; |
1291 | sizeImg = idim.h * idim.w; |
1292 | sizeN = idim.c * sizeImg; |
1293 | isCMinor = false; |
1294 | } |
1295 | |
1296 | } else { |
1297 | // numDims == 1. This can happen due to optimization pass that sinks |
1298 | // reshape below batchnorm. |
1299 | N = I->getSrc()->dims()[0]; |
1300 | C = I->getSrc()->dims()[channelIdx]; |
1301 | sizeImg = I->getSrc()->dims()[channelIdx == 2 ? 1 : 2]; |
1302 | sizeN = C * sizeImg; |
1303 | isCMinor = (channelIdx == 2); |
1304 | } |
1305 | |
1306 | // See qbatch_norm.cpp/compute_fused_params() for FBGEMM implementation |
1307 | std::vector<ParamTy> alpha(C), beta(C); |
1308 | for (dim_t c = 0; c < C; c++) { |
1309 | float invSigma = 1 / std::sqrt(float(varH.at({c})) + epsilon); |
1310 | alpha[c] = ParamTy(invSigma * float(scaleH.at({c})) * (inScale / outScale)); |
1311 | beta[c] = ParamTy((float(biasH.at({c})) - float(meanH.at({c})) * invSigma * |
1312 | float(scaleH.at({c}))) / |
1313 | outScale); |
1314 | } |
1315 | |
1316 | // See QuantizedOpKernels.cpp/q_batch_norm_kernel() for FBGEMM implementation |
1317 | TensorQuantizationParams outputQ{1.0f, outZero}; |
1318 | // For each input in the batch: |
1319 | for (dim_t n = 0; n < N; n++) { |
1320 | if (isCMinor) { |
1321 | // For each H*W{*T} of the image |
1322 | for (dim_t i = 0; i < sizeImg; i++) { |
1323 | // For each channel |
1324 | for (dim_t c = 0; c < C; c++) { |
1325 | int index = n * sizeN + i * C + c; |
1326 | ParamTy x = inH.raw(index) - inZero; |
1327 | ParamTy y = alpha[c] * x + beta[c]; |
1328 | outH.raw(index) = quantization::quantize(y, outputQ); |
1329 | } // image |
1330 | } // C |
1331 | } else { |
1332 | // For each channel |
1333 | for (dim_t c = 0; c < C; c++) { |
1334 | // For each H*W{*T} of the image |
1335 | for (dim_t i = 0; i < sizeImg; i++) { |
1336 | int index = n * sizeN + c * sizeImg + i; |
1337 | auto x = ParamTy(inH.raw(index) - inZero); |
1338 | ParamTy y = alpha[c] * x + beta[c]; |
1339 | outH.raw(index) = quantization::quantize(y, outputQ); |
1340 | } // image |
1341 | } // C |
1342 | } |
1343 | } // N |
1344 | } |
1345 | |
1346 | void BoundInterpreterFunction::fwdBatchNormalizationInst( |
1347 | const BatchNormalizationInst *I) { |
1348 | int numDims = I->getSrc()->dims().size() - 2; |
1349 | bool isQuantized = I->getSrc()->getType()->isQuantizedType(); |
1350 | |
1351 | if (isQuantized) { |
1352 | if (I->getScale()->getType()->getElementType() == ElemKind::FloatTy) { |
1353 | fwdBatchNormalizationI8Impl<float>(I, numDims); |
1354 | } else { |
1355 | fwdBatchNormalizationI8Impl<float16_t>(I, numDims); |
1356 | } |
1357 | } else { |
1358 | dispatchFloatingPointImpl(fwdBatchNormalizationFloatImpl, |
1359 | I->getSrc()->getElementType(), I, numDims); |
1360 | } |
1361 | } |
1362 | |
1363 | //===----------------------------------------------------------------------===// |
1364 | // LayerNormalization |
1365 | //===----------------------------------------------------------------------===// |
1366 | |
1367 | template <typename ElemTy> |
1368 | void BoundInterpreterFunction::fwdLayerNormalizationInstFloatImpl( |
1369 | const glow::LayerNormalizationInst *I) { |
1370 | staticAssertFloatingPointType(ElemTy); |
1371 | |
1372 | // input |
1373 | auto inH = getWeightHandle<ElemTy>(I->getSrc()); |
1374 | auto scaleH = getWeightHandle<ElemTy>(I->getScale()); |
1375 | auto biasH = getWeightHandle<ElemTy>(I->getBias()); |
1376 | float epsilon = I->getEpsilon(); |
1377 | |
1378 | // output |
1379 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1380 | |
1381 | auto N = I->getSrc()->dims()[0]; |
1382 | auto K = I->getSrc()->dims()[1]; |
1383 | |
1384 | std::vector<float> val(K); |
1385 | for (dim_t n = 0; n < N; n++) { |
1386 | // 1. mean = x.mean(dim=-1, keepdim=True) |
1387 | float sum = 0.0f; |
1388 | for (dim_t k = 0; k < K; k++) { |
1389 | val[k] = inH.at({n, k}); |
1390 | sum += val[k]; |
1391 | } |
1392 | float mean = sum / K; |
1393 | |
1394 | // 2. var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) |
1395 | float diff_sqr_sum = 0.0f; |
1396 | for (dim_t k = 0; k < K; k++) { |
1397 | float diff = val[k] - mean; |
1398 | diff_sqr_sum += diff * diff; |
1399 | } |
1400 | float var = diff_sqr_sum / K; |
1401 | |
1402 | // 3. std = (var + epsilon).sqrt() |
1403 | float std = std::sqrt(var + epsilon); |
1404 | |
1405 | for (dim_t k = 0; k < K; k++) { |
1406 | // 4. y = ((x - mean) / std) * scale + bias |
1407 | float scale = scaleH.at({k}); |
1408 | float bias = biasH.at({k}); |
1409 | outW.at({n, k}) = ElemTy((((val[k] - mean) / std) * scale) + bias); |
1410 | } |
1411 | } |
1412 | } |
1413 | |
1414 | void BoundInterpreterFunction::fwdLayerNormalizationInst( |
1415 | const LayerNormalizationInst *I) { |
1416 | dispatchFloatingPointImpl(fwdLayerNormalizationInstFloatImpl, |
1417 | I->getSrc()->getElementType(), I); |
1418 | } |
1419 | |
1420 | //===----------------------------------------------------------------------===// |
1421 | // Pooling |
1422 | //===----------------------------------------------------------------------===// |
1423 | template <class T> |
1424 | static void fwdMaxPool(Tensor *inW, Tensor *outW, Tensor *argmaxW, |
1425 | llvm::ArrayRef<unsigned_t> kernelSizes, |
1426 | llvm::ArrayRef<unsigned_t> strides, |
1427 | llvm::ArrayRef<unsigned_t> pads) { |
1428 | ShapeNHWC odim(outW->dims()); |
1429 | ShapeNHWC idim(inW->dims()); |
1430 | Handle<T> inHandle = inW->getHandle<T>(); |
1431 | Handle<T> outHandle = outW->getHandle<T>(); |
1432 | PaddingTLBR pdim(pads); |
1433 | ShapeHW kdim(kernelSizes); |
1434 | ShapeHW sdim(strides); |
1435 | |
1436 | llvm::Optional<Handle<int64_t>> argmaxH; |
1437 | if (argmaxW) { |
1438 | argmaxH = argmaxW->getHandle<int64_t>(); |
1439 | } |
1440 | // For each input in the batch: |
1441 | for (dim_t n = 0; n < odim.n; n++) { |
1442 | |
1443 | // For each layer in the output tensor: |
1444 | for (dim_t z = 0; z < idim.c; z++) { |
1445 | // For each convolution 'jump' in the input tensor: |
1446 | sdim_t x = -sdim_t(pdim.top); |
1447 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1448 | sdim_t y = -sdim_t(pdim.left); |
1449 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1450 | |
1451 | // When the MaxPool window includes only padding pixels then for |
1452 | // that window by convention we return 0. |
1453 | bool first = true; |
1454 | T max_value = outW->getType().isQuantizedType() |
1455 | ? static_cast<T>(outW->getType().getOffset()) |
1456 | : static_cast<T>(0); |
1457 | dim_t argmaxNHWC = 0; |
1458 | |
1459 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1460 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1461 | sdim_t ox = x + fx; |
1462 | sdim_t oy = y + fy; |
1463 | |
1464 | // Ignore index access below zero (this is due to padding). |
1465 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
1466 | oy >= ssize_t(idim.w)) { |
1467 | continue; |
1468 | } |
1469 | |
1470 | T val = inHandle.at({n, (dim_t)ox, (dim_t)oy, z}); |
1471 | if (first || (val >= max_value)) { |
1472 | first = false; |
1473 | max_value = val; |
1474 | if (argmaxW) { |
1475 | argmaxNHWC = &inHandle.at({n, (dim_t)ox, (dim_t)oy, z}) - |
1476 | &inHandle.raw(0); |
1477 | } |
1478 | } |
1479 | } |
1480 | } |
1481 | |
1482 | outHandle.at({n, ax, ay, z}) = max_value; |
1483 | |
1484 | if (argmaxW) { |
1485 | (*argmaxH).at({n, ax, ay, z}) = argmaxNHWC; |
1486 | } |
1487 | } // W |
1488 | } // H |
1489 | } // C |
1490 | } // N |
1491 | } |
1492 | |
1493 | void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { |
1494 | auto inW = getTensor(I->getSrc()); |
1495 | auto outW = getTensor(I->getDest()); |
1496 | |
1497 | if (inW->getType().isQuantizedType()) { |
1498 | dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW, |
1499 | outW, nullptr, I->getKernels(), I->getStrides(), |
1500 | I->getPads()); |
1501 | return; |
1502 | } |
1503 | |
1504 | dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW, |
1505 | outW, nullptr, I->getKernels(), I->getStrides(), |
1506 | I->getPads()); |
1507 | } |
1508 | |
1509 | void BoundInterpreterFunction::fwdMaxPoolWithArgmaxInst( |
1510 | const MaxPoolWithArgmaxInst *I) { |
1511 | auto inW = getTensor(I->getSrc()); |
1512 | auto outW = getTensor(I->getDest()); |
1513 | auto argmaxW = getTensor(I->getArgmax()); |
1514 | |
1515 | if (inW->getType().isQuantizedType()) { |
1516 | dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW, |
1517 | outW, argmaxW, I->getKernels(), I->getStrides(), |
1518 | I->getPads()); |
1519 | return; |
1520 | } |
1521 | dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW, |
1522 | outW, argmaxW, I->getKernels(), I->getStrides(), |
1523 | I->getPads()); |
1524 | } |
1525 | |
1526 | template <typename ElemTy> |
1527 | void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { |
1528 | staticAssertFloatingPointType(ElemTy); |
1529 | |
1530 | ShapeNHWC odim(I->getDest()->dims()); |
1531 | ShapeNHWC idim(I->getSrc()->dims()); |
1532 | |
1533 | PaddingTLBR pdim(I->getPads()); |
1534 | ShapeHW kdim(I->getKernels()); |
1535 | ShapeHW sdim(I->getStrides()); |
1536 | // Implement the avg pooling operation as defined here: |
1537 | // https://arxiv.org/abs/1312.4400 |
1538 | float rawFilterArea = kdim.height * kdim.width; |
1539 | |
1540 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
1541 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1542 | |
1543 | // For each input in the batch: |
1544 | for (dim_t n = 0; n < odim.n; n++) { |
1545 | // For each layer in the output tensor: |
1546 | for (dim_t z = 0; z < idim.c; z++) { |
1547 | // For each convolution 'jump' in the input tensor: |
1548 | ssize_t x = -ssize_t(pdim.top); |
1549 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1550 | ssize_t y = -ssize_t(pdim.left); |
1551 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1552 | float sum = 0; |
1553 | float filterArea = rawFilterArea; |
1554 | |
1555 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1556 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1557 | sdim_t ox = x + fx; |
1558 | sdim_t oy = y + fy; |
1559 | |
1560 | // Ignore index access below zero (this is due to padding). |
1561 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
1562 | oy >= ssize_t(idim.w)) { |
1563 | if (!I->getCountIncludePads()) { |
1564 | filterArea--; |
1565 | } |
1566 | |
1567 | continue; |
1568 | } |
1569 | |
1570 | sum += float(inW.at({n, (dim_t)ox, (dim_t)oy, z})); |
1571 | } |
1572 | } |
1573 | if (filterArea == 0) { |
1574 | outW.at({n, ax, ay, z}) = ElemTy(0); |
1575 | } else { |
1576 | outW.at({n, ax, ay, z}) = ElemTy(sum / filterArea); |
1577 | } |
1578 | } // W |
1579 | } // H |
1580 | } // C |
1581 | } // N |
1582 | } |
1583 | |
1584 | void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { |
1585 | ShapeNHWC odim(I->getDest()->dims()); |
1586 | ShapeNHWC idim(I->getSrc()->dims()); |
1587 | |
1588 | PaddingTLBR pdim(I->getPads()); |
1589 | ShapeHW kdim(I->getKernels()); |
1590 | ShapeHW sdim(I->getStrides()); |
1591 | // Implement the avg pooling operation as defined here: |
1592 | // https://arxiv.org/abs/1312.4400 |
1593 | float rawFilterArea = kdim.height * kdim.width; |
1594 | |
1595 | auto inW = getWeightHandle<int8_t>(I->getSrc()); |
1596 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
1597 | TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(), |
1598 | I->getSrc()->getType()->getOffset()}; |
1599 | TensorQuantizationParams outQP{I->getDest()->getType()->getScale(), |
1600 | I->getDest()->getType()->getOffset()}; |
1601 | |
1602 | // For each input in the batch: |
1603 | for (dim_t n = 0; n < odim.n; n++) { |
1604 | // For each layer in the output tensor: |
1605 | for (dim_t z = 0; z < idim.c; z++) { |
1606 | // For each convolution 'jump' in the input tensor: |
1607 | ssize_t x = -ssize_t(pdim.top); |
1608 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1609 | ssize_t y = -ssize_t(pdim.left); |
1610 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1611 | int32_t sum = 0; |
1612 | float filterArea = rawFilterArea; |
1613 | |
1614 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1615 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1616 | sdim_t ox = x + fx; |
1617 | sdim_t oy = y + fy; |
1618 | |
1619 | // Ignore index access below zero (this is due to padding). |
1620 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
1621 | oy >= ssize_t(idim.w)) { |
1622 | if (!I->getCountIncludePads()) { |
1623 | filterArea--; |
1624 | } |
1625 | |
1626 | continue; |
1627 | } |
1628 | |
1629 | sum += inW.at({n, (dim_t)ox, (dim_t)oy, z}) - inQP.offset; |
1630 | } |
1631 | } |
1632 | if (filterArea == 0) { |
1633 | outW.at({n, ax, ay, z}) = |
1634 | quantization::clip<int32_t, int8_t>(outQP.offset); |
1635 | } else { |
1636 | // Instead of dividing by filterArea, just change scale. |
1637 | outW.at({n, ax, ay, z}) = |
1638 | quantization::clip<int32_t, int8_t>(std::round( |
1639 | float(sum) * (inQP.scale / outQP.scale / filterArea) + |
1640 | outQP.offset)); |
1641 | } |
1642 | } // W |
1643 | } // H |
1644 | } // C |
1645 | } // N |
1646 | } |
1647 | |
1648 | template <typename ElemTy> |
1649 | void BoundInterpreterFunction::fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I) { |
1650 | staticAssertFloatingPointType(ElemTy); |
1651 | |
1652 | ShapeNTHWC odim(I->getDest()->dims()); |
1653 | ShapeNTHWC idim(I->getSrc()->dims()); |
1654 | |
1655 | PaddingNFTBLR pdim(I->getPads()); |
1656 | ShapeTHW kdim(I->getKernels()); |
1657 | ShapeTHW sdim(I->getStrides()); |
1658 | // Implement the avg pooling operation as defined here: |
1659 | // https://arxiv.org/abs/1312.4400 |
1660 | float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width; |
1661 | |
1662 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
1663 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1664 | |
1665 | // For each input in the batch: |
1666 | for (dim_t n = 0; n < odim.n; n++) { |
1667 | // For each layer in the output tensor: |
1668 | for (dim_t z = 0; z < idim.c; z++) { |
1669 | // For each convolution 'jump' in the input tensor: |
1670 | ssize_t t = -ssize_t(pdim.near); |
1671 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
1672 | ssize_t x = -ssize_t(pdim.top); |
1673 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1674 | ssize_t y = -ssize_t(pdim.left); |
1675 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1676 | float sum = 0; |
1677 | float filterArea = rawFilterArea; |
1678 | |
1679 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
1680 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1681 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1682 | sdim_t ot = t + ft; |
1683 | sdim_t ox = x + fx; |
1684 | sdim_t oy = y + fy; |
1685 | |
1686 | // Ignore index access below zero (this is due to padding). |
1687 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
1688 | ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) { |
1689 | if (!I->getCountIncludePads()) { |
1690 | filterArea--; |
1691 | } |
1692 | |
1693 | continue; |
1694 | } |
1695 | |
1696 | sum += float(inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z})); |
1697 | } |
1698 | } |
1699 | } |
1700 | assert(filterArea != 0 && "filterArea can't be 0" ); |
1701 | outW.at({n, at, ax, ay, z}) = ElemTy(sum / filterArea); |
1702 | } // W |
1703 | } // H |
1704 | } // T |
1705 | } // C |
1706 | } // N |
1707 | } |
1708 | |
1709 | void BoundInterpreterFunction::fwdAvgPool3DInstI8Impl(const AvgPoolInst *I) { |
1710 | ShapeNTHWC odim(I->getDest()->dims()); |
1711 | ShapeNTHWC idim(I->getSrc()->dims()); |
1712 | |
1713 | PaddingNFTBLR pdim(I->getPads()); |
1714 | ShapeTHW kdim(I->getKernels()); |
1715 | ShapeTHW sdim(I->getStrides()); |
1716 | // Implement the avg pooling operation as defined here: |
1717 | // https://arxiv.org/abs/1312.4400 |
1718 | float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width; |
1719 | |
1720 | auto inW = getWeightHandle<int8_t>(I->getSrc()); |
1721 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
1722 | TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(), |
1723 | I->getSrc()->getType()->getOffset()}; |
1724 | TensorQuantizationParams outQP{I->getDest()->getType()->getScale(), |
1725 | I->getDest()->getType()->getOffset()}; |
1726 | |
1727 | // For each input in the batch: |
1728 | for (dim_t n = 0; n < odim.n; n++) { |
1729 | // For each layer in the output tensor: |
1730 | for (dim_t z = 0; z < idim.c; z++) { |
1731 | // For each convolution 'jump' in the input tensor: |
1732 | ssize_t t = -ssize_t(pdim.near); |
1733 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
1734 | ssize_t x = -ssize_t(pdim.top); |
1735 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
1736 | ssize_t y = -ssize_t(pdim.left); |
1737 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
1738 | int32_t sum = 0; |
1739 | float filterArea = rawFilterArea; |
1740 | |
1741 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
1742 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
1743 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
1744 | sdim_t ot = t + ft; |
1745 | sdim_t ox = x + fx; |
1746 | sdim_t oy = y + fy; |
1747 | |
1748 | // Ignore index access below zero (this is due to padding). |
1749 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
1750 | ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) { |
1751 | if (!I->getCountIncludePads()) { |
1752 | filterArea--; |
1753 | } |
1754 | |
1755 | continue; |
1756 | } |
1757 | |
1758 | sum += inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) - |
1759 | inQP.offset; |
1760 | } |
1761 | } |
1762 | } |
1763 | // Instead of dividing by filterArea, just change scale. |
1764 | assert(filterArea != 0 && "filterArea can't be 0" ); |
1765 | |
1766 | float multiplier = inQP.scale / outQP.scale / filterArea; |
1767 | TensorQuantizationParams outputQ{1.0f / multiplier, outQP.offset}; |
1768 | outW.at({n, at, ax, ay, z}) = |
1769 | quantization::quantize(float(sum), outputQ); |
1770 | } // W |
1771 | } // H |
1772 | } // T |
1773 | } // C |
1774 | } // N |
1775 | } |
1776 | |
1777 | void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { |
1778 | bool isConv3D = is3DData(ConvolutionLayout(I->getLayout())); |
1779 | bool isQuantized = I->getSrc()->getType()->isQuantizedType(); |
1780 | |
1781 | if (isConv3D) { |
1782 | if (isQuantized) { |
1783 | fwdAvgPool3DInstI8Impl(I); |
1784 | } else { |
1785 | dispatchFloatingPointImpl(fwdAvgPool3DInstFloatImpl, |
1786 | I->getSrc()->getElementType(), I); |
1787 | } |
1788 | } else { |
1789 | if (isQuantized) { |
1790 | fwdAvgPoolInstI8Impl(I); |
1791 | } else { |
1792 | dispatchFloatingPointImpl(fwdAvgPoolInstFloatImpl, |
1793 | I->getSrc()->getElementType(), I); |
1794 | } |
1795 | } |
1796 | } |
1797 | |
1798 | template <typename ElemTy> |
1799 | void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstFloatImpl( |
1800 | const AdaptiveAvgPoolInst *I) { |
1801 | staticAssertFloatingPointType(ElemTy); |
1802 | |
1803 | ShapeNHWC odim(I->getDest()->dims()); |
1804 | ShapeNHWC idim(I->getSrc()->dims()); |
1805 | |
1806 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
1807 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
1808 | |
1809 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp |
1810 | #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b)) |
1811 | #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b)) |
1812 | |
1813 | // For each input in the batch: |
1814 | for (dim_t n = 0; n < odim.n; n++) { |
1815 | // For each layer in the output tensor: |
1816 | for (dim_t z = 0; z < idim.c; z++) { |
1817 | // For each value in the output tensor: |
1818 | for (dim_t ax = 0; ax < odim.h; ax++) { |
1819 | |
1820 | dim_t x = START_IND(ax, odim.h, idim.h); |
1821 | dim_t kH = END_IND(ax, odim.h, idim.h) - x; |
1822 | |
1823 | for (dim_t ay = 0; ay < odim.w; ay++) { |
1824 | |
1825 | dim_t y = START_IND(ay, odim.w, idim.w); |
1826 | dim_t kW = END_IND(ay, odim.w, idim.w) - y; |
1827 | |
1828 | float sum = 0; |
1829 | for (dim_t fx = 0; fx < kH; fx++) { |
1830 | for (dim_t fy = 0; fy < kW; fy++) { |
1831 | dim_t ox = x + fx; |
1832 | dim_t oy = y + fy; |
1833 | |
1834 | sum += float(inW.at({n, ox, oy, z})); |
1835 | } |
1836 | } |
1837 | outW.at({n, ax, ay, z}) = ElemTy(sum / kW / kH); |
1838 | } // W |
1839 | } // H |
1840 | } // C |
1841 | } // N |
1842 | #undef START_IND |
1843 | #undef END_IND |
1844 | } |
1845 | |
1846 | void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstI8Impl( |
1847 | const AdaptiveAvgPoolInst *I) { |
1848 | ShapeNHWC odim(I->getDest()->dims()); |
1849 | ShapeNHWC idim(I->getSrc()->dims()); |
1850 | |
1851 | auto inW = getWeightHandle<int8_t>(I->getSrc()); |
1852 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
1853 | |
1854 | TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(), |
1855 | I->getSrc()->getType()->getOffset()}; |
1856 | TensorQuantizationParams outQP{I->getDest()->getType()->getScale(), |
1857 | I->getDest()->getType()->getOffset()}; |
1858 | |
1859 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp |
1860 | #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b)) |
1861 | #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b)) |
1862 | |
1863 | // For each input in the batch: |
1864 | for (dim_t n = 0; n < odim.n; n++) { |
1865 | // For each layer in the output tensor: |
1866 | for (dim_t z = 0; z < idim.c; z++) { |
1867 | // For each value in the output tensor: |
1868 | for (dim_t ax = 0; ax < odim.h; ax++) { |
1869 | |
1870 | dim_t x = START_IND(ax, odim.h, idim.h); |
1871 | dim_t kH = END_IND(ax, odim.h, idim.h) - x; |
1872 | |
1873 | for (dim_t ay = 0; ay < odim.w; ay++) { |
1874 | |
1875 | dim_t y = START_IND(ay, odim.w, idim.w); |
1876 | dim_t kW = END_IND(ay, odim.w, idim.w) - y; |
1877 | |
1878 | int32_t sum = 0; |
1879 | for (dim_t fx = 0; fx < kH; fx++) { |
1880 | for (dim_t fy = 0; fy < kW; fy++) { |
1881 | dim_t ox = x + fx; |
1882 | dim_t oy = y + fy; |
1883 | |
1884 | sum += inW.at({n, ox, oy, z}) - inQP.offset; |
1885 | } |
1886 | } |
1887 | |
1888 | outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>( |
1889 | std::round(float(sum) * (inQP.scale / outQP.scale / kW / kH) + |
1890 | outQP.offset)); |
1891 | } // W |
1892 | } // H |
1893 | } // C |
1894 | } // N |
1895 | #undef START_IND |
1896 | #undef END_IND |
1897 | } |
1898 | |
1899 | void BoundInterpreterFunction::fwdAdaptiveAvgPoolInst( |
1900 | const AdaptiveAvgPoolInst *I) { |
1901 | if (I->getSrc()->getType()->isQuantizedType()) { |
1902 | fwdAdaptiveAvgPoolInstI8Impl(I); |
1903 | return; |
1904 | } |
1905 | |
1906 | dispatchFloatingPointImpl(fwdAdaptiveAvgPoolInstFloatImpl, |
1907 | I->getSrc()->getElementType(), I); |
1908 | } |
1909 | |
1910 | void BoundInterpreterFunction::fwdAdaptiveAvgPoolGradInst( |
1911 | const AdaptiveAvgPoolGradInst *I) { |
1912 | auto inG = getWeightHandle(I->getSrcGrad()); |
1913 | auto outW = getWeightHandle(I->getDest()); |
1914 | auto outG = getWeightHandle(I->getDestGrad()); |
1915 | |
1916 | inG.clear(); |
1917 | |
1918 | ShapeNHWC odim(outW.dims()); |
1919 | ShapeNHWC idim(inG.dims()); |
1920 | |
1921 | const float gradCoefficient = 1. / (odim.h * odim.w); |
1922 | |
1923 | #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b)) |
1924 | #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b)) |
1925 | |
1926 | // https://software.intel.com/en-us/daal-programming-guide-2d-average-pooling-backward-layer |
1927 | // For each input in the batch: |
1928 | for (dim_t n = 0; n < odim.n; n++) { |
1929 | // For each layer in the output tensor: |
1930 | for (dim_t z = 0; z < idim.c; z++) { |
1931 | // For each value in the output tensor: |
1932 | for (dim_t ax = 0; ax < odim.h; ax++) { |
1933 | |
1934 | dim_t x = START_IND(ax, odim.h, idim.h); |
1935 | dim_t kH = END_IND(ax, odim.h, idim.h) - x; |
1936 | |
1937 | for (dim_t ay = 0; ay < odim.w; ay++) { |
1938 | |
1939 | dim_t y = START_IND(ay, odim.w, idim.w); |
1940 | dim_t kW = END_IND(ay, odim.w, idim.w) - y; |
1941 | |
1942 | const float chainGrad = outG.at({n, ax, ay, z}) * gradCoefficient; |
1943 | |
1944 | for (dim_t fx = 0; fx < kH; fx++) { |
1945 | for (dim_t fy = 0; fy < kW; fy++) { |
1946 | dim_t ox = x + fx; |
1947 | dim_t oy = y + fy; |
1948 | |
1949 | inG.at({n, ox, oy, z}) += chainGrad; |
1950 | } |
1951 | } |
1952 | } // W |
1953 | } // H |
1954 | } // C |
1955 | } // N |
1956 | #undef START_IND |
1957 | #undef END_IND |
1958 | } |
1959 | |
1960 | void BoundInterpreterFunction::fwdMaxPoolWithArgmaxGradInst( |
1961 | const MaxPoolWithArgmaxGradInst *I) { |
1962 | auto inG = getWeightHandle(I->getSrcGrad()); |
1963 | auto outW = getWeightHandle(I->getDest()); |
1964 | auto outG = getWeightHandle(I->getDestGrad()); |
1965 | |
1966 | inG.clear(); |
1967 | |
1968 | ShapeNHWC idim(inG.dims()); |
1969 | ShapeNHWC odim(outW.dims()); |
1970 | |
1971 | auto argmax = getWeightHandle<int64_t>(I->getArgmax()); |
1972 | |
1973 | // For each input in the batch: |
1974 | for (dim_t n = 0; n < odim.n; n++) { |
1975 | |
1976 | // Compute the gradient. For each layer in the output tensor: |
1977 | for (dim_t z = 0; z < odim.c; z++) { |
1978 | |
1979 | // For each convolution 'jump' in the input tensor: |
1980 | for (dim_t ax = 0; ax < odim.h; ax++) { |
1981 | for (dim_t ay = 0; ay < odim.w; ay++) { |
1982 | // Reuse precomputed linear index of max element from argmax. |
1983 | float chainGrad = outG.at({n, ax, ay, z}); |
1984 | inG.raw(argmax.at({n, ax, ay, z})) += chainGrad; |
1985 | } // W |
1986 | } // H |
1987 | } // C |
1988 | } // N |
1989 | } |
1990 | |
1991 | void BoundInterpreterFunction::fwdAvgPool2DGradInst(const AvgPoolGradInst *I) { |
1992 | auto inG = getWeightHandle(I->getSrcGrad()); |
1993 | auto outW = getWeightHandle(I->getDest()); |
1994 | auto outG = getWeightHandle(I->getDestGrad()); |
1995 | |
1996 | ShapeNHWC odim(outW.dims()); |
1997 | ShapeNHWC idim(inG.dims()); |
1998 | |
1999 | PaddingTLBR pdim(I->getPads()); |
2000 | ShapeHW kdim(I->getKernels()); |
2001 | ShapeHW sdim(I->getStrides()); |
2002 | |
2003 | inG.clear(); |
2004 | |
2005 | float rawFilterArea = kdim.height * kdim.width; |
2006 | |
2007 | // For each input in the batch: |
2008 | for (dim_t n = 0; n < odim.n; n++) { |
2009 | |
2010 | // For each layer in the output tensor: |
2011 | for (dim_t z = 0; z < odim.c; z++) { |
2012 | // For each convolution 'jump' in the input tensor: |
2013 | ssize_t x = -ssize_t(pdim.top); |
2014 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
2015 | ssize_t y = -ssize_t(pdim.left); |
2016 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
2017 | float filterArea = rawFilterArea; |
2018 | |
2019 | // Excludes the padding area in filterArea if the flag is false |
2020 | if (!I->getCountIncludePads()) { |
2021 | ssize_t pad_x = (-x > 0 ? -x : 0) + |
2022 | ((x + ssize_t(kdim.height) - ssize_t(idim.h)) > 0 |
2023 | ? (x + ssize_t(kdim.height) - ssize_t(idim.h)) |
2024 | : 0); |
2025 | ssize_t pad_y = (-y > 0 ? -y : 0) + |
2026 | ((y + ssize_t(kdim.width) - ssize_t(idim.w)) > 0 |
2027 | ? (y + ssize_t(kdim.width) - ssize_t(idim.w)) |
2028 | : 0); |
2029 | filterArea = rawFilterArea - pad_x * kdim.width - |
2030 | pad_y * kdim.height + pad_x * pad_y; |
2031 | } |
2032 | assert(filterArea != 0 && "filterArea can't be 0" ); |
2033 | float dy = outG.at({n, ax, ay, z}) / filterArea; |
2034 | |
2035 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
2036 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
2037 | ssize_t ox = x + fx; |
2038 | ssize_t oy = y + fy; |
2039 | |
2040 | // Ignore index access below zero (this is due to padding). |
2041 | if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) || |
2042 | oy >= ssize_t(idim.w)) { |
2043 | continue; |
2044 | } |
2045 | inG.at({n, (dim_t)ox, (dim_t)oy, z}) += dy; |
2046 | } |
2047 | } |
2048 | } // W |
2049 | } // H |
2050 | } // C |
2051 | } // N |
2052 | } |
2053 | |
2054 | void BoundInterpreterFunction::fwdAvgPool3DGradInst(const AvgPoolGradInst *I) { |
2055 | auto inG = getWeightHandle(I->getSrcGrad()); |
2056 | auto outW = getWeightHandle(I->getDest()); |
2057 | auto outG = getWeightHandle(I->getDestGrad()); |
2058 | |
2059 | ShapeNTHWC odim(outW.dims()); |
2060 | ShapeNTHWC idim(inG.dims()); |
2061 | |
2062 | PaddingNFTBLR pdim(I->getPads()); |
2063 | ShapeTHW kdim(I->getKernels()); |
2064 | ShapeTHW sdim(I->getStrides()); |
2065 | |
2066 | inG.clear(); |
2067 | |
2068 | float rawFilterArea = kdim.temporal_frames * kdim.height * kdim.width; |
2069 | |
2070 | // For each input in the batch: |
2071 | for (dim_t n = 0; n < odim.n; n++) { |
2072 | |
2073 | // For each layer in the output tensor: |
2074 | for (dim_t z = 0; z < odim.c; z++) { |
2075 | // For each convolution 'jump' in the input tensor: |
2076 | ssize_t t = -ssize_t(pdim.near); |
2077 | for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) { |
2078 | ssize_t x = -ssize_t(pdim.top); |
2079 | for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) { |
2080 | ssize_t y = -ssize_t(pdim.left); |
2081 | for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) { |
2082 | float filterArea = rawFilterArea; |
2083 | |
2084 | // Excludes the padding area in filterArea if the flag is false |
2085 | if (!I->getCountIncludePads()) { |
2086 | ssize_t pad_x = |
2087 | (-x > 0 ? -x : 0) + |
2088 | ((x + ssize_t(kdim.height) - ssize_t(idim.h)) > 0 |
2089 | ? (x + ssize_t(kdim.height) - ssize_t(idim.h)) |
2090 | : 0); |
2091 | ssize_t pad_y = (-y > 0 ? -y : 0) + |
2092 | ((y + ssize_t(kdim.width) - ssize_t(idim.w)) > 0 |
2093 | ? (y + ssize_t(kdim.width) - ssize_t(idim.w)) |
2094 | : 0); |
2095 | ssize_t pad_z = |
2096 | (-t > 0 ? -t : 0) + |
2097 | ((t + ssize_t(kdim.temporal_frames) - ssize_t(idim.t)) > 0 |
2098 | ? (t + ssize_t(kdim.temporal_frames) - ssize_t(idim.t)) |
2099 | : 0); |
2100 | filterArea = rawFilterArea - |
2101 | pad_x * kdim.width * kdim.temporal_frames - |
2102 | pad_y * kdim.height * kdim.temporal_frames - |
2103 | pad_z * kdim.height * kdim.width + |
2104 | pad_x * pad_y * kdim.temporal_frames + |
2105 | pad_x * pad_z * kdim.width + |
2106 | pad_y * pad_z * kdim.height - pad_x * pad_y * pad_z; |
2107 | } |
2108 | assert(filterArea != 0 && "filterArea can't be 0" ); |
2109 | float dy = outG.at({n, at, ax, ay, z}) / filterArea; |
2110 | |
2111 | for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) { |
2112 | for (dim_t fx = 0; fx < kdim.height; fx++) { |
2113 | for (dim_t fy = 0; fy < kdim.width; fy++) { |
2114 | ssize_t ot = t + ft; |
2115 | ssize_t ox = x + fx; |
2116 | ssize_t oy = y + fy; |
2117 | |
2118 | // Ignore index access below zero (this is due to padding). |
2119 | if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) || |
2120 | ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) { |
2121 | continue; |
2122 | } |
2123 | inG.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) += dy; |
2124 | } |
2125 | } |
2126 | } |
2127 | } // W |
2128 | } // H |
2129 | } // T |
2130 | } // C |
2131 | } // N |
2132 | } |
2133 | |
2134 | void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { |
2135 | bool isConv3D = is3DData(ConvolutionLayout(I->getLayout())); |
2136 | |
2137 | if (isConv3D) { |
2138 | fwdAvgPool3DGradInst(I); |
2139 | } else { |
2140 | fwdAvgPool2DGradInst(I); |
2141 | } |
2142 | } |
2143 | |
2144 | //===----------------------------------------------------------------------===// |
2145 | // Activation functions |
2146 | //===----------------------------------------------------------------------===// |
2147 | |
2148 | void BoundInterpreterFunction::fwdReluInst(const ReluInst *) { |
2149 | DCHECK(!"Found ReluInst but Relu is lowered on Interpreter" ); |
2150 | } |
2151 | |
2152 | void BoundInterpreterFunction::fwdClipInst(const ClipInst *) { |
2153 | DCHECK(!"Found ClipInst but Clip is lowered on Interpreter" ); |
2154 | } |
2155 | |
2156 | void BoundInterpreterFunction::fwdLeakyReluInst(const LeakyReluInst *) { |
2157 | DCHECK(!"Found LeakyReluInst but LeakyRelu is lowered on Interpreter" ); |
2158 | } |
2159 | |
2160 | template <typename ElemTy> |
2161 | void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { |
2162 | staticAssertFloatingPointType(ElemTy); |
2163 | |
2164 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
2165 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
2166 | |
2167 | for (dim_t i = 0, e = outW.size(); i < e; i++) { |
2168 | float val = inW.raw(i); |
2169 | outW.raw(i) = ElemTy(1 / (1 + std::exp(-val))); |
2170 | } |
2171 | } |
2172 | |
2173 | void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) { |
2174 | dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl, |
2175 | I->getSrc()->getElementType(), I); |
2176 | } |
2177 | |
2178 | template <typename ElemTy> |
2179 | void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { |
2180 | staticAssertFloatingPointType(ElemTy); |
2181 | |
2182 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
2183 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
2184 | |
2185 | for (dim_t i = 0, e = inW.size(); i < e; i++) { |
2186 | float val = inW.raw(i); |
2187 | outW.raw(i) = ElemTy(std::tanh(val)); |
2188 | } |
2189 | } |
2190 | |
2191 | void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) { |
2192 | dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(), |
2193 | I); |
2194 | } |
2195 | |
2196 | template <typename ElemTy> |
2197 | void BoundInterpreterFunction::fwdSoftPlusInstFloatImpl(const SoftPlusInst *I) { |
2198 | staticAssertFloatingPointType(ElemTy); |
2199 | |
2200 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
2201 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
2202 | |
2203 | for (dim_t i = 0, e = outW.size(); i < e; i++) { |
2204 | float val = inW.raw(i); |
2205 | outW.raw(i) = ElemTy(std::log(1 + std::exp(val))); |
2206 | } |
2207 | } |
2208 | |
2209 | void BoundInterpreterFunction::fwdSoftPlusInst(const SoftPlusInst *I) { |
2210 | dispatchFloatingPointImpl(fwdSoftPlusInstFloatImpl, |
2211 | I->getSrc()->getElementType(), I); |
2212 | } |
2213 | |
2214 | //===----------------------------------------------------------------------===// |
2215 | // Loss Functions (Softmax/regression/...) |
2216 | //===----------------------------------------------------------------------===// |
2217 | |
2218 | template <typename ElemTy> |
2219 | void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { |
2220 | staticAssertFloatingPointType(ElemTy); |
2221 | |
2222 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
2223 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
2224 | auto idim = inW.dims(); |
2225 | |
2226 | for (dim_t n = 0; n < idim[0]; n++) { |
2227 | // Find Max. |
2228 | float max = float(inW.at({n, 0})); |
2229 | for (dim_t i = 1; i < idim[1]; i++) { |
2230 | max = std::max(max, float(inW.at({n, i}))); |
2231 | } |
2232 | |
2233 | // Compute exp. |
2234 | float sum = 0; |
2235 | for (dim_t i = 0; i < idim[1]; i++) { |
2236 | float e = std::exp(float(inW.at({n, i})) - max); |
2237 | sum += e; |
2238 | outW.at({n, i}) = ElemTy(e); |
2239 | } |
2240 | |
2241 | // Normalize the output. |
2242 | for (dim_t i = 0; i < idim[1]; i++) { |
2243 | outW.at({n, i}) = ElemTy(float(outW.at({n, i})) / sum); |
2244 | } |
2245 | } // N |
2246 | } |
2247 | |
2248 | void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) { |
2249 | dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(), |
2250 | I); |
2251 | } |
2252 | |
2253 | void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { |
2254 | auto inG = getWeightHandle(I->getSrcGrad()); |
2255 | auto idim = inG.dims(); |
2256 | auto outW = getWeightHandle(I->getOrigDest()); |
2257 | auto selectedH = getWeightHandle<int64_t>(I->getSelected()); |
2258 | |
2259 | inG.clear(); |
2260 | |
2261 | // http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/ |
2262 | // https://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network |
2263 | for (dim_t n = 0; n < idim[0]; n++) { |
2264 | for (dim_t i = 0; i < idim[1]; i++) { |
2265 | float delta = (selectedH.at({n, 0}) == (int64_t)i); |
2266 | inG.at({n, i}) = outW.at({n, i}) - delta; |
2267 | } |
2268 | } |
2269 | } |
2270 | |
2271 | template <typename ElemTy> |
2272 | void BoundInterpreterFunction::fwdLogSoftMaxInstImpl(const LogSoftMaxInst *I) { |
2273 | staticAssertFloatingPointType(ElemTy); |
2274 | |
2275 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
2276 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
2277 | auto idim = inW.dims(); |
2278 | // using log(softmax(x)) = x - max_x - log(sum) |
2279 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L14-L64 |
2280 | for (dim_t n = 0; n < idim[0]; n++) { |
2281 | // Find Max. |
2282 | float max = float(inW.at({n, 0})); |
2283 | for (dim_t i = 1; i < idim[1]; i++) { |
2284 | max = std::max(max, float(inW.at({n, i}))); |
2285 | } |
2286 | |
2287 | // Compute sum of exp(x-max). |
2288 | float sum = 0; |
2289 | for (dim_t i = 0; i < idim[1]; i++) { |
2290 | float e = std::exp(float(inW.at({n, i})) - max); |
2291 | sum += e; |
2292 | } |
2293 | |
2294 | // Output = x - max - log(sum) |
2295 | for (dim_t i = 0; i < idim[1]; i++) { |
2296 | outW.at({n, i}) = ElemTy(float(inW.at({n, i})) - max - std::log(sum)); |
2297 | } |
2298 | } |
2299 | } |
2300 | |
2301 | void BoundInterpreterFunction::fwdLogSoftMaxInst(const LogSoftMaxInst *I) { |
2302 | dispatchFloatingPointImpl(fwdLogSoftMaxInstImpl, |
2303 | I->getSrc()->getElementType(), I); |
2304 | } |
2305 | |
2306 | template <typename ElemTy> |
2307 | void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl( |
2308 | const CrossEntropyLossInst *I) { |
2309 | staticAssertFloatingPointType(ElemTy); |
2310 | |
2311 | auto P = getWeightHandle<ElemTy>(I->getP()); |
2312 | auto labels = getWeightHandle<int64_t>(I->getLabels()); |
2313 | auto CE = getWeightHandle<ElemTy>(I->getCE()); |
2314 | auto dims = P.dims(); |
2315 | CE.clear(); |
2316 | for (dim_t n = 0; n < dims[0]; ++n) { |
2317 | assert(labels.raw(n) >= 0 && "Cannot use negative index." ); |
2318 | dim_t y = labels.raw(n); |
2319 | float p_n = P.at({n, y}); |
2320 | CE.at({0}) -= log(p_n); |
2321 | } |
2322 | } |
2323 | |
2324 | void BoundInterpreterFunction::fwdCrossEntropyLossInst( |
2325 | const CrossEntropyLossInst *I) { |
2326 | dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl, |
2327 | I->getP()->getElementType(), I); |
2328 | } |
2329 | |
2330 | void BoundInterpreterFunction::fwdCrossEntropyLossGradInst( |
2331 | const CrossEntropyLossGradInst *I) { |
2332 | auto P = getWeightHandle(I->getP()); |
2333 | auto Labels = getWeightHandle<int64_t>(I->getLabels()); |
2334 | auto PGrad = getWeightHandle(I->getPgrad()); |
2335 | auto dims = PGrad.dims(); |
2336 | PGrad.clear(); |
2337 | for (dim_t n = 0; n < dims[0]; ++n) { |
2338 | assert(Labels.raw(n) >= 0 && "Cannot use negative index." ); |
2339 | dim_t y = Labels.raw(n); |
2340 | PGrad.at({n, y}) = -1 / P.at({n, y}); // * CEGrad.at({0}) |
2341 | } |
2342 | } |
2343 | |
2344 | //===----------------------------------------------------------------------===// |
2345 | // Tensor shape (copy/transpose/concat/...) |
2346 | //===----------------------------------------------------------------------===// |
2347 | |
2348 | void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) { |
2349 | auto inT = getTensor(I->getSrc()); |
2350 | auto outT = getTensor(I->getDest()); |
2351 | outT->copyRawFrom(inT); |
2352 | } |
2353 | |
2354 | void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) { |
2355 | auto inT = getTensor(I->getSrc()); |
2356 | (void)inT; |
2357 | auto outT = getTensor(I->getDest()); |
2358 | |
2359 | assert(outT->size() == inT->size() && "Invalid tensor dimensions" ); |
2360 | |
2361 | if (I->getSrc()->getType()->isQuantizedType()) { |
2362 | inT->transpose(outT, I->getShuffle()); |
2363 | } else { |
2364 | inT->transpose(outT, I->getShuffle()); |
2365 | } |
2366 | } |
2367 | |
2368 | void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) { |
2369 | getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets()); |
2370 | } |
2371 | |
2372 | void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { |
2373 | auto *T = getTensor(I->getDest()); |
2374 | ElemKind k = T->getElementType(); |
2375 | |
2376 | if (k == ElemKind::Int32ITy) { |
2377 | return T->getHandle<int32_t>().clear(I->getValue()); |
2378 | } |
2379 | |
2380 | if (k == ElemKind::Int64ITy) { |
2381 | return T->getHandle<int64_t>().clear(I->getValue()); |
2382 | } |
2383 | |
2384 | if (k == ElemKind::Int32ITy) { |
2385 | return T->getHandle<int32_t>().clear(I->getValue()); |
2386 | } |
2387 | |
2388 | if (k == ElemKind::FloatTy) { |
2389 | return T->getHandle<float>().clear(I->getValue()); |
2390 | } |
2391 | |
2392 | if (k == ElemKind::Float16Ty) { |
2393 | return T->getHandle<float16_t>().clear(I->getValue()); |
2394 | } |
2395 | |
2396 | if (k == ElemKind::BFloat16Ty) { |
2397 | return T->getHandle<bfloat16_t>().clear(I->getValue()); |
2398 | } |
2399 | |
2400 | if (k == ElemKind::BoolTy) { |
2401 | return T->getHandle<bool>().clear(static_cast<bool>(I->getValue())); |
2402 | } |
2403 | |
2404 | if (k == ElemKind::Int8QTy) { |
2405 | // Quantize the requested floating point splat value into the correct |
2406 | // integer representation. |
2407 | auto destTy = I->getDest()->getType(); |
2408 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
2409 | float val = I->getValue(); |
2410 | return T->getHandle<int8_t>().clear(quantization::quantize(val, destQ)); |
2411 | } |
2412 | |
2413 | if (k == ElemKind::Int16QTy) { |
2414 | // Quantize the requested floating point splat value into the correct |
2415 | // integer representation. |
2416 | auto destTy = I->getDest()->getType(); |
2417 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
2418 | float val = I->getValue(); |
2419 | return T->getHandle<int16_t>().clear(quantization::quantize(val, destQ)); |
2420 | } |
2421 | |
2422 | if (k == ElemKind::BoolTy) { |
2423 | return T->getHandle<bool>().clear(static_cast<bool>(I->getValue())); |
2424 | } |
2425 | |
2426 | llvm_unreachable("Unsupported tensor type" ); |
2427 | } |
2428 | |
2429 | void BoundInterpreterFunction::fwdTouchInst(const glow::TouchInst *) { |
2430 | // Do nothing for a TouchInst |
2431 | } |
2432 | |
2433 | void BoundInterpreterFunction::fwdInsertTensorInst( |
2434 | const glow::InsertTensorInst *I) { |
2435 | Tensor *outT = getTensor(I->getDest()); |
2436 | Tensor *inT = getTensor(I->getSrc()); |
2437 | ElemKind k = outT->getElementType(); |
2438 | #define TYPED_INSERT(TY, TYPEKIND) \ |
2439 | if (k == TYPEKIND) { \ |
2440 | auto OH = outT->getHandle<TY>(); \ |
2441 | auto IH = inT->getHandle<TY>(); \ |
2442 | return OH.insertTensors(IH, I->getOffsets(), I->getCount(), I->getAxis()); \ |
2443 | } |
2444 | |
2445 | TYPED_INSERT(int64_t, ElemKind::Int64ITy); |
2446 | TYPED_INSERT(int32_t, ElemKind::Int32ITy); |
2447 | TYPED_INSERT(float, ElemKind::FloatTy); |
2448 | TYPED_INSERT(float16_t, ElemKind::Float16Ty); |
2449 | TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty); |
2450 | TYPED_INSERT(int8_t, ElemKind::Int8QTy); |
2451 | TYPED_INSERT(int16_t, ElemKind::Int16QTy); |
2452 | TYPED_INSERT(bool, ElemKind::BoolTy); |
2453 | #undef TYPED_INSERT |
2454 | |
2455 | llvm_unreachable("Unsupported tensor type" ); |
2456 | } |
2457 | |
2458 | void BoundInterpreterFunction::( |
2459 | const glow::ExtractTensorInst *I) { |
2460 | Tensor *outT = getTensor(I->getDest()); |
2461 | Tensor *inT = getTensor(I->getSrc()); |
2462 | ElemKind k = outT->getElementType(); |
2463 | #define TYPED_INSERT(TY, TYPEKIND) \ |
2464 | if (k == TYPEKIND) { \ |
2465 | auto OH = outT->getHandle<TY>(); \ |
2466 | auto IH = inT->getHandle<TY>(); \ |
2467 | return IH.extractTensors(OH, I->getOffsets()); \ |
2468 | } |
2469 | |
2470 | TYPED_INSERT(int64_t, ElemKind::Int64ITy); |
2471 | TYPED_INSERT(float, ElemKind::FloatTy); |
2472 | TYPED_INSERT(float16_t, ElemKind::Float16Ty); |
2473 | TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty); |
2474 | TYPED_INSERT(int8_t, ElemKind::Int8QTy); |
2475 | TYPED_INSERT(int32_t, ElemKind::Int32QTy); |
2476 | TYPED_INSERT(int32_t, ElemKind::Int32ITy); |
2477 | TYPED_INSERT(bool, ElemKind::BoolTy); |
2478 | #undef TYPED_INSERT |
2479 | |
2480 | llvm_unreachable("Unsupported tensor type" ); |
2481 | } |
2482 | |
2483 | template <typename ElemTy> |
2484 | void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { |
2485 | Tensor *dataT = getTensor(I->getData()); |
2486 | auto &dataTy = dataT->getType(); |
2487 | Tensor *indicesT = getTensor(I->getIndices()); |
2488 | Tensor *outT = getTensor(I->getDest()); |
2489 | unsigned_t axis = I->getBatchDims(); |
2490 | |
2491 | size_t out_p = 0; |
2492 | dim_t elementSize = dataTy.getElementSize(); |
2493 | // The size of the sample in the batch. |
2494 | dim_t dataSampleSize = dataTy.getSliceSize(axis) * elementSize; |
2495 | // The size of the slices that we gather. |
2496 | dim_t dataSliceSize = dataTy.getSliceSize(axis + 1) * elementSize; |
2497 | |
2498 | // Calculate the size of each sample in the batch. |
2499 | dim_t numSamples = (dataT->size() * elementSize) / dataSampleSize; |
2500 | |
2501 | // Calculate number of samples in the batch. |
2502 | dim_t batchSize = dataTy.dims()[axis]; |
2503 | (void)batchSize; |
2504 | |
2505 | // For each sample in the batch: |
2506 | for (dim_t sample = 0; sample < numSamples; sample++) { |
2507 | dim_t sampleStart = sample * dataSampleSize; |
2508 | |
2509 | // For each slice (small fragment) that we copy from the source memory: |
2510 | for (dim_t i = 0, end = indicesT->size(); i < end; i++) { |
2511 | dim_t slice = indicesT->getHandle<ElemTy>().raw(i); |
2512 | assert(slice < batchSize && "Invalid index seen during Gather operation" ); |
2513 | std::copy( |
2514 | &dataT->getUnsafePtr()[sampleStart + dataSliceSize * slice], |
2515 | &dataT->getUnsafePtr()[sampleStart + dataSliceSize * (slice + 1)], |
2516 | &outT->getUnsafePtr()[out_p]); |
2517 | out_p += dataSliceSize; |
2518 | } |
2519 | } |
2520 | } |
2521 | |
2522 | void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { |
2523 | switch (I->getIndices()->getElementType()) { |
2524 | case ElemKind::Int64ITy: |
2525 | fwdGatherInstImpl<int64_t>(I); |
2526 | break; |
2527 | case ElemKind::Int32ITy: |
2528 | fwdGatherInstImpl<int32_t>(I); |
2529 | break; |
2530 | default: |
2531 | llvm_unreachable("Unsupported type for indices input of Gather." ); |
2532 | } |
2533 | } |
2534 | |
2535 | //===----------------------------------------------------------------------===// |
2536 | // Gather Elements |
2537 | //===----------------------------------------------------------------------===// |
2538 | |
2539 | template <typename IndexTy> |
2540 | void BoundInterpreterFunction::fwdGatherElementsInstImpl( |
2541 | const glow::GatherElementsInst *I) { |
2542 | Tensor *outT = getTensor(I->getDest()); |
2543 | Tensor *dataT = getTensor(I->getData()); |
2544 | auto dataDims = dataT->getType().dims(); |
2545 | Tensor *indicesT = getTensor(I->getIndices()); |
2546 | auto indicesDims = indicesT->getType().dims(); |
2547 | const auto dim = I->getDim(); |
2548 | const auto numElems = outT->getRealNumElements(); |
2549 | auto ndims = indicesDims.size(); |
2550 | |
2551 | std::vector<dim_t> ind_dim_off(ndims); |
2552 | std::vector<dim_t> data_dim_off(ndims); |
2553 | |
2554 | ind_dim_off[0] = 1; |
2555 | data_dim_off[0] = 1; |
2556 | for (dim_t i = 1; i < ndims; i++) { |
2557 | ind_dim_off[i] = |
2558 | std::accumulate(indicesDims.begin() + ndims - i, indicesDims.end(), 1, |
2559 | std::multiplies<dim_t>()); |
2560 | data_dim_off[i] = |
2561 | std::accumulate(dataDims.begin() + ndims - i, dataDims.end(), 1, |
2562 | std::multiplies<dim_t>()); |
2563 | } |
2564 | |
2565 | const auto dimPos = ndims - 1 - dim; |
2566 | const auto elemSize = dataT->getType().getElementSize(); |
2567 | // Loop over number of elements in indices |
2568 | for (size_t idx = 0; idx < numElems; idx++) { |
2569 | unsigned_t offset = 0; |
2570 | // Loop over number of dimensions to calculate offset |
2571 | for (dim_t i = 0; i < ndims; i++) { |
2572 | // Calculate axis index i.e. (i, j, k) |
2573 | const dim_t dim_idx = |
2574 | static_cast<dim_t>(std::floor(idx / ind_dim_off[i])) % |
2575 | indicesDims[ndims - (i + 1)]; |
2576 | offset += (dimPos != i) * dim_idx * data_dim_off[i]; |
2577 | } |
2578 | auto indIdx = indicesT->getHandle<IndexTy>().raw(idx); |
2579 | assert(indIdx < 0 ? -indIdx <= dataDims[dim] |
2580 | : indIdx < dataDims[dim] && |
2581 | "[GatherElements] Got out of bounds index" ); |
2582 | // In case of negative indices |
2583 | if (indIdx < 0) { |
2584 | indIdx += dataDims[dim]; |
2585 | } |
2586 | const auto dataRawIdx = indIdx * data_dim_off[dimPos] + offset; |
2587 | std::copy(&dataT->getUnsafePtr()[dataRawIdx * elemSize], |
2588 | &dataT->getUnsafePtr()[(dataRawIdx + 1) * elemSize], |
2589 | &outT->getUnsafePtr()[idx * elemSize]); |
2590 | } |
2591 | } |
2592 | |
2593 | void BoundInterpreterFunction::fwdGatherElementsInst( |
2594 | const glow::GatherElementsInst *I) { |
2595 | switch (I->getIndices()->getElementType()) { |
2596 | case ElemKind::Int64ITy: |
2597 | fwdGatherElementsInstImpl<int64_t>(I); |
2598 | break; |
2599 | case ElemKind::Int32ITy: |
2600 | fwdGatherElementsInstImpl<int32_t>(I); |
2601 | break; |
2602 | default: |
2603 | llvm_unreachable("[GatherElements] Unsupported type for indices input of " |
2604 | "GatherElements." ); |
2605 | } |
2606 | } |
2607 | |
2608 | template <typename ElemTy> |
2609 | void BoundInterpreterFunction::fwdGatherNDInstImpl( |
2610 | const glow::GatherNDInst *I) { |
2611 | |
2612 | Tensor *dataT = getTensor(I->getData()); |
2613 | Tensor *indicesT = getTensor(I->getIndices()); |
2614 | Tensor *outT = getTensor(I->getDest()); |
2615 | auto batchDims = I->getBatchDims(); |
2616 | |
2617 | auto dataDims = I->getData()->dims(); |
2618 | auto indicesDims = I->getIndices()->dims(); |
2619 | dim_t indicesDimLast = indicesDims.back(); |
2620 | |
2621 | // Compute batch count. |
2622 | dim_t batchCount = 1; |
2623 | for (size_t idx = 0; idx < batchDims; ++idx) { |
2624 | batchCount *= dataDims[idx]; |
2625 | } |
2626 | |
2627 | // Compute input slice count. |
2628 | dim_t inpSliceCount = 1; |
2629 | for (size_t idx = batchDims; idx < batchDims + indicesDimLast; ++idx) { |
2630 | inpSliceCount *= dataDims[idx]; |
2631 | } |
2632 | |
2633 | // Compute output slice count. |
2634 | dim_t outSliceCount = 1; |
2635 | for (size_t idx = batchDims; idx < indicesDims.size() - 1; ++idx) { |
2636 | outSliceCount *= indicesDims[idx]; |
2637 | } |
2638 | |
2639 | // Compute slice size (in bytes). |
2640 | dim_t sliceSize = dataT->getType().getElementSize(); |
2641 | for (size_t idx = batchDims + indicesDimLast; idx < dataDims.size(); idx++) { |
2642 | sliceSize *= dataDims[idx]; |
2643 | } |
2644 | |
2645 | // Get indices dimension products. |
2646 | std::vector<dim_t> indicesDimProd(indicesDimLast); |
2647 | indicesDimProd[indicesDimLast - 1] = 1; |
2648 | for (ssize_t idx = static_cast<ssize_t>(indicesDimLast) - 2; idx >= 0; |
2649 | idx--) { |
2650 | indicesDimProd[idx] = |
2651 | indicesDimProd[idx + 1] * dataDims[batchDims + idx + 1]; |
2652 | } |
2653 | |
2654 | // We will view the tensors as equivalent 3D tensors with the dimensions: |
2655 | // data - batchCount x inpSliceCount x sliceSize |
2656 | // indices - batchCount x outSliceCount x indicesDimLast |
2657 | // output - batchCount x outSliceCount x sliceSize |
2658 | |
2659 | char *dataPtr = dataT->getUnsafePtr(); |
2660 | ElemTy *indicesPtr = (ElemTy *)indicesT->getUnsafePtr(); |
2661 | char *outPtr = outT->getUnsafePtr(); |
2662 | |
2663 | for (size_t batchIdx = 0; batchIdx < batchCount; ++batchIdx) { |
2664 | for (size_t outSliceIdx = 0; outSliceIdx < outSliceCount; ++outSliceIdx) { |
2665 | |
2666 | // Compute input slice index. |
2667 | dim_t inpSliceIdx = 0; |
2668 | for (size_t idx = 0; idx < indicesDimLast; ++idx) { |
2669 | inpSliceIdx += (*indicesPtr++) * indicesDimProd[idx]; |
2670 | } |
2671 | |
2672 | // Copy data. |
2673 | std::copy(dataPtr + (inpSliceIdx + 0) * sliceSize, |
2674 | dataPtr + (inpSliceIdx + 1) * sliceSize, outPtr); |
2675 | outPtr += sliceSize; |
2676 | } |
2677 | |
2678 | // Increment input pointer for next batch. |
2679 | dataPtr += inpSliceCount * sliceSize; |
2680 | } |
2681 | } |
2682 | |
2683 | void BoundInterpreterFunction::fwdGatherNDInst(const glow::GatherNDInst *I) { |
2684 | switch (I->getIndices()->getElementType()) { |
2685 | case ElemKind::Int64ITy: |
2686 | fwdGatherNDInstImpl<int64_t>(I); |
2687 | break; |
2688 | case ElemKind::Int32ITy: |
2689 | fwdGatherNDInstImpl<int32_t>(I); |
2690 | break; |
2691 | default: |
2692 | llvm_unreachable("Unsupported type for indices input of Gather." ); |
2693 | } |
2694 | } |
2695 | |
2696 | template <typename ElemTy> |
2697 | void BoundInterpreterFunction::fwdGatherRangesInstImpl( |
2698 | const glow::GatherRangesInst *I) { |
2699 | Tensor *dataT = getTensor(I->getData()); |
2700 | auto &dataTy = dataT->getType(); |
2701 | Tensor *rangesT = getTensor(I->getRanges()); |
2702 | auto &rangesTy = rangesT->getType(); |
2703 | Tensor *outT = getTensor(I->getOutput()); |
2704 | Tensor *lengthsT = getTensor(I->getLengths()); |
2705 | |
2706 | // Offset into the output tensor that keeps track of where to start |
2707 | // copying data. |
2708 | size_t outP = 0; |
2709 | |
2710 | unsigned dataElementSize = dataTy.getElementSize(); |
2711 | dim_t numExamples = rangesTy.dims()[0]; |
2712 | dim_t exampleSize = rangesTy.dims()[1]; |
2713 | |
2714 | // Keep track of the total number of elements gathered across all |
2715 | // examples for a sanity check later. |
2716 | dim_t grandTotalLen = 0; |
2717 | |
2718 | // For each example in ranges: |
2719 | for (dim_t example = 0; example < numExamples; ++example) { |
2720 | // Keep a running total of the lengths of all ranges in this example |
2721 | // to record into lengthsT once the entire example is processed. |
2722 | ElemTy totalLen = 0; |
2723 | |
2724 | // For each range in the example: |
2725 | for (dim_t range = 0; range < exampleSize; ++range) { |
2726 | // Get the start index and range length. |
2727 | ElemTy startIdx = rangesT->getHandle<ElemTy>().at({example, range, 0}); |
2728 | ElemTy len = rangesT->getHandle<ElemTy>().at({example, range, 1}); |
2729 | |
2730 | // Add the length of this current range to the example length counter. |
2731 | totalLen += len; |
2732 | |
2733 | // Compute the start and end offsets. |
2734 | dim_t startOffset = startIdx * dataElementSize; |
2735 | dim_t endOffset = startOffset + (len * dataElementSize); |
2736 | |
2737 | // Sanity checks on the offsets. |
2738 | assert(startOffset < dataT->getSizeInBytes()); |
2739 | assert(endOffset <= dataT->getSizeInBytes()); |
2740 | assert(endOffset >= startOffset); |
2741 | assert(outP < outT->getSizeInBytes()); |
2742 | assert((outP + (len * dataElementSize)) <= outT->getSizeInBytes()); |
2743 | |
2744 | // Copy the specified data to outT. |
2745 | std::copy(&dataT->getUnsafePtr()[startOffset], |
2746 | &dataT->getUnsafePtr()[endOffset], &outT->getUnsafePtr()[outP]); |
2747 | |
2748 | // Advance the offset into outT. |
2749 | outP += len * dataElementSize; |
2750 | } |
2751 | |
2752 | // Record the total number of elements gathered for the example in |
2753 | // lengthsT. |
2754 | lengthsT->getHandle<ElemTy>().at({example}) = totalLen; |
2755 | |
2756 | // Add the total length of the entire example to the grand total. |
2757 | grandTotalLen += static_cast<size_t>(totalLen); |
2758 | } |
2759 | |
2760 | // Make sure that number of elements written to outT is equal to the |
2761 | // total of all elements in lengthsT. |
2762 | assert(grandTotalLen == (outP / dataElementSize)); |
2763 | } |
2764 | |
2765 | void BoundInterpreterFunction::fwdGatherRangesInst( |
2766 | const glow::GatherRangesInst *I) { |
2767 | switch (I->getRanges()->getElementType()) { |
2768 | case ElemKind::Int64ITy: |
2769 | fwdGatherRangesInstImpl<int64_t>(I); |
2770 | break; |
2771 | case ElemKind::Int32ITy: |
2772 | fwdGatherRangesInstImpl<int32_t>(I); |
2773 | break; |
2774 | default: |
2775 | llvm_unreachable("Unsupported type for ranges input of GatherRanges." ); |
2776 | } |
2777 | } |
2778 | |
2779 | template <typename ElemTy, typename IndicesElemTy> |
2780 | void BoundInterpreterFunction::fwdScatterDataInstCopyImpl( |
2781 | const glow::ScatterDataInst *I) { |
2782 | Tensor *dataT = getTensor(I->getData()); |
2783 | Tensor *indicesT = getTensor(I->getIndices()); |
2784 | Tensor *slicesT = getTensor(I->getSlices()); |
2785 | |
2786 | assert(indicesT->dims().size() == 2 && |
2787 | "Index should be stored in 2D tensor!" ); |
2788 | const dim_t dataSliceSize = slicesT->size() / slicesT->dims()[0] * |
2789 | slicesT->getType().getElementSize(); |
2790 | |
2791 | auto IH = indicesT->getHandle<IndicesElemTy>(); |
2792 | // For each index, copy from the slice at that index into the location in |
2793 | // data given the offset from the indices tensor. |
2794 | for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) { |
2795 | dim_t destDataIdx = 0; |
2796 | for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) { |
2797 | destDataIdx *= dataT->dims()[j]; |
2798 | destDataIdx += IH.at({i, j}); |
2799 | } |
2800 | std::copy(&slicesT->getUnsafePtr()[i * dataSliceSize], |
2801 | &slicesT->getUnsafePtr()[(i + 1) * dataSliceSize], |
2802 | &dataT->getUnsafePtr()[dataSliceSize * destDataIdx]); |
2803 | } |
2804 | } |
2805 | |
2806 | template <typename ElemTy, typename IndicesElemTy> |
2807 | void BoundInterpreterFunction::fwdScatterDataInstAddFloatImpl( |
2808 | const glow::ScatterDataInst *I) { |
2809 | Tensor *dataT = getTensor(I->getData()); |
2810 | Tensor *indicesT = getTensor(I->getIndices()); |
2811 | Tensor *slicesT = getTensor(I->getSlices()); |
2812 | |
2813 | assert(!dataT->getType().isQuantizedType() && "Should be float type!" ); |
2814 | assert(!slicesT->getType().isQuantizedType() && "Should be float type!" ); |
2815 | |
2816 | const size_t numSlices = slicesT->size() / slicesT->dims()[0]; |
2817 | |
2818 | auto IH = indicesT->getHandle<IndicesElemTy>(); |
2819 | // For each index, copy from the slice at that index into the location in |
2820 | // data given the offset from the indices tensor. |
2821 | assert(indicesT->dims().size() == 2 && |
2822 | "Multi-dimensional index should be stored in 2D tensor!" ); |
2823 | auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>(); |
2824 | for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) { |
2825 | size_t destDataIdx = 0; |
2826 | for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) { |
2827 | destDataIdx *= dataT->dims()[j]; |
2828 | destDataIdx += IH.at({i, j}); |
2829 | } |
2830 | for (dim_t j = 0; j < numSlices; j++) { |
2831 | D.raw(destDataIdx * numSlices + j) += S.raw(i * numSlices + j); |
2832 | } |
2833 | } |
2834 | } |
2835 | |
2836 | template <typename ElemTy, typename IndicesElemTy> |
2837 | void BoundInterpreterFunction::fwdScatterDataInstAddQuantizedImpl( |
2838 | const glow::ScatterDataInst *I) { |
2839 | Tensor *dataT = getTensor(I->getData()); |
2840 | Tensor *indicesT = getTensor(I->getIndices()); |
2841 | Tensor *slicesT = getTensor(I->getSlices()); |
2842 | |
2843 | assert(dataT->getType().isQuantizedType() && "Should be quantized type!" ); |
2844 | assert(slicesT->getType().isQuantizedType() && "Should be quantized type!" ); |
2845 | |
2846 | const dim_t numSlices = slicesT->size() / slicesT->dims()[0]; |
2847 | |
2848 | TensorQuantizationParams dataQ{dataT->getType().getScale(), |
2849 | dataT->getType().getOffset()}; |
2850 | TensorQuantizationParams sliceQ{slicesT->getType().getScale(), |
2851 | slicesT->getType().getOffset()}; |
2852 | |
2853 | auto IH = indicesT->getHandle<IndicesElemTy>(); |
2854 | // For each index, copy from the slice at that index into the location in |
2855 | // data given the offset from the indices tensor. |
2856 | assert(indicesT->dims().size() == 2 && |
2857 | "Multi-dimensional index should be stored in 2D tensor!" ); |
2858 | auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>(); |
2859 | for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) { |
2860 | dim_t destDataIdx = 0; |
2861 | for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) { |
2862 | destDataIdx *= dataT->dims()[j]; |
2863 | destDataIdx += IH.at({i, j}); |
2864 | } |
2865 | for (dim_t j = 0; j < numSlices; j++) { |
2866 | float lhs = |
2867 | quantization::dequantize(D.raw(destDataIdx * numSlices + j), dataQ); |
2868 | float rhs = quantization::dequantize(S.raw(i * numSlices + j), sliceQ); |
2869 | ElemTy result = quantization::quantize(lhs + rhs, dataQ); |
2870 | D.raw(destDataIdx * numSlices + j) = result; |
2871 | } |
2872 | } |
2873 | } |
2874 | |
2875 | void BoundInterpreterFunction::fwdScatterDataInst( |
2876 | const glow::ScatterDataInst *I) { |
2877 | const auto indicesAreInt64 = |
2878 | I->getIndices()->getElementType() == ElemKind::Int64ITy; |
2879 | |
2880 | if (I->getCumulative()) { |
2881 | switch (I->getData()->getElementType()) { |
2882 | case ElemKind::FloatTy: |
2883 | if (indicesAreInt64) { |
2884 | fwdScatterDataInstAddFloatImpl<float, int64_t>(I); |
2885 | } else { |
2886 | fwdScatterDataInstAddFloatImpl<float, int32_t>(I); |
2887 | } |
2888 | |
2889 | break; |
2890 | case ElemKind::Int8QTy: |
2891 | if (indicesAreInt64) { |
2892 | fwdScatterDataInstAddQuantizedImpl<int8_t, int64_t>(I); |
2893 | } else { |
2894 | fwdScatterDataInstAddQuantizedImpl<int8_t, int32_t>(I); |
2895 | } |
2896 | break; |
2897 | default: |
2898 | llvm_unreachable("Unsupported type for ScatterData." ); |
2899 | } |
2900 | } else { |
2901 | switch (I->getData()->getElementType()) { |
2902 | case ElemKind::FloatTy: |
2903 | if (indicesAreInt64) { |
2904 | fwdScatterDataInstCopyImpl<float, int64_t>(I); |
2905 | } else { |
2906 | fwdScatterDataInstCopyImpl<float, int32_t>(I); |
2907 | } |
2908 | break; |
2909 | case ElemKind::Int8QTy: |
2910 | if (indicesAreInt64) { |
2911 | fwdScatterDataInstCopyImpl<int8_t, int64_t>(I); |
2912 | } else { |
2913 | fwdScatterDataInstCopyImpl<int8_t, int32_t>(I); |
2914 | } |
2915 | break; |
2916 | default: |
2917 | llvm_unreachable("Unsupported type for ScatterData." ); |
2918 | } |
2919 | } |
2920 | } |
2921 | |
2922 | template <typename ElemTy> |
2923 | void BoundInterpreterFunction::fwdBatchOneHotImpl( |
2924 | const glow::BatchOneHotInst *I) { |
2925 | auto dataH = getWeightHandle<ElemTy>(I->getData()); |
2926 | auto lengthsH = getWeightHandle<int32_t>(I->getLengths()); |
2927 | auto valuesH = getWeightHandle<ElemTy>(I->getValues()); |
2928 | auto destH = getWeightHandle<ElemTy>(I->getDest()); |
2929 | |
2930 | auto batchSize = dataH.dims()[0]; |
2931 | auto featureCnt = dataH.dims()[1]; |
2932 | |
2933 | for (dim_t batchId = 0; batchId < batchSize; batchId++) { |
2934 | size_t offset = 0; |
2935 | for (dim_t featureId = 0; featureId < featureCnt; featureId++) { |
2936 | auto curValue = dataH.at({batchId, featureId}); |
2937 | auto curLength = lengthsH.at({featureId}); |
2938 | for (dim_t i = offset, e = offset + curLength; i != e; i++) { |
2939 | destH.at({batchId, i}) = curValue == valuesH.at({i}); |
2940 | } |
2941 | offset += curLength; |
2942 | } |
2943 | assert(offset == destH.dims()[1] && |
2944 | "Sum of Lengths must be equal to size of Values" ); |
2945 | } |
2946 | } |
2947 | |
2948 | void BoundInterpreterFunction::fwdBatchOneHotInst( |
2949 | const glow::BatchOneHotInst *I) { |
2950 | switch (I->getData()->getElementType()) { |
2951 | case ElemKind::Int64ITy: |
2952 | fwdBatchOneHotImpl<int64_t>(I); |
2953 | break; |
2954 | case ElemKind::Int32ITy: |
2955 | fwdBatchOneHotImpl<int32_t>(I); |
2956 | break; |
2957 | case ElemKind::Int8QTy: |
2958 | fwdBatchOneHotImpl<int8_t>(I); |
2959 | break; |
2960 | default: |
2961 | dispatchFloatingPointImpl(fwdBatchOneHotImpl, |
2962 | I->getData()->getElementType(), I); |
2963 | } |
2964 | } |
2965 | |
2966 | template <typename ElemTy> |
2967 | void BoundInterpreterFunction::fwdSpaceToDepthInstImpl( |
2968 | const glow::SpaceToDepthInst *I) { |
2969 | auto *inT = getTensor(I->getSrc()); |
2970 | auto *outT = getTensor(I->getDest()); |
2971 | |
2972 | auto inH = inT->getHandle<ElemTy>(); |
2973 | auto outH = outT->getHandle<ElemTy>(); |
2974 | |
2975 | unsigned blockSize = I->getBlockSize(); |
2976 | |
2977 | dim_t inDepth = inT->dims()[3]; |
2978 | |
2979 | dim_t outBatch = outT->dims()[0]; |
2980 | dim_t outHeight = outT->dims()[1]; |
2981 | dim_t outWidth = outT->dims()[2]; |
2982 | dim_t outDepth = outT->dims()[3]; |
2983 | |
2984 | for (dim_t ob = 0; ob < outBatch; ++ob) { |
2985 | for (dim_t oh = 0; oh < outHeight; ++oh) { |
2986 | for (dim_t ow = 0; ow < outWidth; ++ow) { |
2987 | for (dim_t oc = 0; oc < outDepth; ++oc) { |
2988 | // Gets the block layer we are on |
2989 | dim_t blockDepthLayer = oc / inDepth; |
2990 | // every multiple of block size we reset to 0 offset |
2991 | dim_t iw = ow * blockSize + blockDepthLayer % blockSize; |
2992 | // every multiple of blockSize we start height traversal + 1 |
2993 | dim_t ih = oh * blockSize + blockDepthLayer / blockSize; |
2994 | // at every multiple of inDepth index in to input depths resets to 0 |
2995 | dim_t ic = oc % inDepth; |
2996 | |
2997 | outH.at({ob, oh, ow, oc}) = inH.at({ob, ih, iw, ic}); |
2998 | } |
2999 | } |
3000 | } |
3001 | } |
3002 | } |
3003 | |
3004 | void BoundInterpreterFunction::fwdSpaceToDepthInst( |
3005 | const glow::SpaceToDepthInst *I) { |
3006 | switch (I->getSrc()->getElementType()) { |
3007 | case ElemKind::FloatTy: |
3008 | fwdSpaceToDepthInstImpl<float>(I); |
3009 | break; |
3010 | case ElemKind::Int8QTy: |
3011 | fwdSpaceToDepthInstImpl<int8_t>(I); |
3012 | break; |
3013 | default: |
3014 | llvm_unreachable("Type is not supported" ); |
3015 | break; |
3016 | } |
3017 | } |
3018 | |
3019 | template <typename ElemTy> |
3020 | void BoundInterpreterFunction::fwdResizeNearestInstImpl( |
3021 | const ResizeNearestInst *I) { |
3022 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
3023 | auto scale = I->getScale(); |
3024 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3025 | |
3026 | auto outputDims = outW.dims(); |
3027 | auto inputDims = inW.dims(); |
3028 | |
3029 | for (dim_t oa = 0; oa < outputDims[0]; ++oa) { |
3030 | auto ia = std::min(dim_t(oa / scale[0]), inputDims[0] - 1); |
3031 | for (dim_t ob = 0; ob < outputDims[1]; ++ob) { |
3032 | auto ib = std::min(dim_t(ob / scale[1]), inputDims[1] - 1); |
3033 | for (dim_t oc = 0; oc < outputDims[2]; ++oc) { |
3034 | auto ic = std::min(dim_t(oc / scale[2]), inputDims[2] - 1); |
3035 | if (outputDims.size() > 3) { |
3036 | for (dim_t od = 0; od < outputDims[3]; ++od) { |
3037 | auto id = std::min(dim_t(od / scale[3]), inputDims[3] - 1); |
3038 | if (outputDims.size() > 4) { |
3039 | for (dim_t oe = 0; oe < outputDims[4]; ++oe) { |
3040 | auto ie = std::min(dim_t(oe / scale[4]), inputDims[4] - 1); |
3041 | if (outputDims.size() > 5) { |
3042 | for (dim_t of = 0; of < outputDims[4]; ++of) { |
3043 | auto f = std::min(dim_t(of / scale[5]), inputDims[5] - 1); |
3044 | outW.at({oa, ob, oc, od, oe, of}) = |
3045 | inW.at({ia, ib, ic, id, ie, f}); |
3046 | } |
3047 | } else { |
3048 | outW.at({oa, ob, oc, od, oe}) = inW.at({ia, ib, ic, id, ie}); |
3049 | } |
3050 | } |
3051 | } else { |
3052 | outW.at({oa, ob, oc, od}) = inW.at({ia, ib, ic, id}); |
3053 | } |
3054 | } |
3055 | } else { |
3056 | outW.at({oa, ob, oc}) = inW.at({ia, ib, ic}); |
3057 | } |
3058 | } |
3059 | } |
3060 | } |
3061 | } |
3062 | |
3063 | void BoundInterpreterFunction::fwdResizeNearestInst( |
3064 | const ResizeNearestInst *I) { |
3065 | if (getTensor(I->getSrc())->getType().isQuantizedType()) { |
3066 | dispatchQuantizedImpl(fwdResizeNearestInstImpl, |
3067 | I->getSrc()->getElementType(), I); |
3068 | return; |
3069 | } |
3070 | |
3071 | dispatchImpl(fwdResizeNearestInstImpl, I->getSrc()->getElementType(), I); |
3072 | } |
3073 | |
3074 | template <typename ElemTy> |
3075 | void BoundInterpreterFunction::fwdResizeBilinearInstImpl( |
3076 | const ResizeBilinearInst *I) { |
3077 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
3078 | auto scale = I->getScale(); |
3079 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3080 | |
3081 | ShapeNHWC odim(outW.dims()); |
3082 | ShapeNHWC idim(inW.dims()); |
3083 | |
3084 | CHECK_EQ(scale[0], 1.0) << "Scaling batch not supported." ; |
3085 | CHECK_EQ(scale[3], 1.0) << "Scaling channel not supported." ; |
3086 | |
3087 | for (dim_t ob = 0; ob < odim.n; ++ob) { |
3088 | for (dim_t oh = 0; oh < odim.h; ++oh) { |
3089 | for (dim_t ow = 0; ow < odim.w; ++ow) { |
3090 | |
3091 | float ihf = oh / scale[1]; |
3092 | float iwf = ow / scale[2]; |
3093 | dim_t ih = dim_t(ihf); |
3094 | dim_t iw = dim_t(iwf); |
3095 | |
3096 | auto ih0 = std::min(ih, idim.h - 1); |
3097 | auto ih1 = std::min(ih + 1, idim.h - 1); |
3098 | auto iw0 = std::min(iw, idim.w - 1); |
3099 | auto iw1 = std::min(iw + 1, idim.w - 1); |
3100 | |
3101 | for (dim_t oc = 0; oc < odim.c; ++oc) { |
3102 | auto v00 = inW.at({ob, ih0, iw0, oc}); |
3103 | auto v01 = inW.at({ob, ih0, iw1, oc}); |
3104 | auto v10 = inW.at({ob, ih1, iw0, oc}); |
3105 | auto v11 = inW.at({ob, ih1, iw1, oc}); |
3106 | |
3107 | auto hd = (float)v00 + (float)(v10 - v00) * (ihf - ih); |
3108 | auto hw = (float)v01 + (float)(v11 - v01) * (ihf - ih); |
3109 | float result = hd + (hw - hd) * (iwf - iw); |
3110 | outW.at({ob, oh, ow, oc}) = result; |
3111 | } |
3112 | } |
3113 | } |
3114 | } |
3115 | } |
3116 | |
3117 | void BoundInterpreterFunction::fwdResizeBilinearInst( |
3118 | const ResizeBilinearInst *I) { |
3119 | if (getTensor(I->getSrc())->getType().isQuantizedType()) { |
3120 | dispatchQuantizedImpl(fwdResizeBilinearInstImpl, |
3121 | I->getSrc()->getElementType(), I); |
3122 | return; |
3123 | } |
3124 | |
3125 | dispatchImpl(fwdResizeBilinearInstImpl, I->getSrc()->getElementType(), I); |
3126 | } |
3127 | |
3128 | //===----------------------------------------------------------------------===// |
3129 | // Local Response Normalization |
3130 | //===----------------------------------------------------------------------===// |
3131 | |
3132 | template <typename ElemTy> |
3133 | void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( |
3134 | const glow::LocalResponseNormalizationInst *I) { |
3135 | staticAssertFloatingPointType(ElemTy); |
3136 | |
3137 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
3138 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3139 | auto scaleCache = getWeightHandle<ElemTy>(I->getScale()); |
3140 | |
3141 | ShapeNHWC odim(outW.dims()); |
3142 | ShapeNHWC idim(inW.dims()); |
3143 | |
3144 | (void)odim; |
3145 | |
3146 | // LRN node does not change the shape of the input. |
3147 | assert(odim == idim && "Output of LRN node must be same shape as input" ); |
3148 | |
3149 | // LRN node normalizes across channels, so the input must have a minimum |
3150 | // depth of 1. |
3151 | assert(idim.c > 0 && "Input of LRN node must have a minimum depth of 1" ); |
3152 | |
3153 | auto halfWindowSize = (size_t)I->getHalfWindowSize(); |
3154 | auto k = I->getK(); |
3155 | auto beta = I->getBeta(); |
3156 | auto windowSize = 2 * halfWindowSize + 1; |
3157 | auto normedAlpha = I->getAlpha() / windowSize; |
3158 | |
3159 | // For every input in the batch: |
3160 | for (dim_t n = 0; n < idim.n; n++) { |
3161 | |
3162 | // For every row: |
3163 | for (dim_t h = 0; h < idim.h; h++) { |
3164 | |
3165 | // For every column: |
3166 | for (dim_t w = 0; w < idim.w; w++) { |
3167 | |
3168 | // For every channel: |
3169 | for (dim_t c = 0; c < idim.c; c++) { |
3170 | float squareSum = 0.0; |
3171 | for (dim_t i = (c >= halfWindowSize ? c - halfWindowSize : 0); |
3172 | i <= std::min<dim_t>(c + halfWindowSize, (size_t)idim.c - 1); |
3173 | i++) { |
3174 | float val = inW.at({n, h, w, i}); |
3175 | squareSum += val * val; |
3176 | } |
3177 | |
3178 | auto scale = k + normedAlpha * squareSum; |
3179 | |
3180 | // This will be used to accelerate the backward pass. |
3181 | scaleCache.at({n, h, w, c}) = ElemTy(scale); |
3182 | |
3183 | auto normFactor = std::pow(scale, -beta); |
3184 | outW.at({n, h, w, c}) = |
3185 | ElemTy(float(inW.at({n, h, w, c})) * normFactor); |
3186 | } |
3187 | } |
3188 | } |
3189 | } |
3190 | } |
3191 | |
3192 | void BoundInterpreterFunction::fwdLocalResponseNormalizationInst( |
3193 | const LocalResponseNormalizationInst *I) { |
3194 | dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl, |
3195 | I->getSrc()->getElementType(), I); |
3196 | } |
3197 | |
3198 | void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst( |
3199 | const glow::LocalResponseNormalizationGradInst *I) { |
3200 | auto inW = getWeightHandle(I->getSrc()); |
3201 | auto inG = getWeightHandle(I->getSrcGrad()); |
3202 | auto outW = getWeightHandle(I->getDest()); |
3203 | auto outG = getWeightHandle(I->getDestGrad()); |
3204 | auto scaleCache = getWeightHandle(I->getScale()); |
3205 | |
3206 | ShapeNHWC odim(outW.dims()); |
3207 | |
3208 | auto halfWindowSize = I->getHalfWindowSize(); |
3209 | auto beta = I->getBeta(); |
3210 | auto windowSize = 2 * halfWindowSize + 1; |
3211 | auto normedAlpha = I->getAlpha() / windowSize; |
3212 | |
3213 | // For every input in the batch: |
3214 | for (dim_t n = 0; n < odim.n; n++) { |
3215 | |
3216 | // For every row: |
3217 | for (dim_t h = 0; h < odim.h; h++) { |
3218 | |
3219 | // For every column: |
3220 | for (dim_t w = 0; w < odim.w; w++) { |
3221 | |
3222 | float sum = 0.0; |
3223 | |
3224 | // Compute sum for first channel. |
3225 | for (dim_t c = 0; c <= halfWindowSize && c < odim.c; c++) { |
3226 | auto outw = outW.at({n, h, w, c}); |
3227 | auto scale = scaleCache.at({n, h, w, c}); |
3228 | auto outg = outG.at({n, h, w, c}); |
3229 | sum += (outg * (outw / scale)); |
3230 | } |
3231 | |
3232 | // For every channel: |
3233 | for (dim_t c = 0; c < odim.c; c++) { |
3234 | auto outg = outG.at({n, h, w, c}); |
3235 | auto scale = scaleCache.at({n, h, w, c}); |
3236 | auto inw = inW.at({n, h, w, c}); |
3237 | |
3238 | inG.at({n, h, w, c}) = outg * std::pow(scale, -beta) - |
3239 | 2 * normedAlpha * beta * inw * sum; |
3240 | |
3241 | // Modify sum for next channel. |
3242 | auto subIndex = c - halfWindowSize; |
3243 | auto addIndex = c + halfWindowSize + 1; |
3244 | |
3245 | if (c >= halfWindowSize) { |
3246 | auto outw = outW.at({n, h, w, subIndex}); |
3247 | auto scale = scaleCache.at({n, h, w, subIndex}); |
3248 | auto outg = outG.at({n, h, w, subIndex}); |
3249 | |
3250 | // Subtract "rear" end of this window. |
3251 | sum -= (outg * (outw / scale)); |
3252 | } |
3253 | |
3254 | if (addIndex < odim.c) { |
3255 | auto outw = outW.at({n, h, w, addIndex}); |
3256 | auto scale = scaleCache.at({n, h, w, addIndex}); |
3257 | auto outg = outG.at({n, h, w, addIndex}); |
3258 | |
3259 | // Add "front" end of next window. |
3260 | sum += (outg * (outw / scale)); |
3261 | } |
3262 | } |
3263 | } |
3264 | } |
3265 | } |
3266 | } |
3267 | |
3268 | //===--------------------------------------------------------------------===// |
3269 | // Bucketing |
3270 | //===--------------------------------------------------------------------===// |
3271 | |
3272 | void BoundInterpreterFunction::fwdBucketizeInst(const BucketizeInst *I) { |
3273 | auto inputH = getTensor(I->getSrc())->getHandle<float>(); |
3274 | auto outputH = getTensor(I->getDest())->getHandle<int32_t>(); |
3275 | const auto boundaries = I->getBoundaries(); |
3276 | |
3277 | const auto numItems = inputH.size(); |
3278 | |
3279 | for (size_t i = 0; i < numItems; ++i) { |
3280 | outputH.raw(i) = |
3281 | std::lower_bound(boundaries.begin(), boundaries.end(), inputH.raw(i)) - |
3282 | boundaries.begin(); |
3283 | } |
3284 | } |
3285 | |
3286 | //===----------------------------------------------------------------------===// |
3287 | // Arithmetic operations |
3288 | //===----------------------------------------------------------------------===// |
3289 | void BoundInterpreterFunction::fwdElementAddInstI8Impl( |
3290 | const ElementAddInst *I) { |
3291 | assert(getTensor(I->getLHS())->getType().isQuantizedType() && |
3292 | "Wrong function" ); |
3293 | auto lhsTy = I->getLHS()->getType(); |
3294 | auto rhsTy = I->getRHS()->getType(); |
3295 | auto destTy = I->getDest()->getType(); |
3296 | |
3297 | float lhsScale = lhsTy->getScale(); |
3298 | float rhsScale = rhsTy->getScale(); |
3299 | float destScale = destTy->getScale(); |
3300 | |
3301 | int32_t lhsOffset = lhsTy->getOffset(); |
3302 | int32_t rhsOffset = rhsTy->getOffset(); |
3303 | int32_t destOffset = destTy->getOffset(); |
3304 | |
3305 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3306 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3307 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3308 | for (dim_t i = 0, e = outW.size(); i < e; i++) { |
3309 | int32_t L = lhsW.raw(i); |
3310 | int32_t R = rhsW.raw(i); |
3311 | |
3312 | // We increase the size of the integer up to 16 bits to prevent overflow. |
3313 | const float largeScale = float(1) / (1 << 15); |
3314 | // Scale both sides from 8-bit to 16-bits. |
3315 | int32_t L32 = std::round(float(L - lhsOffset) * (lhsScale / largeScale)); |
3316 | int32_t R32 = std::round(float(R - rhsOffset) * (rhsScale / largeScale)); |
3317 | int32_t sum32 = L32 + R32; |
3318 | sum32 = std::round(float(sum32) * (largeScale / destScale) + destOffset); |
3319 | outW.raw(i) = quantization::clip<int32_t, int8_t>(sum32); |
3320 | } |
3321 | } |
3322 | |
3323 | template <typename ElemTy> |
3324 | void BoundInterpreterFunction::fwdElementAddInstArithmeticImpl( |
3325 | const ElementAddInst *I) { |
3326 | staticAssertArithmeticType(ElemTy); |
3327 | |
3328 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3329 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3330 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3331 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3332 | outW.raw(i) = lhsW.raw(i) + rhsW.raw(i); |
3333 | } |
3334 | } |
3335 | |
3336 | void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { |
3337 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3338 | fwdElementAddInstI8Impl(I); |
3339 | return; |
3340 | } |
3341 | |
3342 | dispatchArithmeticImpl(fwdElementAddInstArithmeticImpl, |
3343 | I->getLHS()->getType()->getElementType(), I); |
3344 | } |
3345 | |
3346 | template <typename ElemTy> |
3347 | void BoundInterpreterFunction::fwdElementSubInstArithmeticImpl( |
3348 | const ElementSubInst *I) { |
3349 | staticAssertArithmeticType(ElemTy); |
3350 | |
3351 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3352 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3353 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3354 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3355 | outW.raw(i) = lhsW.raw(i) - rhsW.raw(i); |
3356 | } |
3357 | } |
3358 | |
3359 | void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { |
3360 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3361 | auto destTy = I->getDest()->getType(); |
3362 | auto lhsTy = I->getLHS()->getType(); |
3363 | auto rhsTy = I->getRHS()->getType(); |
3364 | |
3365 | float destScale = destTy->getScale(); |
3366 | float lhsScale = lhsTy->getScale(); |
3367 | float rhsScale = rhsTy->getScale(); |
3368 | |
3369 | int32_t destOffset = destTy->getOffset(); |
3370 | int32_t lhsOffset = lhsTy->getOffset(); |
3371 | int32_t rhsOffset = rhsTy->getOffset(); |
3372 | |
3373 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3374 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3375 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3376 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3377 | // s_d * (i_d - o_d) = s_l * (i_l - o_l) - s_r * (i_r - o_r) |
3378 | // => i_d = (s_l / s_d) * (i_l - o_l) - (s_r / s_d) * (i_r - o_r) + o_d |
3379 | float l = (lhsScale / destScale) * float(lhsW.raw(i) - lhsOffset); |
3380 | float r = (rhsScale / destScale) * float(rhsW.raw(i) - rhsOffset); |
3381 | int32_t q = std::round(l - r + destOffset); |
3382 | outW.raw(i) = quantization::clip<int32_t, int8_t>(q); |
3383 | } |
3384 | return; |
3385 | } |
3386 | |
3387 | dispatchArithmeticImpl(fwdElementSubInstArithmeticImpl, |
3388 | I->getDest()->getElementType(), I); |
3389 | } |
3390 | |
3391 | template <typename ElemTy> |
3392 | void BoundInterpreterFunction::fwdElementMulInstArithmeticImpl( |
3393 | const ElementMulInst *I) { |
3394 | staticAssertArithmeticType(ElemTy); |
3395 | |
3396 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3397 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3398 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3399 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3400 | outW.raw(i) = lhsW.raw(i) * rhsW.raw(i); |
3401 | } |
3402 | } |
3403 | |
3404 | void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { |
3405 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3406 | auto lhsTy = I->getLHS()->getType(); |
3407 | auto rhsTy = I->getRHS()->getType(); |
3408 | auto destTy = I->getDest()->getType(); |
3409 | |
3410 | TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()}; |
3411 | TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()}; |
3412 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
3413 | |
3414 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3415 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3416 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3417 | float scale = lhsQ.scale * rhsQ.scale / destQ.scale; |
3418 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3419 | int32_t mul = (lhsW.raw(i) - lhsQ.offset) * (rhsW.raw(i) - rhsQ.offset); |
3420 | outW.raw(i) = quantization::clip<int32_t, int8_t>( |
3421 | std::round(mul * scale) + destQ.offset); |
3422 | } |
3423 | return; |
3424 | } |
3425 | |
3426 | dispatchArithmeticImpl(fwdElementMulInstArithmeticImpl, |
3427 | I->getDest()->getElementType(), I); |
3428 | } |
3429 | |
3430 | void BoundInterpreterFunction::fwdElementFmodInst(const ElementFmodInst *I) { |
3431 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3432 | auto destTy = I->getDest()->getType(); |
3433 | auto lhsTy = I->getLHS()->getType(); |
3434 | auto rhsTy = I->getRHS()->getType(); |
3435 | |
3436 | float destScale = destTy->getScale(); |
3437 | float lhsScale = lhsTy->getScale(); |
3438 | float rhsScale = rhsTy->getScale(); |
3439 | |
3440 | int32_t destOffset = destTy->getOffset(); |
3441 | int32_t lhsOffset = lhsTy->getOffset(); |
3442 | int32_t rhsOffset = rhsTy->getOffset(); |
3443 | |
3444 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3445 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3446 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3447 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3448 | float l = lhsScale * float(lhsW.raw(i) - lhsOffset); |
3449 | float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset); |
3450 | int32_t q = std::round(std::fmod(l, r) + destOffset); |
3451 | outW.raw(i) = quantization::clip<int32_t, int8_t>(q); |
3452 | } |
3453 | return; |
3454 | } |
3455 | |
3456 | #define FMOD_LOOP(TYPE_) \ |
3457 | staticAssertArithmeticType(TYPE_); \ |
3458 | auto outW = getWeightHandle<TYPE_>(I->getDest()); \ |
3459 | auto lhsW = getWeightHandle<TYPE_>(I->getLHS()); \ |
3460 | auto rhsW = getWeightHandle<TYPE_>(I->getRHS()); \ |
3461 | for (size_t i = 0, e = outW.size(); i < e; i++) { \ |
3462 | outW.raw(i) = std::fmod((float)lhsW.raw(i), (float)rhsW.raw(i)); \ |
3463 | } |
3464 | |
3465 | auto *T = getTensor(I->getDest()); |
3466 | switch (T->getElementType()) { |
3467 | case ElemKind::Int64ITy: { |
3468 | FMOD_LOOP(int64_t); |
3469 | return; |
3470 | } |
3471 | case ElemKind::Int32ITy: { |
3472 | FMOD_LOOP(int32_t); |
3473 | return; |
3474 | } |
3475 | case ElemKind::FloatTy: { |
3476 | FMOD_LOOP(float); |
3477 | return; |
3478 | } |
3479 | case ElemKind::Float16Ty: { |
3480 | FMOD_LOOP(float16_t); |
3481 | return; |
3482 | } |
3483 | case ElemKind::BFloat16Ty: { |
3484 | FMOD_LOOP(bfloat16_t); |
3485 | return; |
3486 | } |
3487 | |
3488 | default: |
3489 | llvm_unreachable("Unsupported type for Fmod." ); |
3490 | } |
3491 | } |
3492 | |
3493 | void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { |
3494 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3495 | auto destTy = I->getDest()->getType(); |
3496 | auto lhsTy = I->getLHS()->getType(); |
3497 | auto rhsTy = I->getRHS()->getType(); |
3498 | |
3499 | float destScale = destTy->getScale(); |
3500 | float lhsScale = lhsTy->getScale(); |
3501 | float rhsScale = rhsTy->getScale(); |
3502 | |
3503 | int32_t destOffset = destTy->getOffset(); |
3504 | int32_t lhsOffset = lhsTy->getOffset(); |
3505 | int32_t rhsOffset = rhsTy->getOffset(); |
3506 | |
3507 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3508 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3509 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3510 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3511 | // s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r)) |
3512 | // => i_d = (s_l * (i_l - o_l)) / (s_d * s_r * (i_r - o_r)) + o_d |
3513 | float l = lhsScale * float(lhsW.raw(i) - lhsOffset); |
3514 | float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset); |
3515 | int32_t q = std::round(l / r + destOffset); |
3516 | outW.raw(i) = quantization::clip<int32_t, int8_t>(q); |
3517 | } |
3518 | return; |
3519 | } |
3520 | |
3521 | #define DIV_LOOP(TYPE_) \ |
3522 | auto outW = getWeightHandle<TYPE_>(I->getDest()); \ |
3523 | auto lhsW = getWeightHandle<TYPE_>(I->getLHS()); \ |
3524 | auto rhsW = getWeightHandle<TYPE_>(I->getRHS()); \ |
3525 | for (size_t i = 0, e = outW.size(); i < e; i++) { \ |
3526 | outW.raw(i) = lhsW.raw(i) / rhsW.raw(i); \ |
3527 | } |
3528 | |
3529 | auto *T = getTensor(I->getDest()); |
3530 | switch (T->getElementType()) { |
3531 | case ElemKind::Int64ITy: { |
3532 | DIV_LOOP(int64_t); |
3533 | return; |
3534 | } |
3535 | case ElemKind::Int32ITy: { |
3536 | DIV_LOOP(int32_t); |
3537 | return; |
3538 | } |
3539 | case ElemKind::FloatTy: { |
3540 | DIV_LOOP(float); |
3541 | return; |
3542 | } |
3543 | case ElemKind::Float16Ty: { |
3544 | DIV_LOOP(float16_t); |
3545 | return; |
3546 | } |
3547 | case ElemKind::BFloat16Ty: { |
3548 | DIV_LOOP(bfloat16_t); |
3549 | return; |
3550 | } |
3551 | default: |
3552 | llvm_unreachable("Unsupported type for Div." ); |
3553 | } |
3554 | } |
3555 | |
3556 | void BoundInterpreterFunction::fwdElementMaxInstI8Impl( |
3557 | const ElementMaxInst *I) { |
3558 | assert(getTensor(I->getLHS())->getType().isQuantizedType() && |
3559 | "Wrong function" ); |
3560 | auto lhsTy = I->getLHS()->getType(); |
3561 | auto rhsTy = I->getRHS()->getType(); |
3562 | auto destTy = I->getDest()->getType(); |
3563 | |
3564 | TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()}; |
3565 | TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()}; |
3566 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
3567 | |
3568 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3569 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3570 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3571 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3572 | // Convert both sides to the destination scale and perform a regular |
3573 | // comparison. |
3574 | int8_t L = quantization::quantize( |
3575 | quantization::dequantize(lhsW.raw(i), lhsQ), destQ); |
3576 | int8_t R = quantization::quantize( |
3577 | quantization::dequantize(rhsW.raw(i), rhsQ), destQ); |
3578 | outW.raw(i) = std::max(L, R); |
3579 | } |
3580 | } |
3581 | |
3582 | template <typename ElemTy> |
3583 | void BoundInterpreterFunction::fwdElementMaxInstArithmeticImpl( |
3584 | const ElementMaxInst *I) { |
3585 | staticAssertArithmeticType(ElemTy); |
3586 | |
3587 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3588 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3589 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3590 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3591 | outW.raw(i) = std::max(lhsW.raw(i), rhsW.raw(i)); |
3592 | } |
3593 | } |
3594 | |
3595 | void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { |
3596 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3597 | fwdElementMaxInstI8Impl(I); |
3598 | return; |
3599 | } |
3600 | |
3601 | dispatchArithmeticImpl(fwdElementMaxInstArithmeticImpl, |
3602 | I->getLHS()->getType()->getElementType(), I); |
3603 | } |
3604 | |
3605 | template <typename ElemTy> |
3606 | void BoundInterpreterFunction::fwdElementMinInstArithmeticImpl( |
3607 | const ElementMinInst *I) { |
3608 | staticAssertArithmeticType(ElemTy); |
3609 | |
3610 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3611 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3612 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3613 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3614 | outW.raw(i) = std::min(lhsW.raw(i), rhsW.raw(i)); |
3615 | } |
3616 | } |
3617 | |
3618 | void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { |
3619 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
3620 | auto lhsTy = I->getLHS()->getType(); |
3621 | auto rhsTy = I->getRHS()->getType(); |
3622 | auto destTy = I->getDest()->getType(); |
3623 | |
3624 | TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()}; |
3625 | TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()}; |
3626 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
3627 | |
3628 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
3629 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
3630 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
3631 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3632 | // Convert both sides to the destination scale and perform a regular |
3633 | // comparison. |
3634 | int8_t L = quantization::quantize( |
3635 | quantization::dequantize(lhsW.raw(i), lhsQ), destQ); |
3636 | int8_t R = quantization::quantize( |
3637 | quantization::dequantize(rhsW.raw(i), rhsQ), destQ); |
3638 | outW.raw(i) = std::min(L, R); |
3639 | } |
3640 | return; |
3641 | } |
3642 | |
3643 | dispatchArithmeticImpl(fwdElementMinInstArithmeticImpl, |
3644 | I->getDest()->getElementType(), I); |
3645 | } |
3646 | |
3647 | template <typename ElemTy> |
3648 | void BoundInterpreterFunction::fwdElementBitwiseOrInstImpl( |
3649 | const ElementBitwiseOrInst *I) { |
3650 | |
3651 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3652 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3653 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3654 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3655 | outW.raw(i) = lhsW.raw(i) | rhsW.raw(i); |
3656 | } |
3657 | } |
3658 | |
3659 | void BoundInterpreterFunction::fwdElementBitwiseOrInst( |
3660 | const ElementBitwiseOrInst *I) { |
3661 | |
3662 | dispatchBitwiseImpl(fwdElementBitwiseOrInstImpl, |
3663 | I->getDest()->getElementType(), I); |
3664 | } |
3665 | |
3666 | template <typename ElemTy> |
3667 | void BoundInterpreterFunction::fwdElementBitwiseAndInstImpl( |
3668 | const ElementBitwiseAndInst *I) { |
3669 | |
3670 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3671 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3672 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3673 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3674 | outW.raw(i) = lhsW.raw(i) & rhsW.raw(i); |
3675 | } |
3676 | } |
3677 | |
3678 | void BoundInterpreterFunction::fwdElementBitwiseAndInst( |
3679 | const ElementBitwiseAndInst *I) { |
3680 | |
3681 | dispatchBitwiseImpl(fwdElementBitwiseAndInstImpl, |
3682 | I->getDest()->getElementType(), I); |
3683 | } |
3684 | |
3685 | template <typename ElemTy> |
3686 | void BoundInterpreterFunction::fwdElementBitwiseXorInstImpl( |
3687 | const ElementBitwiseXorInst *I) { |
3688 | staticAssertArithmeticType(ElemTy); |
3689 | |
3690 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
3691 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
3692 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
3693 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
3694 | outW.raw(i) = lhsW.raw(i) ^ rhsW.raw(i); |
3695 | } |
3696 | } |
3697 | |
3698 | void BoundInterpreterFunction::fwdElementBitwiseXorInst( |
3699 | const ElementBitwiseXorInst *I) { |
3700 | |
3701 | dispatchBitwiseImpl(fwdElementBitwiseXorInstImpl, |
3702 | I->getDest()->getElementType(), I); |
3703 | } |
3704 | |
3705 | //===----------------------------------------------------------------------===// |
3706 | // Logical operations |
3707 | //===----------------------------------------------------------------------===// |
3708 | void BoundInterpreterFunction::fwdElementNotInst(const ElementNotInst *I) { |
3709 | auto inpW = getWeightHandle<bool>(I->getSrc()); |
3710 | auto outW = getWeightHandle<bool>(I->getDest()); |
3711 | for (size_t i = 0, e = outW.size(); i < e; ++i) { |
3712 | outW.raw(i) = (!inpW.raw(i)); |
3713 | } |
3714 | } |
3715 | |
3716 | void BoundInterpreterFunction::fwdElementAndInst(const ElementAndInst *I) { |
3717 | auto lhsW = getWeightHandle<bool>(I->getLHS()); |
3718 | auto rhsW = getWeightHandle<bool>(I->getRHS()); |
3719 | auto outW = getWeightHandle<bool>(I->getDest()); |
3720 | for (size_t i = 0, e = outW.size(); i < e; ++i) { |
3721 | outW.raw(i) = (lhsW.raw(i) && rhsW.raw(i)); |
3722 | } |
3723 | } |
3724 | |
3725 | void BoundInterpreterFunction::fwdElementOrInst(const ElementOrInst *I) { |
3726 | auto lhsW = getWeightHandle<bool>(I->getLHS()); |
3727 | auto rhsW = getWeightHandle<bool>(I->getRHS()); |
3728 | auto outW = getWeightHandle<bool>(I->getDest()); |
3729 | for (size_t i = 0, e = outW.size(); i < e; ++i) { |
3730 | outW.raw(i) = (lhsW.raw(i) || rhsW.raw(i)); |
3731 | } |
3732 | } |
3733 | |
3734 | void BoundInterpreterFunction::fwdElementXorInst(const ElementXorInst *I) { |
3735 | auto lhsW = getWeightHandle<bool>(I->getLHS()); |
3736 | auto rhsW = getWeightHandle<bool>(I->getRHS()); |
3737 | auto outW = getWeightHandle<bool>(I->getDest()); |
3738 | for (size_t i = 0, e = outW.size(); i < e; ++i) { |
3739 | outW.raw(i) = (lhsW.raw(i) ^ rhsW.raw(i)); |
3740 | } |
3741 | } |
3742 | |
3743 | //===----------------------------------------------------------------------===// |
3744 | // Unary arithmetic operations |
3745 | //===----------------------------------------------------------------------===// |
3746 | template <typename ElemTy, typename InstKind> |
3747 | void BoundInterpreterFunction::fwdUnaryArithmeticImpl( |
3748 | const InstKind *I, std::function<float(float)> func) { |
3749 | Value *inpV = I->getSrc(); |
3750 | Value *outV = I->getDest(); |
3751 | auto inpTy = inpV->getType(); |
3752 | auto outTy = outV->getType(); |
3753 | auto inpH = getWeightHandle<ElemTy>(inpV); |
3754 | auto outH = getWeightHandle<ElemTy>(outV); |
3755 | |
3756 | if (inpTy->isQuantizedType()) { |
3757 | float inpScale = inpTy->getScale(); |
3758 | int32_t inpOffset = inpTy->getOffset(); |
3759 | float outScale = outTy->getScale(); |
3760 | int32_t outOffset = outTy->getOffset(); |
3761 | for (size_t i = 0, e = outH.size(); i < e; ++i) { |
3762 | float inpVal = |
3763 | quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset}); |
3764 | float outVal = func(inpVal); |
3765 | outH.raw(i) = |
3766 | quantization::quantize<ElemTy>(outVal, {outScale, outOffset}); |
3767 | } |
3768 | } else { |
3769 | for (size_t i = 0, e = outH.size(); i < e; ++i) { |
3770 | float inpVal = static_cast<float>(inpH.raw(i)); |
3771 | float outVal = func(inpVal); |
3772 | outH.raw(i) = static_cast<ElemTy>(outVal); |
3773 | } |
3774 | } |
3775 | } |
3776 | |
3777 | void BoundInterpreterFunction::fwdElementBitwiseNotInst( |
3778 | const ElementBitwiseNotInst *I) { |
3779 | auto func = [](int64_t i) -> int64_t { return ~i; }; |
3780 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3781 | } |
3782 | |
3783 | void BoundInterpreterFunction::fwdElementAbsInst(const ElementAbsInst *I) { |
3784 | auto func = [](float x) -> float { return std::abs(x); }; |
3785 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3786 | } |
3787 | |
3788 | void BoundInterpreterFunction::fwdElementNegInst(const ElementNegInst *I) { |
3789 | auto func = [](float x) -> float { return -x; }; |
3790 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3791 | } |
3792 | |
3793 | void BoundInterpreterFunction::fwdElementFloorInst(const ElementFloorInst *I) { |
3794 | auto func = [](float x) -> float { return std::floor(x); }; |
3795 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3796 | } |
3797 | |
3798 | void BoundInterpreterFunction::fwdElementSignInst(const ElementSignInst *I) { |
3799 | auto func = [](float x) -> float { return ((x > 0) - (x < 0)); }; |
3800 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3801 | } |
3802 | |
3803 | void BoundInterpreterFunction::fwdElementCeilInst(const ElementCeilInst *I) { |
3804 | auto func = [](float x) -> float { return std::ceil(x); }; |
3805 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3806 | } |
3807 | |
3808 | void BoundInterpreterFunction::fwdElementTruncateInst( |
3809 | const ElementTruncateInst *I) { |
3810 | auto func = [](float x) -> float { return std::trunc(x); }; |
3811 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3812 | } |
3813 | |
3814 | void BoundInterpreterFunction::fwdElementRoundInst(const ElementRoundInst *I) { |
3815 | // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which |
3816 | // rounds to nearest even integer those values with fractional part 0.5. |
3817 | auto func = [](float x) -> float { return std::nearbyintf(x); }; |
3818 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3819 | } |
3820 | |
3821 | void BoundInterpreterFunction::fwdElementSqrtInst(const ElementSqrtInst *I) { |
3822 | auto func = [](float x) -> float { return std::sqrt(x); }; |
3823 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3824 | } |
3825 | |
3826 | void BoundInterpreterFunction::fwdElementRsqrtInst(const ElementRsqrtInst *I) { |
3827 | auto func = [](float x) -> float { return 1 / std::sqrt(x); }; |
3828 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3829 | } |
3830 | |
3831 | void BoundInterpreterFunction::fwdElementReciprocalInst( |
3832 | const ElementReciprocalInst *I) { |
3833 | auto func = [](float x) -> float { return 1 / x; }; |
3834 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3835 | } |
3836 | |
3837 | void BoundInterpreterFunction::fwdElementSinInst(const ElementSinInst *I) { |
3838 | auto func = [](float x) -> float { return std::sin(x); }; |
3839 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3840 | } |
3841 | |
3842 | void BoundInterpreterFunction::fwdElementCosInst(const ElementCosInst *I) { |
3843 | auto func = [](float x) -> float { return std::cos(x); }; |
3844 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3845 | } |
3846 | |
3847 | void BoundInterpreterFunction::fwdElementErfInst(const ElementErfInst *I) { |
3848 | auto func = [](float x) -> float { return std::erf(x); }; |
3849 | dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func); |
3850 | } |
3851 | |
3852 | //===----------------------------------------------------------------------===// |
3853 | // Compare operations |
3854 | //===----------------------------------------------------------------------===// |
3855 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
3856 | typename CmpTy, typename InstCmpKind> |
3857 | void BoundInterpreterFunction::fwdElementCmpHelperImpl( |
3858 | const InstCmpKind *I, std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper) { |
3859 | Value *lhsV = I->getLHS(); |
3860 | Value *rhsV = I->getRHS(); |
3861 | Value *outV = I->getDest(); |
3862 | |
3863 | auto lhsH = getWeightHandle<ElemTy>(lhsV); |
3864 | auto rhsH = getWeightHandle<ElemTy>(rhsV); |
3865 | auto oH = getWeightHandle<bool>(outV); |
3866 | |
3867 | ElemScaleTy lhsScale = 1.0f; |
3868 | ElemScaleTy rhsScale = 1.0f; |
3869 | ElemOffsetTy lhsOffset = 0; |
3870 | ElemOffsetTy rhsOffset = 0; |
3871 | |
3872 | auto lhsTy = lhsV->getType(); |
3873 | auto rhsTy = rhsV->getType(); |
3874 | |
3875 | if (lhsV->getType()->isQuantizedType()) { |
3876 | lhsScale = lhsTy->getScale(); |
3877 | rhsScale = rhsTy->getScale(); |
3878 | |
3879 | lhsOffset = lhsTy->getOffset(); |
3880 | rhsOffset = rhsTy->getOffset(); |
3881 | } |
3882 | |
3883 | // For each layer in the batch: |
3884 | for (size_t i = 0, e = oH.size(); i < e; i++) { |
3885 | oH.raw(i) = cmpHelper(lhsScale * (lhsH.raw(i) - lhsOffset), |
3886 | rhsScale * (rhsH.raw(i) - rhsOffset)); |
3887 | } |
3888 | } |
3889 | |
3890 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
3891 | typename CmpTy> |
3892 | void BoundInterpreterFunction::fwdElementCmpLTEInstImpl( |
3893 | const ElementCmpLTEInst *I) { |
3894 | auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS <= RHS; }; |
3895 | fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy, |
3896 | ElementCmpLTEInst>(I, cmpHelper); |
3897 | } |
3898 | |
3899 | void BoundInterpreterFunction::fwdElementCmpLTEInst( |
3900 | const ElementCmpLTEInst *I) { |
3901 | auto *T = getTensor(I->getLHS()); |
3902 | |
3903 | if (T->getType().isQuantizedType()) { |
3904 | switch (T->getElementType()) { |
3905 | case ElemKind::Int8QTy: |
3906 | fwdElementCmpLTEInstImpl<int8_t, int32_t, float, int32_t>(I); |
3907 | break; |
3908 | case ElemKind::Int16QTy: |
3909 | fwdElementCmpLTEInstImpl<int16_t, int32_t, float, int32_t>(I); |
3910 | break; |
3911 | default: |
3912 | llvm_unreachable("Type is not supported" ); |
3913 | } |
3914 | return; |
3915 | } |
3916 | |
3917 | switch (T->getElementType()) { |
3918 | case ElemKind::FloatTy: |
3919 | fwdElementCmpLTEInstImpl<float, float, float>(I); |
3920 | break; |
3921 | case ElemKind::Float16Ty: |
3922 | fwdElementCmpLTEInstImpl<float16_t, float16_t, float16_t>(I); |
3923 | break; |
3924 | case ElemKind::BFloat16Ty: |
3925 | fwdElementCmpLTEInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I); |
3926 | break; |
3927 | case ElemKind::Int32ITy: |
3928 | fwdElementCmpLTEInstImpl<int32_t, int32_t, float>(I); |
3929 | break; |
3930 | case ElemKind::Int64ITy: |
3931 | fwdElementCmpLTEInstImpl<int64_t, int64_t, float>(I); |
3932 | break; |
3933 | default: |
3934 | llvm_unreachable("Type is not supported" ); |
3935 | } |
3936 | } |
3937 | |
3938 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
3939 | typename CmpTy> |
3940 | void BoundInterpreterFunction::fwdElementCmpEQInstImpl( |
3941 | const ElementCmpEQInst *I) { |
3942 | auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS == RHS; }; |
3943 | fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy, |
3944 | ElementCmpEQInst>(I, cmpHelper); |
3945 | } |
3946 | |
3947 | void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { |
3948 | auto *T = getTensor(I->getLHS()); |
3949 | |
3950 | if (T->getType().isQuantizedType()) { |
3951 | switch (T->getElementType()) { |
3952 | case ElemKind::Int8QTy: |
3953 | fwdElementCmpEQInstImpl<int8_t, int32_t, float, int32_t>(I); |
3954 | break; |
3955 | case ElemKind::Int16QTy: |
3956 | fwdElementCmpEQInstImpl<int16_t, int32_t, float, int32_t>(I); |
3957 | break; |
3958 | default: |
3959 | llvm_unreachable("Type is not supported" ); |
3960 | } |
3961 | return; |
3962 | } |
3963 | |
3964 | switch (T->getElementType()) { |
3965 | case ElemKind::FloatTy: |
3966 | fwdElementCmpEQInstImpl<float, float, float>(I); |
3967 | break; |
3968 | case ElemKind::Float16Ty: |
3969 | fwdElementCmpEQInstImpl<float16_t, float16_t, float16_t>(I); |
3970 | break; |
3971 | case ElemKind::BFloat16Ty: |
3972 | fwdElementCmpEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I); |
3973 | break; |
3974 | case ElemKind::Int32ITy: |
3975 | fwdElementCmpEQInstImpl<int32_t, int32_t, float>(I); |
3976 | break; |
3977 | case ElemKind::Int64ITy: |
3978 | fwdElementCmpEQInstImpl<int64_t, int64_t, float>(I); |
3979 | break; |
3980 | default: |
3981 | llvm_unreachable("Type is not supported" ); |
3982 | } |
3983 | } |
3984 | |
3985 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
3986 | typename CmpTy> |
3987 | void BoundInterpreterFunction::fwdElementCmpNEQInstImpl( |
3988 | const ElementCmpNEQInst *I) { |
3989 | auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return !(LHS == RHS); }; |
3990 | fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy, |
3991 | ElementCmpNEQInst>(I, cmpHelper); |
3992 | } |
3993 | |
3994 | void BoundInterpreterFunction::fwdElementCmpNEQInst( |
3995 | const ElementCmpNEQInst *I) { |
3996 | auto *T = getTensor(I->getLHS()); |
3997 | |
3998 | if (T->getType().isQuantizedType()) { |
3999 | switch (T->getElementType()) { |
4000 | case ElemKind::Int8QTy: |
4001 | fwdElementCmpNEQInstImpl<int8_t, int32_t, float, int32_t>(I); |
4002 | break; |
4003 | case ElemKind::Int16QTy: |
4004 | fwdElementCmpNEQInstImpl<int16_t, int32_t, float, int32_t>(I); |
4005 | break; |
4006 | default: |
4007 | llvm_unreachable("Type is not supported" ); |
4008 | } |
4009 | return; |
4010 | } |
4011 | |
4012 | switch (T->getElementType()) { |
4013 | case ElemKind::FloatTy: |
4014 | fwdElementCmpNEQInstImpl<float, float, float>(I); |
4015 | break; |
4016 | case ElemKind::Float16Ty: |
4017 | fwdElementCmpNEQInstImpl<float16_t, float16_t, float16_t>(I); |
4018 | break; |
4019 | case ElemKind::BFloat16Ty: |
4020 | fwdElementCmpNEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I); |
4021 | break; |
4022 | case ElemKind::Int32ITy: |
4023 | fwdElementCmpNEQInstImpl<int32_t, int32_t, float>(I); |
4024 | break; |
4025 | case ElemKind::Int64ITy: |
4026 | fwdElementCmpNEQInstImpl<int64_t, int64_t, float>(I); |
4027 | break; |
4028 | default: |
4029 | llvm_unreachable("Type is not supported" ); |
4030 | } |
4031 | } |
4032 | |
4033 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
4034 | typename CmpTy> |
4035 | void BoundInterpreterFunction::fwdElementCmpLTInstImpl( |
4036 | const ElementCmpLTInst *I) { |
4037 | auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS < RHS; }; |
4038 | fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy, |
4039 | ElementCmpLTInst>(I, cmpHelper); |
4040 | } |
4041 | |
4042 | void BoundInterpreterFunction::fwdElementCmpLTInst(ElementCmpLTInst const *I) { |
4043 | auto *T = getTensor(I->getLHS()); |
4044 | if (T->getType().isQuantizedType()) { |
4045 | switch (T->getElementType()) { |
4046 | case ElemKind::Int8QTy: |
4047 | fwdElementCmpLTInstImpl<int8_t, int32_t, float, int32_t>(I); |
4048 | break; |
4049 | case ElemKind::Int16QTy: |
4050 | fwdElementCmpLTInstImpl<int16_t, int32_t, float, int32_t>(I); |
4051 | break; |
4052 | default: |
4053 | llvm_unreachable("Type is not supported" ); |
4054 | } |
4055 | return; |
4056 | } |
4057 | |
4058 | switch (T->getElementType()) { |
4059 | case ElemKind::FloatTy: |
4060 | fwdElementCmpLTInstImpl<float, float, float>(I); |
4061 | break; |
4062 | case ElemKind::Float16Ty: |
4063 | fwdElementCmpLTInstImpl<float16_t, float16_t, float16_t>(I); |
4064 | break; |
4065 | case ElemKind::BFloat16Ty: |
4066 | fwdElementCmpLTInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I); |
4067 | break; |
4068 | case ElemKind::Int32ITy: |
4069 | fwdElementCmpLTInstImpl<int32_t, int32_t, float>(I); |
4070 | break; |
4071 | case ElemKind::Int64ITy: |
4072 | fwdElementCmpLTInstImpl<int64_t, int64_t, float>(I); |
4073 | break; |
4074 | default: |
4075 | llvm_unreachable("Type is not supported" ); |
4076 | } |
4077 | } |
4078 | |
4079 | template <typename ElemTy> |
4080 | void BoundInterpreterFunction::fwdElementPowInstFloatImpl( |
4081 | const ElementPowInst *I) { |
4082 | staticAssertFloatingPointType(ElemTy); |
4083 | |
4084 | auto baseW = getWeightHandle<ElemTy>(I->getLHS()); |
4085 | auto expW = getWeightHandle<ElemTy>(I->getRHS()); |
4086 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4087 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
4088 | outW.raw(i) = ElemTy(pow(float(baseW.raw(i)), float(expW.raw(i)))); |
4089 | } |
4090 | } |
4091 | |
4092 | void BoundInterpreterFunction::fwdElementPowInstI8Impl( |
4093 | const ElementPowInst *I) { |
4094 | assert(getTensor(I->getLHS())->getType().isQuantizedType() && |
4095 | "Expect quantized type" ); |
4096 | auto baseTy = I->getLHS()->getType(); |
4097 | auto expTy = I->getRHS()->getType(); |
4098 | auto destTy = I->getDest()->getType(); |
4099 | |
4100 | float baseScale = baseTy->getScale(); |
4101 | int32_t baseOffset = baseTy->getOffset(); |
4102 | TensorQuantizationParams baseTQP{baseScale, baseOffset}; |
4103 | |
4104 | float expScale = expTy->getScale(); |
4105 | int32_t expOffset = expTy->getOffset(); |
4106 | TensorQuantizationParams expTQP{expScale, expOffset}; |
4107 | |
4108 | float destScale = destTy->getScale(); |
4109 | int32_t destOffset = destTy->getOffset(); |
4110 | TensorQuantizationParams destTQP{destScale, destOffset}; |
4111 | |
4112 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
4113 | auto baseW = getWeightHandle<int8_t>(I->getLHS()); |
4114 | auto expW = getWeightHandle<int8_t>(I->getRHS()); |
4115 | |
4116 | for (dim_t i = 0, e = outW.size(); i < e; i++) { |
4117 | float base = quantization::dequantize(baseW.raw(i), baseTQP); |
4118 | float exp = quantization::dequantize(expW.raw(i), expTQP); |
4119 | outW.raw(i) = quantization::quantize(std::pow(base, exp), destTQP); |
4120 | } |
4121 | } |
4122 | |
4123 | void BoundInterpreterFunction::fwdElementPowInst( |
4124 | const glow::ElementPowInst *I) { |
4125 | auto *T = getTensor(I->getLHS()); |
4126 | if (T->getType().isQuantizedType()) { |
4127 | fwdElementPowInstI8Impl(I); |
4128 | return; |
4129 | } |
4130 | |
4131 | dispatchFloatingPointImpl(fwdElementPowInstFloatImpl, |
4132 | I->getLHS()->getElementType(), I); |
4133 | } |
4134 | |
4135 | template <typename ElemTy> |
4136 | void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl( |
4137 | const ElementIsNaNInst *I) { |
4138 | staticAssertFloatingPointType(ElemTy); |
4139 | |
4140 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
4141 | auto outW = getWeightHandle<bool>(I->getDest()); |
4142 | for (size_t i = 0, e = inW.size(); i < e; i++) { |
4143 | float val = inW.raw(i); |
4144 | outW.raw(i) = std::isnan(val); |
4145 | } |
4146 | } |
4147 | |
4148 | void BoundInterpreterFunction::fwdElementIsNaNInst( |
4149 | const glow::ElementIsNaNInst *I) { |
4150 | dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl, |
4151 | I->getSrc()->getElementType(), I); |
4152 | } |
4153 | |
4154 | template <typename ElemTy> |
4155 | void BoundInterpreterFunction::fwdElementLogInstFloatImpl( |
4156 | const ElementLogInst *I) { |
4157 | staticAssertFloatingPointType(ElemTy); |
4158 | |
4159 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
4160 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4161 | for (size_t i = 0, e = inW.size(); i < e; i++) { |
4162 | float val = inW.raw(i); |
4163 | outW.raw(i) = ElemTy(log(val)); |
4164 | } |
4165 | } |
4166 | |
4167 | void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) { |
4168 | dispatchFloatingPointImpl(fwdElementLogInstFloatImpl, |
4169 | I->getSrc()->getElementType(), I); |
4170 | } |
4171 | |
4172 | template <typename ElemTy> |
4173 | void BoundInterpreterFunction::fwdElementExpInstFloatImpl( |
4174 | const ElementExpInst *I) { |
4175 | staticAssertFloatingPointType(ElemTy); |
4176 | |
4177 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
4178 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4179 | for (size_t i = 0, e = inW.size(); i < e; i++) { |
4180 | float val = inW.raw(i); |
4181 | outW.raw(i) = ElemTy(exp(val)); |
4182 | } |
4183 | } |
4184 | |
4185 | void BoundInterpreterFunction::fwdElementExpInst(const ElementExpInst *I) { |
4186 | dispatchFloatingPointImpl(fwdElementExpInstFloatImpl, |
4187 | I->getSrc()->getElementType(), I); |
4188 | } |
4189 | |
4190 | void BoundInterpreterFunction::fwdNonZeroInst(const NonZeroInst *I) { |
4191 | auto *T = getTensor(I->getDest()); |
4192 | T->zero(); |
4193 | auto outW = T->getHandle<int32_t>(); |
4194 | auto condW = getWeightHandle<bool>(I->getCond()); |
4195 | for (size_t condIdx = 0, outIdx = 0, n = condW.size(); condIdx < n; |
4196 | condIdx++) { |
4197 | if (condW.raw(condIdx)) { |
4198 | outW.raw(outIdx) = condIdx; |
4199 | outIdx++; |
4200 | } |
4201 | } |
4202 | } |
4203 | |
4204 | template <typename ElemTy> |
4205 | void BoundInterpreterFunction::fwdElementSelectInstFloatImpl( |
4206 | const glow::ElementSelectInst *I) { |
4207 | staticAssertFloatingPointType(ElemTy); |
4208 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4209 | auto condW = getWeightHandle<bool>(I->getCond()); |
4210 | auto lhsW = getWeightHandle<ElemTy>(I->getLHS()); |
4211 | auto rhsW = getWeightHandle<ElemTy>(I->getRHS()); |
4212 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
4213 | outW.raw(i) = condW.raw(i) ? lhsW.raw(i) : rhsW.raw(i); |
4214 | } |
4215 | } |
4216 | |
4217 | void BoundInterpreterFunction::fwdElementSelectInst( |
4218 | const glow::ElementSelectInst *I) { |
4219 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
4220 | auto destTy = I->getDest()->getType(); |
4221 | auto lhsTy = I->getLHS()->getType(); |
4222 | auto rhsTy = I->getRHS()->getType(); |
4223 | |
4224 | float destScale = destTy->getScale(); |
4225 | float lhsScale = lhsTy->getScale(); |
4226 | float rhsScale = rhsTy->getScale(); |
4227 | |
4228 | int32_t destOffset = destTy->getOffset(); |
4229 | int32_t lhsOffset = lhsTy->getOffset(); |
4230 | int32_t rhsOffset = rhsTy->getOffset(); |
4231 | |
4232 | auto outW = getWeightHandle<int8_t>(I->getDest()); |
4233 | auto condW = getWeightHandle<bool>(I->getCond()); |
4234 | auto lhsW = getWeightHandle<int8_t>(I->getLHS()); |
4235 | auto rhsW = getWeightHandle<int8_t>(I->getRHS()); |
4236 | for (size_t i = 0, e = outW.size(); i < e; i++) { |
4237 | float val = condW.raw(i) ? lhsScale * (lhsW.raw(i) - lhsOffset) |
4238 | : rhsScale * (rhsW.raw(i) - rhsOffset); |
4239 | int32_t q = std::round(val / destScale + destOffset); |
4240 | outW.raw(i) = quantization::clip<int32_t, int8_t>(q); |
4241 | } |
4242 | return; |
4243 | } |
4244 | |
4245 | dispatchFloatingPointImpl(fwdElementSelectInstFloatImpl, |
4246 | I->getLHS()->getElementType(), I); |
4247 | } |
4248 | |
4249 | template <typename ElemTy> |
4250 | void BoundInterpreterFunction::fwdModuloInstImpl(glow::ModuloInst const *I) { |
4251 | auto srcH = getTensor(I->getSrc())->getHandle<ElemTy>(); |
4252 | auto destH = getTensor(I->getDest())->getHandle<ElemTy>(); |
4253 | |
4254 | auto divisor = I->getDivisor(); |
4255 | auto signFollowDivisor = I->getSignFollowDivisor(); |
4256 | |
4257 | for (size_t i = 0, e = srcH.size(); i < e; i++) { |
4258 | auto res = srcH.raw(i) % divisor; |
4259 | if (signFollowDivisor && res < 0) { |
4260 | res += divisor; |
4261 | } |
4262 | destH.raw(i) = res; |
4263 | } |
4264 | } |
4265 | |
4266 | void BoundInterpreterFunction::fwdModuloInst(glow::ModuloInst const *I) { |
4267 | dispatchIndexTypeImpl(fwdModuloInstImpl, I->getSrc()->getElementType(), I); |
4268 | } |
4269 | |
4270 | ///=============== Trigonometric Operators=============== |
4271 | template <typename ElemTy, typename InstKind> |
4272 | void BoundInterpreterFunction::fwdUnaryTrigonometricImpl( |
4273 | const InstKind *I, std::function<float(float)> func) { |
4274 | Value *inpV = I->getSrc(); |
4275 | Value *outV = I->getDest(); |
4276 | auto inpTy = inpV->getType(); |
4277 | auto outTy = outV->getType(); |
4278 | auto inpH = getWeightHandle<ElemTy>(inpV); |
4279 | auto outH = getWeightHandle<ElemTy>(outV); |
4280 | |
4281 | if (inpTy->isQuantizedType()) { |
4282 | float inpScale = inpTy->getScale(); |
4283 | int32_t inpOffset = inpTy->getOffset(); |
4284 | float outScale = outTy->getScale(); |
4285 | int32_t outOffset = outTy->getOffset(); |
4286 | for (size_t i = 0, e = outH.size(); i < e; ++i) { |
4287 | float inpVal = |
4288 | quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset}); |
4289 | float outVal = func(inpVal); |
4290 | outH.raw(i) = |
4291 | quantization::quantize<ElemTy>(outVal, {outScale, outOffset}); |
4292 | } |
4293 | } else { |
4294 | for (size_t i = 0, e = outH.size(); i < e; ++i) { |
4295 | float inpVal = static_cast<float>(inpH.raw(i)); |
4296 | float outVal = func(inpVal); |
4297 | outH.raw(i) = static_cast<ElemTy>(outVal); |
4298 | } |
4299 | } |
4300 | } |
4301 | |
4302 | void BoundInterpreterFunction::fwdElementAcosInst(const ElementAcosInst *I) { |
4303 | auto func = [](float x) -> float { return std::acos(x); }; |
4304 | dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I, |
4305 | func); |
4306 | } |
4307 | |
4308 | void BoundInterpreterFunction::fwdElementAsinInst(const ElementAsinInst *I) { |
4309 | auto func = [](float x) -> float { return std::asin(x); }; |
4310 | dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I, |
4311 | func); |
4312 | } |
4313 | |
4314 | void BoundInterpreterFunction::fwdElementAtanInst(const ElementAtanInst *I) { |
4315 | auto func = [](float x) -> float { return std::atan(x); }; |
4316 | dispatchImpl(fwdUnaryTrigonometricImpl, I->getSrc()->getElementType(), I, |
4317 | func); |
4318 | } |
4319 | //===----------------------------------------------------------------------===// |
4320 | // Mat Mul |
4321 | //===----------------------------------------------------------------------===// |
4322 | template <typename ElemTy, typename AccumulatorTy> |
4323 | void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl( |
4324 | const glow::MatMulInst *I) { |
4325 | assert(getTensor(I->getLHS())->getType().isQuantizedType()); |
4326 | auto lhs = getWeightHandle<ElemTy>(I->getLHS()); |
4327 | auto rhs = getWeightHandle<ElemTy>(I->getRHS()); |
4328 | |
4329 | auto dest = getWeightHandle<ElemTy>(I->getDest()); |
4330 | |
4331 | auto destDim = dest.dims(); |
4332 | auto lhsDim = lhs.dims(); |
4333 | |
4334 | auto destTy = I->getDest()->getType(); |
4335 | auto lhsTy = I->getLHS()->getType(); |
4336 | auto rhsTy = I->getRHS()->getType(); |
4337 | |
4338 | dest.clear(0); |
4339 | |
4340 | // For matrix multiplication, if the offset is equal to zero the scale |
4341 | // is defined as the formula (L.scale * R.scale / D.scale). |
4342 | // In here we assume that the offset for all buffers is zero. |
4343 | float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale(); |
4344 | int32_t lhsOffset = lhsTy->getOffset(); |
4345 | int32_t rhsOffset = rhsTy->getOffset(); |
4346 | int32_t destOffset = destTy->getOffset(); |
4347 | |
4348 | // For each (x,y) in the destination matrix: |
4349 | for (dim_t x = 0; x < destDim[0]; x++) { |
4350 | for (dim_t y = 0; y < destDim[1]; y++) { |
4351 | |
4352 | // Perform DOT on the row an column. |
4353 | AccumulatorTy sum = 0; |
4354 | for (dim_t i = 0; i < lhsDim[1]; i++) { |
4355 | AccumulatorTy L = lhs.at({x, i}); |
4356 | AccumulatorTy R = rhs.at({i, y}); |
4357 | // We represent the element multiplication with offset as |
4358 | // (value - offset). |
4359 | sum += (L - lhsOffset) * (R - rhsOffset); |
4360 | } |
4361 | |
4362 | dest.at({x, y}) = quantization::clip<AccumulatorTy, ElemTy>( |
4363 | std::round(scale * sum + destOffset)); |
4364 | } |
4365 | } |
4366 | } |
4367 | |
4368 | template <typename ElemTy> |
4369 | void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { |
4370 | staticAssertFloatingPointType(ElemTy); |
4371 | |
4372 | auto lhs = getWeightHandle<ElemTy>(I->getLHS()); |
4373 | auto rhs = getWeightHandle<ElemTy>(I->getRHS()); |
4374 | auto dest = getWeightHandle<ElemTy>(I->getDest()); |
4375 | |
4376 | auto destDim = dest.dims(); |
4377 | auto lhsDim = lhs.dims(); |
4378 | |
4379 | dest.clear(0); |
4380 | |
4381 | // For each (x,y) in the destination matrix: |
4382 | for (dim_t x = 0; x < destDim[0]; x++) { |
4383 | for (dim_t y = 0; y < destDim[1]; y++) { |
4384 | |
4385 | // Perform DOT on the row an column. |
4386 | float sum = 0; |
4387 | for (dim_t i = 0; i < lhsDim[1]; i++) { |
4388 | sum += float(lhs.at({x, i}) * rhs.at({i, y})); |
4389 | } |
4390 | dest.at({x, y}) = ElemTy(sum); |
4391 | } |
4392 | } |
4393 | } |
4394 | |
4395 | template <typename ElemTy> |
4396 | void BoundInterpreterFunction::fwdBatchMatMulInstFloatImpl( |
4397 | const BatchMatMulInst *I) { |
4398 | staticAssertFloatingPointType(ElemTy); |
4399 | |
4400 | auto lhs = getWeightHandle<ElemTy>(I->getLHS()); |
4401 | auto rhs = getWeightHandle<ElemTy>(I->getRHS()); |
4402 | auto dest = getWeightHandle<ElemTy>(I->getDest()); |
4403 | |
4404 | auto destDim = dest.dims(); |
4405 | auto lhsDim = lhs.dims(); |
4406 | |
4407 | dest.clear(0); |
4408 | |
4409 | for (dim_t batch = 0; batch < destDim[0]; batch++) { |
4410 | // For each (x,y) in the destination matrix: |
4411 | for (dim_t x = 0; x < destDim[1]; x++) { |
4412 | for (dim_t y = 0; y < destDim[2]; y++) { |
4413 | // Perform DOT on the row an column. |
4414 | float sum = 0; |
4415 | for (dim_t i = 0; i < lhsDim[2]; i++) { |
4416 | sum += float(lhs.at({batch, x, i})) * float(rhs.at({batch, i, y})); |
4417 | } |
4418 | dest.at({batch, x, y}) = ElemTy(sum); |
4419 | } |
4420 | } |
4421 | } |
4422 | } |
4423 | |
4424 | void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { |
4425 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
4426 | dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl, |
4427 | I->getLHS()->getElementType(), I); |
4428 | return; |
4429 | } |
4430 | |
4431 | dispatchFloatingPointImpl(fwdMatMulInstFloatImpl, |
4432 | I->getLHS()->getElementType(), I); |
4433 | } |
4434 | |
4435 | void BoundInterpreterFunction::fwdBatchMatMulInst( |
4436 | const glow::BatchMatMulInst *I) { |
4437 | if (getTensor(I->getLHS())->getType().isQuantizedType()) { |
4438 | DCHECK(!"Quantized implementation for BatchMatmul not supported yet." ); |
4439 | return; |
4440 | } |
4441 | |
4442 | dispatchFloatingPointImpl(fwdBatchMatMulInstFloatImpl, |
4443 | I->getLHS()->getElementType(), I); |
4444 | } |
4445 | |
4446 | void BoundInterpreterFunction::fwdReluGradInst(const glow::ReluGradInst *I) { |
4447 | DCHECK(!"Found ReluGradInst but ReluGrad is lowered on Interpreter" ); |
4448 | } |
4449 | |
4450 | //===----------------------------------------------------------------------===// |
4451 | // FC |
4452 | //===----------------------------------------------------------------------===// |
4453 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
4454 | void BoundInterpreterFunction::fwdFullyConnectedInstQuantizedImpl( |
4455 | const glow::FullyConnectedInst *I) { |
4456 | assert(getTensor(I->getSrc())->getType().isQuantizedType()); |
4457 | |
4458 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
4459 | auto weightsW = getWeightHandle<ElemTy>(I->getWeights()); |
4460 | auto biasW = getWeightHandle<BiasElemTy>(I->getBias()); |
4461 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4462 | |
4463 | auto inTy = inW.getType(); |
4464 | auto weightsTy = weightsW.getType(); |
4465 | auto biasTy = biasW.getType(); |
4466 | auto outTy = outW.getType(); |
4467 | |
4468 | int32_t inOffset = inTy.getOffset(); |
4469 | int32_t weightsOffset = weightsTy.getOffset(); |
4470 | int32_t biasOffset = biasTy.getOffset(); |
4471 | int32_t outOffset = outTy.getOffset(); |
4472 | |
4473 | float outScale = outTy.getScale(); |
4474 | float weightsScale = weightsTy.getScale(); |
4475 | float biasScale = biasTy.getScale(); |
4476 | float inScale = inTy.getScale(); |
4477 | |
4478 | ShapeHW idim(inW.dims()); |
4479 | ShapeHW odim(outW.dims()); |
4480 | |
4481 | // Calculate the scale of the values that come out of the matrix |
4482 | // multiplication part of the calculation. |
4483 | float matMulScale = weightsScale * inScale; |
4484 | |
4485 | outW.clear(0); |
4486 | |
4487 | for (dim_t i = 0; i < idim.height; i++) { |
4488 | for (dim_t j = 0; j < odim.width; j++) { |
4489 | AccumulatorTy sum = 0; |
4490 | for (dim_t k = 0; k < idim.width; k++) { |
4491 | AccumulatorTy W = weightsW.at({k, j}); |
4492 | AccumulatorTy A = inW.at({i, k}); |
4493 | sum += (W - weightsOffset) * (A - inOffset); |
4494 | } |
4495 | |
4496 | // Scale the bias to match the scale of the matrix multiplication. |
4497 | AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) * |
4498 | (biasScale / matMulScale)); |
4499 | |
4500 | // Add the bias. |
4501 | sum += B; |
4502 | |
4503 | // Scale the result back to the expected destination scale. |
4504 | outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>( |
4505 | std::round(float(sum) * (matMulScale / outScale)) + outOffset); |
4506 | } |
4507 | } |
4508 | } |
4509 | |
4510 | template <typename ElemTy> |
4511 | void BoundInterpreterFunction::fwdFullyConnectedInstFloatImpl( |
4512 | const FullyConnectedInst *I) { |
4513 | staticAssertFloatingPointType(ElemTy); |
4514 | |
4515 | auto inW = getWeightHandle<ElemTy>(I->getSrc()); |
4516 | auto weightsW = getWeightHandle<ElemTy>(I->getWeights()); |
4517 | auto biasW = getWeightHandle<ElemTy>(I->getBias()); |
4518 | auto outW = getWeightHandle<ElemTy>(I->getDest()); |
4519 | |
4520 | ShapeHW idim(inW.dims()); |
4521 | ShapeHW odim(outW.dims()); |
4522 | |
4523 | outW.clear(0); |
4524 | |
4525 | for (dim_t i = 0; i < idim.height; i++) { |
4526 | for (dim_t j = 0; j < odim.width; j++) { |
4527 | float sum = 0; |
4528 | for (dim_t k = 0; k < idim.width; k++) { |
4529 | sum += float(inW.at({i, k})) * float(weightsW.at({k, j})); |
4530 | } |
4531 | |
4532 | outW.at({i, j}) = sum + float(biasW.at({j})); |
4533 | } |
4534 | } |
4535 | } |
4536 | |
4537 | void BoundInterpreterFunction::fwdFullyConnectedInst( |
4538 | const glow::FullyConnectedInst *I) { |
4539 | |
4540 | if (getTensor(I->getSrc())->getType().isQuantizedType()) { |
4541 | dispatchQuantizedWithAccumulationAndBiasImpl( |
4542 | fwdFullyConnectedInstQuantizedImpl, I->getSrc()->getElementType(), |
4543 | I->getBias()->getElementType(), I); |
4544 | return; |
4545 | } else { |
4546 | dispatchFloatingPointImpl(fwdFullyConnectedInstFloatImpl, |
4547 | I->getSrc()->getElementType(), I); |
4548 | } |
4549 | } |
4550 | |
4551 | //===----------------------------------------------------------------------===// |
4552 | // Dynamic quantized FC |
4553 | //===----------------------------------------------------------------------===// |
4554 | |
4555 | template <typename ElemTy, typename OutputTy, typename AccumulatorTy> |
4556 | void BoundInterpreterFunction::fwdDynRowwiseQuantizedFullyConnectedInstImpl( |
4557 | Handle<ElemTy> inW, Handle<OutputTy> &outW, dim_t baseRow, |
4558 | Handle<ElemTy> weightsW, Handle<float> biasW, Handle<float> scalesW, |
4559 | Handle<int32_t> offsetsW) { |
4560 | ShapeHW idim(inW.dims()); |
4561 | ShapeHW odim(outW.dims()); |
4562 | auto inTy = inW.getType(); |
4563 | int32_t inOffset = inTy.getOffset(); |
4564 | float inScale = inTy.getScale(); |
4565 | |
4566 | for (dim_t i = 0; i < idim.height; i++) { |
4567 | for (dim_t j = 0; j < odim.width; j++) { |
4568 | float matMulScale = inScale * static_cast<float>(scalesW.raw(j)); |
4569 | AccumulatorTy sum = 0; |
4570 | for (dim_t k = 0; k < idim.width; k++) { |
4571 | AccumulatorTy W = weightsW.at({k, j}); |
4572 | AccumulatorTy A = inW.at({i, k}); |
4573 | sum += (W - offsetsW.raw(j)) * (A - inOffset); |
4574 | } |
4575 | |
4576 | float B = float(biasW.at({j})); |
4577 | |
4578 | // Scale the result back to the expected destination scale and add the |
4579 | // bias. |
4580 | outW.at({baseRow + i, j}) = float(sum) * matMulScale + B; |
4581 | } |
4582 | } |
4583 | } |
4584 | |
4585 | void BoundInterpreterFunction::fwdDynRowwiseQuantizedFullyConnectedInstPreimpl( |
4586 | Tensor *inputTensor, Tensor *weightsTensor, Tensor *biasTensor, |
4587 | Tensor *resultTensor, Tensor *wScaleTensor, Tensor *wOffsetTensor, |
4588 | bool isSymmetric, bool isPerBatchElement) { |
4589 | |
4590 | /* Check the options */ |
4591 | assert(isSymmetric && "Only symmetric quantization is supported." ); |
4592 | assert(isPerBatchElement && "Only quantized per batch element is supported." ); |
4593 | |
4594 | auto weightsW = weightsTensor->getHandle<int8_t>(); |
4595 | |
4596 | /* Dynamic Quantization */ |
4597 | auto resultHandle = resultTensor->getHandle<float16_t>(); |
4598 | auto offsetsW = wOffsetTensor->getHandle<int32_t>(); |
4599 | dim_t N = inputTensor->dims()[0]; |
4600 | dim_t L = inputTensor->dims()[1]; |
4601 | if (isPerBatchElement && isSymmetric) { |
4602 | // We slice N * L input tensor to N tensors with 1 * L shape, |
4603 | // For each batch we calculate qparams, quantize, FC and dequantize |
4604 | // independently, and finally splice them together. |
4605 | for (dim_t i = 0; i < N; i++) { |
4606 | Tensor slicedInputTensor = inputTensor->getOwnedSlice({1, L}, {i, 0}); |
4607 | auto slicedInputHandle = slicedInputTensor.getHandle<float16_t>(); |
4608 | auto minMax = slicedInputHandle.minMaxArg(); |
4609 | auto qMin = slicedInputHandle.raw(minMax.first); |
4610 | auto qMax = slicedInputHandle.raw(minMax.second); |
4611 | |
4612 | // TODO Currently we only support symmetric quantization. |
4613 | // We should support both symmetric/asymmetric based of isSymmetric. |
4614 | auto qParams = quantization::chooseQuantizationParams( |
4615 | {qMin, qMax}, quantization::Schema::Symmetric, ElemKind::Int8QTy); |
4616 | Tensor qInputTensor = quantization::quantizeTensor( |
4617 | slicedInputTensor, {qParams.scale, qParams.offset}, |
4618 | ElemKind::Int8QTy); |
4619 | auto inW = qInputTensor.getHandle<int8_t>(); |
4620 | |
4621 | auto biasW = biasTensor->getHandle<float>(); |
4622 | auto scalesW = wScaleTensor->getHandle<float>(); |
4623 | fwdDynRowwiseQuantizedFullyConnectedInstImpl<int8_t, float16_t, int32_t>( |
4624 | inW, resultHandle, i, weightsW, biasW, scalesW, offsetsW); |
4625 | } |
4626 | } |
4627 | } |
4628 | |
4629 | void BoundInterpreterFunction::fwdDynamicRowwiseQuantizedFullyConnectedInst( |
4630 | const glow::DynamicRowwiseQuantizedFullyConnectedInst *I) { |
4631 | auto *inputTensor = getTensor(I->getSrc()); |
4632 | auto *weightsTensor = getTensor(I->getWeights()); |
4633 | auto *biasTensor = getTensor(I->getBias()); |
4634 | auto *resultTensor = getTensor(I->getDest()); |
4635 | auto *scaleTensor = getTensor(I->getScales()); |
4636 | auto *offsetTensor = getTensor(I->getOffsets()); |
4637 | auto isSymmetric = I->getIsSymmetric(); |
4638 | auto isPerBatchElement = I->getIsPerBatchElement(); |
4639 | fwdDynRowwiseQuantizedFullyConnectedInstPreimpl( |
4640 | inputTensor, weightsTensor, biasTensor, resultTensor, scaleTensor, |
4641 | offsetTensor, isSymmetric, isPerBatchElement); |
4642 | } |
4643 | |
4644 | void BoundInterpreterFunction::fwdDynamicQuantizedFullyConnectedInst( |
4645 | const glow::DynamicQuantizedFullyConnectedInst *I) { |
4646 | |
4647 | auto *inputTensor = getTensor(I->getSrc()); |
4648 | auto *weightsTensor = getTensor(I->getWeights()); |
4649 | auto *biasTensor = getTensor(I->getBias()); |
4650 | auto *resultTensor = getTensor(I->getDest()); |
4651 | auto isSymmetric = I->getIsSymmetric(); |
4652 | auto isPerBatchElement = I->getIsPerBatchElement(); |
4653 | dim_t M = resultTensor->dims()[1]; |
4654 | |
4655 | // Calc channelwise QParam |
4656 | Tensor scaleTensor = Tensor(ElemKind::FloatTy, {M}); |
4657 | Tensor offsetTensor = Tensor(ElemKind::Int32ITy, {M}); |
4658 | auto scalesW = scaleTensor.getHandle<float>(); |
4659 | auto offsetsW = offsetTensor.getHandle<int32_t>(); |
4660 | auto weightsW = weightsTensor->getHandle<int8_t>(); |
4661 | for (int i = 0; i < M; i++) { |
4662 | scalesW.raw(i) = weightsW.getType().getScale(); |
4663 | offsetsW.raw(i) = weightsW.getType().getOffset(); |
4664 | } |
4665 | fwdDynRowwiseQuantizedFullyConnectedInstPreimpl( |
4666 | inputTensor, weightsTensor, biasTensor, resultTensor, &scaleTensor, |
4667 | &offsetTensor, isSymmetric, isPerBatchElement); |
4668 | } |
4669 | |
4670 | //===----------------------------------------------------------------------===// |
4671 | // Row-wise quantized FC |
4672 | //===----------------------------------------------------------------------===// |
4673 | template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy> |
4674 | void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInstImpl( |
4675 | Value *inV, Value *outV, Value *weightsV, Value *biasV, Value *scalesV, |
4676 | Value *offsetsV) { |
4677 | auto inW = getWeightHandle<ElemTy>(inV); |
4678 | auto outW = getWeightHandle<ElemTy>(outV); |
4679 | auto weightsW = getWeightHandle<ElemTy>(weightsV); |
4680 | auto biasW = getWeightHandle<BiasElemTy>(biasV); |
4681 | auto scalesW = getWeightHandle<float>(scalesV); |
4682 | auto offsetsW = getWeightHandle<int32_t>(offsetsV); |
4683 | ShapeHW idim(inW.dims()); |
4684 | ShapeHW odim(outW.dims()); |
4685 | auto inTy = inW.getType(); |
4686 | auto biasTy = biasW.getType(); |
4687 | auto outTy = outW.getType(); |
4688 | int32_t outOffset = outTy.getOffset(); |
4689 | int32_t inOffset = inTy.getOffset(); |
4690 | int32_t biasOffset = biasTy.getOffset(); |
4691 | float outScale = outTy.getScale(); |
4692 | float inScale = inTy.getScale(); |
4693 | float biasScale = biasTy.getScale(); |
4694 | |
4695 | for (dim_t i = 0; i < idim.height; i++) { |
4696 | for (dim_t j = 0; j < odim.width; j++) { |
4697 | float matMulScale = scalesW.raw(j) * inScale; |
4698 | AccumulatorTy sum = 0; |
4699 | for (dim_t k = 0; k < idim.width; k++) { |
4700 | AccumulatorTy W = weightsW.at({j, k}); |
4701 | AccumulatorTy A = inW.at({i, k}); |
4702 | sum += (W - offsetsW.raw(j)) * (A - inOffset); |
4703 | } |
4704 | |
4705 | // Scale the bias to match the scale of the matrix multiplication. |
4706 | AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) * |
4707 | (biasScale / matMulScale)); |
4708 | |
4709 | // Add the bias. |
4710 | sum += B; |
4711 | |
4712 | // Scale the result back to the expected destination scale. |
4713 | outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>( |
4714 | std::round(float(sum) * (matMulScale / outScale) + outOffset)); |
4715 | } |
4716 | } |
4717 | } |
4718 | |
4719 | void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst( |
4720 | const RowwiseQuantizedFullyConnectedInst *I) { |
4721 | dispatchQuantizedWithAccumulationAndBiasImpl( |
4722 | fwdRowwiseQuantizedFullyConnectedInstImpl, I->getSrc()->getElementType(), |
4723 | I->getBias()->getElementType(), I->getSrc(), I->getDest(), |
4724 | I->getWeights(), I->getBias(), I->getScales(), I->getOffsets()); |
4725 | } |
4726 | |
4727 | //===----------------------------------------------------------------------===// |
4728 | // Batched operations |
4729 | //===----------------------------------------------------------------------===// |
4730 | template <typename ElemTy, typename AccumulatorTy, typename SliceElemTy> |
4731 | static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) { |
4732 | auto batchH = batch->getHandle<ElemTy>(); |
4733 | auto sliceH = slice->getHandle<SliceElemTy>(); |
4734 | auto destH = dest->getHandle<ElemTy>(); |
4735 | |
4736 | auto batchTy = batch->getType(); |
4737 | auto sliceTy = slice->getType(); |
4738 | auto destTy = dest->getType(); |
4739 | |
4740 | float sliceScale = sliceTy.getScale(); |
4741 | float batchScale = batchTy.getScale(); |
4742 | float destScale = destTy.getScale(); |
4743 | |
4744 | int32_t sliceOffset = sliceTy.getOffset(); |
4745 | int32_t batchOffset = batchTy.getOffset(); |
4746 | int32_t destOffset = destTy.getOffset(); |
4747 | |
4748 | auto bdim = flattenCdr(batchH.dims()); |
4749 | assert(sliceH.size() == bdim.second && "Invalid slice size" ); |
4750 | assert(batchH.dims().drop_front() == sliceH.dims() && "Invalid batch size" ); |
4751 | |
4752 | // For each layer in the batch: |
4753 | for (dim_t n = 0; n < bdim.first; n++) { |
4754 | size_t base = batchH.getElementPtr({n}); |
4755 | |
4756 | // For each element in the slice. |
4757 | for (dim_t i = 0; i < bdim.second; i++) { |
4758 | AccumulatorTy batchVal = batchH.raw(base + i); |
4759 | AccumulatorTy sliceVal = sliceH.raw(i); |
4760 | // We increase the size of the integer up to 16 bits for more accurate |
4761 | // arithmetic. |
4762 | const float largeScale = float(1) / (1 << 15); |
4763 | // Scale both sides from 8-bit to 16-bits. |
4764 | AccumulatorTy B = |
4765 | std::round(float(batchVal - batchOffset) * (batchScale / largeScale)); |
4766 | AccumulatorTy S = |
4767 | std::round(float(sliceVal - sliceOffset) * (sliceScale / largeScale)); |
4768 | AccumulatorTy R = B + S; |
4769 | destH.raw(base + i) = quantization::clip<AccumulatorTy, ElemTy>( |
4770 | std::round(float(R) * (largeScale / destScale) + destOffset)); |
4771 | } |
4772 | } |
4773 | } |
4774 | |
4775 | template <typename ElemTy> |
4776 | void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl( |
4777 | const glow::BatchedAddInst *I) { |
4778 | staticAssertFloatingPointType(ElemTy); |
4779 | |
4780 | auto batch = getWeightHandle<ElemTy>(I->getBatch()); |
4781 | auto slice = getWeightHandle<ElemTy>(I->getSlice()); |
4782 | auto dest = getWeightHandle<ElemTy>(I->getDest()); |
4783 | |
4784 | auto bdim = flattenCdr(batch.dims()); |
4785 | assert(slice.size() == bdim.second && "Invalid slice size" ); |
4786 | assert(batch.dims().drop_front() == slice.dims() && "Invalid batch size" ); |
4787 | |
4788 | // For each layer in the batch: |
4789 | for (dim_t n = 0; n < bdim.first; n++) { |
4790 | size_t base = batch.getElementPtr({n}); |
4791 | |
4792 | // For each element in the slice. |
4793 | for (dim_t i = 0; i < bdim.second; i++) { |
4794 | dest.raw(base + i) = batch.raw(base + i) + slice.raw(i); |
4795 | } |
4796 | } |
4797 | } |
4798 | |
4799 | void BoundInterpreterFunction::fwdBatchedAddInst( |
4800 | const glow::BatchedAddInst *I) { |
4801 | if (getTensor(I->getBatch())->getType().isQuantizedType()) { |
4802 | dispatchQuantizedWithAccumulationAndBiasImpl( |
4803 | fwdBatchedAdd, I->getBatch()->getElementType(), |
4804 | I->getSlice()->getElementType(), getTensor(I->getBatch()), |
4805 | getTensor(I->getSlice()), getTensor(I->getDest())); |
4806 | return; |
4807 | } |
4808 | dispatchFloatingPointImpl(fwdBatchedAddInstFloatImpl, |
4809 | I->getBatch()->getElementType(), I); |
4810 | } |
4811 | |
4812 | // Macro to define the ReduceAdd/Prod kernel implementation. |
4813 | #define DEFINE_REDUCEADDPROD_INST_IMPL(func, init, op, inst) \ |
4814 | template <typename ElemTy> \ |
4815 | void BoundInterpreterFunction::fwdBatched##func##inst( \ |
4816 | Value *batch, Value *dest, unsigned_t axis, \ |
4817 | const ShapeVector &eBatchDims, const ShapeVector &eDestDims) { \ |
4818 | /*Get unowned handles of the batch and dest with these new expanded \ |
4819 | * dims.*/ \ |
4820 | auto eBatch = getTensor(batch)->getUnowned(eBatchDims); \ |
4821 | auto eDest = getTensor(dest)->getUnowned(eDestDims); \ |
4822 | auto eBatchH = eBatch.getHandle<ElemTy>(); \ |
4823 | auto eDestH = eDest.getHandle<ElemTy>(); \ |
4824 | eDestH.clear(init); \ |
4825 | \ |
4826 | /* We can use this loop for all shapes. Use the same indices for both the \ |
4827 | * batch and dest, except for setting the axis index in the dest to 0.*/ \ |
4828 | for (dim_t x = 0; x < eBatchDims[0]; x++) { \ |
4829 | for (dim_t y = 0; y < eBatchDims[1]; y++) { \ |
4830 | for (dim_t z = 0; z < eBatchDims[2]; z++) { \ |
4831 | for (dim_t w = 0; w < eBatchDims[3]; w++) { \ |
4832 | for (dim_t q = 0; q < eBatchDims[4]; q++) { \ |
4833 | for (dim_t r = 0; r < eBatchDims[5]; r++) { \ |
4834 | dim_t destIndices[] = {x, y, z, w, q, r}; \ |
4835 | destIndices[axis] = 0; \ |
4836 | eDestH.at(destIndices) = \ |
4837 | eDestH.at(destIndices) op eBatchH.at({x, y, z, w, q, r}); \ |
4838 | } \ |
4839 | } \ |
4840 | } \ |
4841 | } \ |
4842 | } \ |
4843 | } \ |
4844 | } |
4845 | |
4846 | /// Define fwdBatchedReduceAddInstImpl |
4847 | DEFINE_REDUCEADDPROD_INST_IMPL(ReduceAdd, 0, +, InstImpl) |
4848 | |
4849 | /// Define fwdBatchedReduceAddInstImpl |
4850 | DEFINE_REDUCEADDPROD_INST_IMPL(ReduceProd, 1, *, InstFloatImpl) |
4851 | |
4852 | #undef DEFINE_REDUCEADDPROD_INST_IMPL |
4853 | |
4854 | void BoundInterpreterFunction::fwdBatchedReduceAddInst( |
4855 | const glow::BatchedReduceAddInst *I) { |
4856 | static_assert(max_tensor_dimensions == 6, |
4857 | "Loops below assume max_tensor_dimensions = 6." ); |
4858 | |
4859 | auto *batch = I->getBatch(); |
4860 | auto *dest = I->getDest(); |
4861 | const auto axis = I->getAxis(); |
4862 | |
4863 | // Initialize both expanded batch and dest dims to the expanded batch |
4864 | // dims. This allows us below to iterate over the tensor regardless of its |
4865 | // shape using max_tensor_dimensions loops below. |
4866 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); |
4867 | ShapeVector eDestDims = eBatchDims; |
4868 | |
4869 | // Set the destination axis dimension (the one we are reducing) to 1. |
4870 | eDestDims[axis] = 1; |
4871 | |
4872 | if (getTensor(batch)->getType().isQuantizedType()) { |
4873 | auto destTy = dest->getType(); |
4874 | auto batchTy = batch->getType(); |
4875 | |
4876 | float destScale = destTy->getScale(); |
4877 | float batchScale = batchTy->getScale(); |
4878 | |
4879 | int32_t destOffset = destTy->getOffset(); |
4880 | int32_t batchOffset = batchTy->getOffset(); |
4881 | |
4882 | // Get unowned handles of the batch and dest with these new expanded dims. |
4883 | auto eBatch = getTensor(batch)->getUnowned(eBatchDims); |
4884 | auto eDest = getTensor(dest)->getUnowned(eDestDims); |
4885 | auto eBatchH = eBatch.getHandle<int8_t>(); |
4886 | auto eDestH = eDest.getHandle<int8_t>(); |
4887 | eDestH.clear(); |
4888 | |
4889 | // For quantization, we must accumulate in the inner-most loop into a |
4890 | // local float and then clip the result back into the dest tensor. Here |
4891 | // are the max_tensor_dimensions cases for this, to ensure the axis is |
4892 | // used as the inner-most loop. |
4893 | switch (axis) { |
4894 | #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \ |
4895 | case _D5_AXIS: \ |
4896 | for (dim_t i##_D0 = 0; i##_D0 < eBatchDims[_D0]; i##_D0++) \ |
4897 | for (dim_t i##_D1 = 0; i##_D1 < eBatchDims[_D1]; i##_D1++) \ |
4898 | for (dim_t i##_D2 = 0; i##_D2 < eBatchDims[_D2]; i##_D2++) \ |
4899 | for (dim_t i##_D3 = 0; i##_D3 < eBatchDims[_D3]; i##_D3++) \ |
4900 | for (dim_t i##_D4 = 0; i##_D4 < eBatchDims[_D4]; i##_D4++) { \ |
4901 | float sum = 0.0; \ |
4902 | for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < eBatchDims[_D5_AXIS]; \ |
4903 | i##_D5_AXIS++) { \ |
4904 | sum += eBatchH.at({i0, i1, i2, i3, i4, i5}) - batchOffset; \ |
4905 | } \ |
4906 | dim_t i##_D5_AXIS = 0; \ |
4907 | int32_t res = \ |
4908 | std::round(sum * batchScale / destScale) + destOffset; \ |
4909 | eDestH.at({i0, i1, i2, i3, i4, i5}) = \ |
4910 | quantization::clip<int32_t, int8_t>(res); \ |
4911 | } \ |
4912 | return; |
4913 | |
4914 | // Each loop order, with the inner-most dimension/index equal to the |
4915 | // axis. |
4916 | LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0); |
4917 | LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1); |
4918 | LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2); |
4919 | LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3); |
4920 | LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4); |
4921 | LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5); |
4922 | #undef LOOP_AXIS_CASE |
4923 | default: |
4924 | llvm_unreachable("Axis should be less than max_tensor_dimensions." ); |
4925 | } |
4926 | } |
4927 | dispatchFloatingPointAndInt32Impl(fwdBatchedReduceAddInstImpl, |
4928 | batch->getElementType(), batch, dest, axis, |
4929 | eBatchDims, eDestDims); |
4930 | } |
4931 | |
4932 | void BoundInterpreterFunction::fwdBatchedReduceProdInst( |
4933 | const glow::BatchedReduceProdInst *I) { |
4934 | static_assert(max_tensor_dimensions == 6, |
4935 | "Loops below assume max_tensor_dimensions = 6." ); |
4936 | |
4937 | auto *batch = I->getBatch(); |
4938 | auto *dest = I->getDest(); |
4939 | const auto axis = I->getAxis(); |
4940 | |
4941 | // Initialize both expanded batch and dest dims to the expanded batch |
4942 | // dims. This allows us below to iterate over the tensor regardless of its |
4943 | // shape using max_tensor_dimensions loops below. |
4944 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); |
4945 | ShapeVector eDestDims = eBatchDims; |
4946 | |
4947 | // Set the destination axis dimension (the one we are reducing) to 1. |
4948 | eDestDims[axis] = 1; |
4949 | |
4950 | assert(!batch->getType()->isQuantizedType() && |
4951 | "Quantized implementation for ReduceProd not supported yet." ); |
4952 | |
4953 | dispatchArithmeticImpl(fwdBatchedReduceProdInstFloatImpl, |
4954 | batch->getElementType(), batch, dest, axis, eBatchDims, |
4955 | eDestDims); |
4956 | } |
4957 | |
4958 | /// Macro to define ReduceMin/Max kernel implementation. |
4959 | #define DEFINE_REDUCEMINMAX_INST_IMPL(func, compare) \ |
4960 | template <typename ElemTy> \ |
4961 | void BoundInterpreterFunction::fwdBatched##func##InstImpl( \ |
4962 | Value *batch, Value *dest, const ShapeVector &eBatchDims, \ |
4963 | const ShapeVector &eDestDims, ElemTy init) { \ |
4964 | static_assert(max_tensor_dimensions == 6, \ |
4965 | "Loops below assume max_tensor_dimensions = 6."); \ |
4966 | /* Get unowned handles of the batch and dest with these new expanded \ |
4967 | * dims.*/ \ |
4968 | auto eBatch = getTensor(batch)->getUnowned(eBatchDims); \ |
4969 | auto eDest = getTensor(dest)->getUnowned(eDestDims); \ |
4970 | auto eBatchH = eBatch.getHandle<ElemTy>(); \ |
4971 | auto eDestH = eDest.getHandle<ElemTy>(); \ |
4972 | eDestH.clear(init); \ |
4973 | \ |
4974 | unsigned int axes[max_tensor_dimensions]; \ |
4975 | for (dim_t i = 0; i < max_tensor_dimensions; i++) { \ |
4976 | axes[i] = (eDestDims[i] > 1); \ |
4977 | } \ |
4978 | \ |
4979 | /* We can use this loop for all shapes. Use the same indices for both the \ |
4980 | * batch and dest, except for setting the axis index in the dest to 0.*/ \ |
4981 | for (dim_t x = 0, dx = 0; x < eBatchDims[0]; x++, dx += axes[0]) { \ |
4982 | for (dim_t y = 0, dy = 0; y < eBatchDims[1]; y++, dy += axes[1]) { \ |
4983 | for (dim_t z = 0, dz = 0; z < eBatchDims[2]; z++, dz += axes[2]) { \ |
4984 | for (dim_t w = 0, dw = 0; w < eBatchDims[3]; w++, dw += axes[3]) { \ |
4985 | for (dim_t q = 0, dq = 0; q < eBatchDims[4]; q++, dq += axes[4]) { \ |
4986 | for (dim_t r = 0, dr = 0; r < eBatchDims[5]; \ |
4987 | r++, dr += axes[5]) { \ |
4988 | dim_t destIndices[] = {dx, dy, dz, dw, dq, dr}; \ |
4989 | dim_t srcIndices[] = {x, y, z, w, q, r}; \ |
4990 | eDestH.at(destIndices) = \ |
4991 | compare(eDestH.at(destIndices), eBatchH.at(srcIndices)); \ |
4992 | } \ |
4993 | } \ |
4994 | } \ |
4995 | } \ |
4996 | } \ |
4997 | } \ |
4998 | } |
4999 | |
5000 | /// Define fwdBatchedReduceMaxInstImpl. |
5001 | DEFINE_REDUCEMINMAX_INST_IMPL(ReduceMax, std::max) |
5002 | |
5003 | /// Define fwdBatchedReduceMinInstImpl. |
5004 | DEFINE_REDUCEMINMAX_INST_IMPL(ReduceMin, std::min) |
5005 | |
5006 | #undef DEFINE_REDUCEMINMAX_INST_IMPL |
5007 | |
5008 | /// Macro to define ReduceMin/Max instruction. |
5009 | #define DEFINE_REDUCEMINMAX_INST(func, init_func) \ |
5010 | void BoundInterpreterFunction::fwdBatched##func##Inst( \ |
5011 | const glow::Batched##func##Inst *I) { \ |
5012 | \ |
5013 | auto *batch = I->getBatch(); \ |
5014 | auto *dest = I->getDest(); \ |
5015 | const auto axes = I->getAxes(); \ |
5016 | \ |
5017 | /* Initialize both expanded batch and dest dims to the expanded batch \ |
5018 | dims. This allows us below to iterate over the tensor regardless of its \ |
5019 | shape using max_tensor_dimensions loops below.*/ \ |
5020 | ShapeVector eBatchDims = expandDimsToMax(batch->dims()); \ |
5021 | ShapeVector eDestDims = eBatchDims; \ |
5022 | /* Set the destination axes dimensions (the one we are reducing) to 1.*/ \ |
5023 | for (dim_t i = 0; i < axes.size(); i++) { \ |
5024 | eDestDims[axes[i]] = 1; \ |
5025 | } \ |
5026 | \ |
5027 | if (batch->getElementType() == ElemKind::Int8QTy) { \ |
5028 | dispatchQuantizedImpl( \ |
5029 | fwdBatched##func##InstImpl, batch->getElementType(), batch, dest, \ |
5030 | eBatchDims, eDestDims, std::numeric_limits<int8_t>::init_func()); \ |
5031 | } else { \ |
5032 | dispatchArithmeticImpl( \ |
5033 | fwdBatched##func##InstImpl, batch->getElementType(), batch, dest, \ |
5034 | eBatchDims, eDestDims, std::numeric_limits<int32_t>::init_func()); \ |
5035 | } \ |
5036 | } |
5037 | |
5038 | // Define fwdBatchedMinInst |
5039 | DEFINE_REDUCEMINMAX_INST(ReduceMin, max) |
5040 | |
5041 | // Define fwdBatchedMaxInst |
5042 | DEFINE_REDUCEMINMAX_INST(ReduceMax, min) |
5043 | |
5044 | #undef DEFINE_REDUCEMINMAX_INST |
5045 | |
5046 | template <typename ElemTy> |
5047 | void BoundInterpreterFunction::fwdCumSumInstImpl(Value *input, Value *dest, |
5048 | int64_t dim, bool exclusive, |
5049 | bool reverse) { |
5050 | auto eInputH = getTensor(input)->getHandle<ElemTy>(); |
5051 | auto eDestH = getTensor(dest)->getHandle<ElemTy>(); |
5052 | eDestH.clear(); |
5053 | |
5054 | // deal with dim < 0 |
5055 | if (dim < 0) { |
5056 | dim += eInputH.dims().size(); |
5057 | } |
5058 | assert(dim < eInputH.dims().size() && |
5059 | "Dim must be less than the number of dimensions of input tensor" ); |
5060 | |
5061 | std::vector<dim_t> accumDims(eInputH.dims()); |
5062 | accumDims[dim] = 1; |
5063 | |
5064 | Tensor accum(eInputH.getElementType(), accumDims); |
5065 | auto accumH = accum.getHandle<ElemTy>(); |
5066 | accumH.clear(); |
5067 | |
5068 | sdim_t s = 0; |
5069 | sdim_t n = eInputH.dims()[dim]; |
5070 | sdim_t dir = 1; |
5071 | |
5072 | if (reverse) { |
5073 | s = n - 1; |
5074 | n = -1; |
5075 | dir = -1; |
5076 | } |
5077 | |
5078 | for (sdim_t i = s; i != n; i += dir) { |
5079 | std::vector<dim_t> offset(eInputH.dims().size()); |
5080 | offset[dim] = i; |
5081 | |
5082 | Tensor temp(eInputH.getElementType(), accumDims); |
5083 | auto tempH = temp.getHandle<ElemTy>(); |
5084 | eInputH.extractTensors(tempH, offset); |
5085 | |
5086 | if (!exclusive) { |
5087 | for (auto accumIt = accumH.begin(), tempIt = tempH.begin(); |
5088 | accumIt != accumH.end(); ++accumIt, ++tempIt) { |
5089 | *accumIt += *tempIt; |
5090 | } |
5091 | } |
5092 | |
5093 | eDestH.insertTensors(accumH, offset, 1, dim); |
5094 | |
5095 | if (exclusive) { |
5096 | for (auto accumIt = accumH.begin(), tempIt = tempH.begin(); |
5097 | accumIt != accumH.end(); ++accumIt, ++tempIt) { |
5098 | *accumIt += *tempIt; |
5099 | } |
5100 | } |
5101 | } |
5102 | } |
5103 | |
5104 | void BoundInterpreterFunction::fwdCumSumInst(glow::CumSumInst const *I) { |
5105 | dispatchArithmeticImpl(fwdCumSumInstImpl, I->getInput()->getElementType(), |
5106 | I->getInput(), I->getDest(), I->getDim(), |
5107 | I->getExclusive(), I->getReverse()); |
5108 | } |
5109 | |
5110 | template <typename ElemTy> |
5111 | void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl( |
5112 | const LengthsSumInst *I) { |
5113 | staticAssertFloatingPointType(ElemTy); |
5114 | |
5115 | auto out = getTensor(I->getDest()); |
5116 | auto data = getTensor(I->getData()); |
5117 | auto lengths = getTensor(I->getLengths()); |
5118 | |
5119 | out->zero(); |
5120 | |
5121 | auto LH = lengths->getHandle<int32_t>(); |
5122 | |
5123 | size_t segments = lengths->dims()[0]; |
5124 | size_t sliceSize = data->size() / data->dims()[0]; |
5125 | |
5126 | auto DH = data->getHandle<ElemTy>(); |
5127 | auto OH = out->getHandle<ElemTy>(); |
5128 | |
5129 | size_t offsetIn = 0; |
5130 | size_t offsetOut = 0; |
5131 | for (dim_t i = 0; i < segments; i++) { |
5132 | for (int32_t j = 0, e = LH.raw(i); j < e; j++) { |
5133 | for (dim_t k = 0; k < sliceSize; k++) { |
5134 | OH.raw(offsetOut + k) += DH.raw(offsetIn + k); |
5135 | } |
5136 | offsetIn += sliceSize; |
5137 | } |
5138 | offsetOut += sliceSize; |
5139 | } |
5140 | |
5141 | assert(offsetIn == data->size() && "All values in Data should be consumed" ); |
5142 | assert(offsetOut == out->size() && "All values in Dest should be written to" ); |
5143 | } |
5144 | |
5145 | void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) { |
5146 | dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl, |
5147 | I->getData()->getElementType(), I) |
5148 | } |
5149 | |
5150 | template <typename TI> |
5151 | void BoundInterpreterFunction::fwdSparseLengthsSumInstI8Impl( |
5152 | const SparseLengthsSumInst *I) { |
5153 | |
5154 | auto out = getTensor(I->getDest()); |
5155 | auto data = getTensor(I->getData()); |
5156 | auto indices = getTensor(I->getIndices()); |
5157 | auto lengths = getTensor(I->getLengths()); |
5158 | |
5159 | out->zero(); |
5160 | |
5161 | auto IH = indices->getHandle<TI>(); |
5162 | auto LH = lengths->getHandle<int32_t>(); |
5163 | |
5164 | size_t segments = lengths->dims()[0]; |
5165 | size_t totalLength = 0; |
5166 | for (size_t i = 0; i < segments; i++) { |
5167 | totalLength += LH.raw(i); |
5168 | } |
5169 | assert(totalLength <= indices->dims()[0] && |
5170 | "sum(Lengths) must be equal to len(Indices)" ); |
5171 | |
5172 | size_t lineSize = data->size() / data->dims()[0]; |
5173 | |
5174 | auto DH = data->getHandle<int8_t>(); |
5175 | auto OH = out->getHandle<int8_t>(); |
5176 | |
5177 | auto TQP = [](Tensor *T) { |
5178 | return TensorQuantizationParams{T->getType().getScale(), |
5179 | T->getType().getOffset()}; |
5180 | }; |
5181 | |
5182 | size_t curIdx = 0; |
5183 | for (size_t i = 0; i < segments; i++) { |
5184 | std::vector<float> accum(lineSize, 0.0f); |
5185 | for (int32_t j = 0; j < LH.raw(i); j++) { |
5186 | size_t offsetIn = IH.raw(curIdx) * lineSize; |
5187 | for (size_t k = 0; k < lineSize; k++) { |
5188 | accum[k] += quantization::dequantize(DH.raw(offsetIn++), TQP(data)); |
5189 | } |
5190 | curIdx++; |
5191 | } |
5192 | size_t offsetOut = i * lineSize; |
5193 | for (size_t k = 0; k < lineSize; k++) { |
5194 | OH.raw(offsetOut++) = quantization::quantize(accum[k], TQP(out)); |
5195 | } |
5196 | } |
5197 | } |
5198 | |
5199 | template <typename ElemTy, typename TI> |
5200 | void BoundInterpreterFunction::fwdSparseLengthsSumInstFloatImpl( |
5201 | const SparseLengthsSumInst *I) { |
5202 | staticAssertFloatingPointType(ElemTy); |
5203 | |
5204 | auto out = getTensor(I->getDest()); |
5205 | auto data = getTensor(I->getData()); |
5206 | auto indices = getTensor(I->getIndices()); |
5207 | auto lengths = getTensor(I->getLengths()); |
5208 | |
5209 | out->zero(); |
5210 | |
5211 | auto IH = indices->getHandle<TI>(); |
5212 | auto LH = lengths->getHandle<int32_t>(); |
5213 | |
5214 | size_t segments = lengths->dims()[0]; |
5215 | size_t totalLength = 0; |
5216 | for (size_t i = 0; i < segments; i++) { |
5217 | totalLength += LH.raw(i); |
5218 | } |
5219 | assert(totalLength <= indices->dims()[0] && |
5220 | "sum(Lengths) must be equal to len(Indices)" ); |
5221 | |
5222 | size_t lineSize = data->size() / data->dims()[0]; |
5223 | |
5224 | auto DH = data->getHandle<ElemTy>(); |
5225 | auto OH = out->getHandle<ElemTy>(); |
5226 | |
5227 | size_t curIdx = 0; |
5228 | for (size_t i = 0; i < segments; i++) { |
5229 | for (size_t j = 0, e = LH.raw(i); j < e; j++) { |
5230 | size_t offsetIn = IH.raw(curIdx++) * lineSize; |
5231 | size_t offsetOut = i * lineSize; |
5232 | for (size_t k = 0; k < lineSize; k++) |
5233 | OH.raw(offsetOut++) += DH.raw(offsetIn++); |
5234 | } |
5235 | } |
5236 | } |
5237 | |
5238 | void BoundInterpreterFunction::fwdSparseLengthsSumInst( |
5239 | const SparseLengthsSumInst *I) { |
5240 | if (I->getDest()->getType()->isQuantizedType()) { |
5241 | dispatchIndexTypeImpl(fwdSparseLengthsSumInstI8Impl, |
5242 | I->getIndices()->getElementType(), I); |
5243 | return; |
5244 | } |
5245 | dispatchFloatingPointAndIndexImpl(fwdSparseLengthsSumInstFloatImpl, |
5246 | I->getData()->getElementType(), |
5247 | I->getIndices()->getElementType(), I); |
5248 | } |
5249 | |
5250 | template <typename ElemTy, typename TI> |
5251 | void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( |
5252 | const SparseLengthsWeightedSumInst *I) { |
5253 | staticAssertFloatingPointType(ElemTy); |
5254 | |
5255 | auto out = getTensor(I->getDest()); |
5256 | auto data = getTensor(I->getData()); |
5257 | auto weights = getTensor(I->getWeights()); |
5258 | auto indices = getTensor(I->getIndices()); |
5259 | auto lengths = getTensor(I->getLengths()); |
5260 | |
5261 | out->zero(); |
5262 | |
5263 | auto IH = indices->getHandle<TI>(); |
5264 | auto LH = lengths->getHandle<int32_t>(); |
5265 | |
5266 | size_t segments = lengths->dims()[0]; |
5267 | size_t totalLength = 0; |
5268 | for (dim_t i = 0; i < segments; i++) { |
5269 | totalLength += LH.raw(i); |
5270 | } |
5271 | assert(totalLength <= indices->dims()[0] && |
5272 | "sum(Lengths) must be equal to len(Indices)" ); |
5273 | |
5274 | dim_t lineSize = data->size() / data->dims()[0]; |
5275 | |
5276 | auto DH = data->getHandle<ElemTy>(); |
5277 | auto WH = weights->getHandle<ElemTy>(); |
5278 | auto OH = out->getHandle<ElemTy>(); |
5279 | |
5280 | dim_t curIdx = 0; |
5281 | for (dim_t i = 0; i < segments; i++) { |
5282 | for (dim_t j = 0, e = LH.raw(i); j < e; j++) { |
5283 | ElemTy weight = WH.raw(curIdx); |
5284 | size_t offsetIn = IH.raw(curIdx++) * lineSize; |
5285 | size_t offsetOut = i * lineSize; |
5286 | for (dim_t k = 0; k < lineSize; k++) |
5287 | OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight; |
5288 | } |
5289 | } |
5290 | } |
5291 | |
5292 | template <typename TI> |
5293 | void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( |
5294 | const SparseLengthsWeightedSumInst *I) { |
5295 | |
5296 | auto out = getTensor(I->getDest()); |
5297 | auto data = getTensor(I->getData()); |
5298 | auto weights = getTensor(I->getWeights()); |
5299 | auto indices = getTensor(I->getIndices()); |
5300 | auto lengths = getTensor(I->getLengths()); |
5301 | |
5302 | out->zero(); |
5303 | |
5304 | auto IH = indices->getHandle<TI>(); |
5305 | auto LH = lengths->getHandle<int32_t>(); |
5306 | |
5307 | dim_t segments = lengths->dims()[0]; |
5308 | dim_t totalLength = 0; |
5309 | for (dim_t i = 0; i < segments; i++) { |
5310 | totalLength += LH.raw(i); |
5311 | } |
5312 | assert(totalLength <= indices->dims()[0] && |
5313 | "sum(Lengths) must be equal to len(Indices)" ); |
5314 | |
5315 | dim_t lineSize = data->size() / data->dims()[0]; |
5316 | |
5317 | auto DH = data->getHandle<int8_t>(); |
5318 | auto WH = weights->getHandle<int8_t>(); |
5319 | auto OH = out->getHandle<int8_t>(); |
5320 | |
5321 | auto TQP = [](Tensor *T) { |
5322 | return TensorQuantizationParams{T->getType().getScale(), |
5323 | T->getType().getOffset()}; |
5324 | }; |
5325 | using namespace quantization; |
5326 | |
5327 | dim_t curIdx = 0; |
5328 | for (dim_t i = 0; i < segments; i++) { |
5329 | std::vector<float> accum(lineSize, 0.0f); |
5330 | for (int32_t j = 0; j < LH.raw(i); j++) { |
5331 | float weight = dequantize(WH.raw(curIdx), TQP(weights)); |
5332 | size_t offsetIn = IH.raw(curIdx) * lineSize; |
5333 | for (dim_t k = 0; k < lineSize; k++) { |
5334 | accum[k] += weight * dequantize(DH.raw(offsetIn++), TQP(data)); |
5335 | } |
5336 | curIdx++; |
5337 | } |
5338 | dim_t offsetOut = i * lineSize; |
5339 | for (dim_t k = 0; k < lineSize; k++) { |
5340 | OH.raw(offsetOut++) = quantize(accum[k], TQP(out)); |
5341 | } |
5342 | } |
5343 | } |
5344 | |
5345 | void BoundInterpreterFunction::fwdSparseLengthsSumGradInst( |
5346 | const SparseLengthsSumGradInst * /*I*/) { |
5347 | DCHECK(!"Found SparseLengthsSumGradInst but SparseLengthsSum is lowered on " |
5348 | "Interpreter" ); |
5349 | } |
5350 | |
5351 | void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst( |
5352 | const SparseLengthsWeightedSumInst *I) { |
5353 | if (I->getDest()->getType()->isQuantizedType()) { |
5354 | dispatchIndexTypeImpl(fwdSparseLengthsWeightedSumInstI8Impl, |
5355 | I->getIndices()->getElementType(), I); |
5356 | return; |
5357 | } |
5358 | dispatchFloatingPointAndIndexImpl(fwdSparseLengthsWeightedSumInstFloatImpl, |
5359 | I->getData()->getElementType(), |
5360 | I->getIndices()->getElementType(), I); |
5361 | } |
5362 | |
5363 | void BoundInterpreterFunction::fwdSparseLengthsWeightedSumGradInst( |
5364 | const SparseLengthsWeightedSumGradInst *I) { |
5365 | assert(I->getDataGrad()->getType()->getElementType() == ElemKind::FloatTy && |
5366 | "Input type must be float" ); |
5367 | |
5368 | auto destGrad = getTensor(I->getDestGrad()); |
5369 | auto data = getTensor(I->getData()); |
5370 | auto dataGrad = getTensor(I->getDataGrad()); |
5371 | auto weightsGrad = getTensor(I->getWeightsGrad()); |
5372 | auto weights = getTensor(I->getWeights()); |
5373 | auto indices = getTensor(I->getIndices()); |
5374 | auto lengths = getTensor(I->getLengths()); |
5375 | |
5376 | // The data gradients not touched by this operation should |
5377 | // be 0, so set the entire buffer to 0 to start with. |
5378 | dataGrad->zero(); |
5379 | |
5380 | auto LH = lengths->getHandle<int32_t>(); |
5381 | auto IH = indices->getHandle<int64_t>(); |
5382 | |
5383 | size_t segments = lengths->dims()[0]; |
5384 | size_t totalLength = 0; |
5385 | for (size_t i = 0; i < segments; ++i) { |
5386 | totalLength += LH.raw(i); |
5387 | } |
5388 | assert(totalLength == indices->dims()[0] && |
5389 | "sum(Lengths) must be equal to len(Indices)" ); |
5390 | |
5391 | size_t lineSize = dataGrad->size() / dataGrad->dims()[0]; |
5392 | |
5393 | auto IGH = destGrad->getHandle(); |
5394 | auto WH = weights->getHandle(); |
5395 | auto WGH = weightsGrad->getHandle(); |
5396 | auto DH = data->getHandle(); |
5397 | auto OGH = dataGrad->getHandle(); |
5398 | |
5399 | // For each index in each segment: |
5400 | // 1) accumulate into the corresponding data gradient the product of the |
5401 | // gradient of the result it was added to and the weight that it was |
5402 | // multiplied by during the SparseLengthsWeightedSum operation. |
5403 | // |
5404 | // 2) accumulate into each weight gradient the reduced sum of the |
5405 | // elementwise product of the result slice that the corresponding weight |
5406 | // produced and the input slice that the weight was multiplied with. |
5407 | for (size_t i = 0, curIdx = 0; i < segments; ++i) { |
5408 | size_t destOffset = i * lineSize; |
5409 | for (size_t j = 0, e = LH.raw(i); j < e; ++j, ++curIdx) { |
5410 | float weightGrad = 0.0f; |
5411 | float weight = WH.raw(curIdx); |
5412 | size_t dataOffset = IH.raw(curIdx) * lineSize; |
5413 | |
5414 | for (size_t k = 0; k < lineSize; ++k) { |
5415 | OGH.raw(dataOffset + k) += IGH.raw(destOffset + k) * weight; |
5416 | weightGrad += IGH.raw(destOffset + k) * DH.raw(dataOffset + k); |
5417 | } |
5418 | |
5419 | WGH.raw(curIdx) = weightGrad; |
5420 | } |
5421 | } |
5422 | } |
5423 | |
5424 | template <typename ElemTy, typename IndexType> |
5425 | void BoundInterpreterFunction::fwdEmbeddingBagInstFloatImpl( |
5426 | const EmbeddingBagInst *I) { |
5427 | staticAssertFloatingPointType(ElemTy); |
5428 | |
5429 | auto out = getTensor(I->getDest()); |
5430 | auto data = getTensor(I->getData()); |
5431 | auto weights = getTensor(I->getWeights()); |
5432 | auto indices = getTensor(I->getIndices()); |
5433 | auto offsets = getTensor(I->getOffsets()); |
5434 | bool hasEndOffset = I->getHasEndOffset(); |
5435 | |
5436 | out->zero(); |
5437 | |
5438 | auto IH = indices->getHandle<IndexType>(); |
5439 | auto OFFH = offsets->getHandle<IndexType>(); |
5440 | |
5441 | // If an end offset is present to mark the end of the last segment then this |
5442 | // must be subtracted to get the correct number of segments |
5443 | size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0]; |
5444 | size_t numIndices = indices->dims()[0]; |
5445 | |
5446 | size_t lineSize = data->size() / data->dims()[0]; |
5447 | |
5448 | auto DH = data->getHandle<ElemTy>(); |
5449 | auto WH = weights->getHandle<ElemTy>(); |
5450 | auto OH = out->getHandle<ElemTy>(); |
5451 | |
5452 | dim_t curIdx = 0; |
5453 | for (dim_t i = 0; i < segments; i++) { |
5454 | dim_t start = OFFH.raw(i); |
5455 | dim_t end; |
5456 | if (!hasEndOffset) { |
5457 | // Note that in this case we have to use numIndices to find the end of |
5458 | // the last segment. This is an issue though because it relies on |
5459 | // knowing the total length of the indices tensor which may not be |
5460 | // possible. Future implementations of this operator should always give |
5461 | // an end offset so eventually this case should be removed. |
5462 | end = i == segments - 1 ? numIndices : OFFH.raw(i + 1); |
5463 | } else { |
5464 | end = OFFH.raw(i + 1); |
5465 | } |
5466 | if (start == end) { |
5467 | continue; |
5468 | } else if (start > end) { |
5469 | break; |
5470 | } |
5471 | for (dim_t j = start; j < end; j++) { |
5472 | ElemTy weight = WH.raw(curIdx); |
5473 | dim_t offsetIn = IH.raw(curIdx++) * lineSize; |
5474 | dim_t offsetOut = i * lineSize; |
5475 | for (dim_t k = 0; k < lineSize; k++) { |
5476 | OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight; |
5477 | } |
5478 | } |
5479 | } |
5480 | } |
5481 | |
5482 | void BoundInterpreterFunction::fwdEmbeddingBagInst(const EmbeddingBagInst *I) { |
5483 | dispatchFloatingPointAndIndexImpl(fwdEmbeddingBagInstFloatImpl, |
5484 | I->getData()->getElementType(), |
5485 | I->getIndices()->getElementType(), I); |
5486 | } |
5487 | |
5488 | template <typename ElemTy> |
5489 | void BoundInterpreterFunction::fwdEmbeddingInstImpl(Tensor *wtT, Tensor *indT, |
5490 | Tensor *outT, |
5491 | int64_t padIdx, bool sparse, |
5492 | bool scale, |
5493 | dim_t embedding_dim) { |
5494 | |
5495 | staticAssertFloatingPointType(ElemTy); |
5496 | |
5497 | assert(!scale && "Currently only support scale_grad_by_freq == 'false'" ); |
5498 | assert(!sparse && "Currently only support sparse == 'false'" ); |
5499 | |
5500 | // Indices Tensor can be an arbitrary shape. |
5501 | // Get it flattened to 1D vector of size indLen |
5502 | // Output Tensor can be an arbitray shape. |
5503 | // Get it reshaped to a 2D tensor of size (indLen, embedding_dim) |
5504 | dim_t indLen = 1; |
5505 | for (dim_t idx = 0; idx < indT->dims().size(); ++idx) { |
5506 | indLen *= indT->dims()[idx]; |
5507 | } |
5508 | auto fIndT = indT->getUnowned({indLen}); |
5509 | auto fOutT = outT->getUnowned({indLen, embedding_dim}); |
5510 | |
5511 | fOutT.zero(); |
5512 | |
5513 | auto WH = wtT->getHandle<ElemTy>(); |
5514 | auto OH = fOutT.getHandle<ElemTy>(); |
5515 | auto IH = fIndT.getHandle<int32_t>(); |
5516 | |
5517 | for (dim_t i = 0; i < indLen; i++) { |
5518 | dim_t index = IH.at(i); |
5519 | if (index != padIdx) { |
5520 | for (dim_t j = 0; j < embedding_dim; j++) { |
5521 | OH.at({i, j}) = WH.at({index, j}); |
5522 | } |
5523 | } |
5524 | } |
5525 | } |
5526 | |
5527 | void BoundInterpreterFunction::fwdEmbeddingInst(const EmbeddingInst *I) { |
5528 | auto wtT = getTensor(I->getWeights()); |
5529 | auto indT = getTensor(I->getIndices()); |
5530 | auto outT = getTensor(I->getDest()); |
5531 | auto padIdx = I->getPadIdx(); |
5532 | bool sparse = I->getSparse(); |
5533 | bool scale = I->getScale(); |
5534 | dim_t embedding_dim = wtT->dims()[1]; |
5535 | auto elemTy = wtT->getElementType(); |
5536 | |
5537 | if (padIdx > -1) { |
5538 | assert(static_cast<dim_t>(padIdx) < wtT->dims()[0] && |
5539 | "padIdx should be within num_embeddings" ); |
5540 | } |
5541 | |
5542 | dispatchFloatingPointImpl(fwdEmbeddingInstImpl, elemTy, wtT, indT, outT, |
5543 | padIdx, sparse, scale, embedding_dim); |
5544 | } |
5545 | |
5546 | template <typename T, typename AccumT, typename TI> |
5547 | void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumImpl( |
5548 | const RowwiseQuantizedSparseLengthsWeightedSumInst *I) { |
5549 | auto *out = getTensor(I->getDest()); |
5550 | auto *data = getTensor(I->getData()); |
5551 | auto *dataScales = getTensor(I->getScales()); |
5552 | auto *dataOffsets = getTensor(I->getOffsets()); |
5553 | auto *weights = getTensor(I->getWeights()); |
5554 | auto *indices = getTensor(I->getIndices()); |
5555 | auto *lengths = getTensor(I->getLengths()); |
5556 | |
5557 | out->zero(); |
5558 | |
5559 | auto IH = indices->getHandle<TI>(); |
5560 | auto LH = lengths->getHandle<int32_t>(); |
5561 | |
5562 | dim_t segments = lengths->dims()[0]; |
5563 | dim_t totalLength = 0; |
5564 | for (dim_t i = 0; i < segments; i++) { |
5565 | totalLength += LH.raw(i); |
5566 | } |
5567 | assert(totalLength <= indices->dims()[0] && |
5568 | "sum(Lengths) must be equal to len(Indices)" ); |
5569 | |
5570 | dim_t lineSize = data->size() / data->dims()[0]; |
5571 | |
5572 | auto DH = data->getHandle<uint8_t>(); |
5573 | auto DSH = dataScales->getHandle<T>(); |
5574 | auto DOH = dataOffsets->getHandle<T>(); |
5575 | auto WH = weights->getHandle<T>(); |
5576 | auto OH = out->getHandle<T>(); |
5577 | |
5578 | dim_t curIdx = 0; |
5579 | for (dim_t i = 0; i < segments; i++) { |
5580 | std::vector<AccumT> accum(lineSize, 0.0f); |
5581 | for (dim_t j = 0, e = LH.raw(i); j < e; j++) { |
5582 | const float weight = static_cast<float>(WH.raw(curIdx)); |
5583 | const dim_t rowIdx = IH.raw(curIdx++); |
5584 | const float scale = static_cast<float>(DSH.at({rowIdx})); |
5585 | const float offset = static_cast<float>(DOH.at({rowIdx})); |
5586 | size_t offsetIn = rowIdx * lineSize; |
5587 | for (dim_t k = 0; k < lineSize; k++) { |
5588 | float d = quantization::dequantizeWithFloatOffset(DH.raw(offsetIn++), |
5589 | scale, offset); |
5590 | accum[k] += d * weight; |
5591 | } |
5592 | } |
5593 | // Accumulation in FP32 complete, now copy back to output with cast to T. |
5594 | size_t offsetOut = i * lineSize; |
5595 | for (size_t k = 0; k < lineSize; k++) { |
5596 | OH.raw(offsetOut++) = static_cast<T>(accum[k]); |
5597 | } |
5598 | } |
5599 | } |
5600 | |
5601 | void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst( |
5602 | const RowwiseQuantizedSparseLengthsWeightedSumInst *I) { |
5603 | const auto ity = I->getIndices()->getElementType(); |
5604 | switch (I->getDest()->getElementType()) { |
5605 | case ElemKind::FloatTy: |
5606 | if (ity == ElemKind::Int32ITy) { |
5607 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int32_t>(I); |
5608 | } else if (ity == ElemKind::Int64ITy) { |
5609 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int64_t>(I); |
5610 | } else { |
5611 | llvm_unreachable("Index type is not supported" ); |
5612 | } |
5613 | break; |
5614 | case ElemKind::Float16Ty: |
5615 | if (I->getUseFP16Accumulation()) { |
5616 | if (ity == ElemKind::Int32ITy) { |
5617 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t, |
5618 | int32_t>(I); |
5619 | } else if (ity == ElemKind::Int64ITy) { |
5620 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t, |
5621 | int64_t>(I); |
5622 | } else { |
5623 | llvm_unreachable("Index type is not supported" ); |
5624 | } |
5625 | } else { |
5626 | if (ity == ElemKind::Int32ITy) { |
5627 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float, |
5628 | int32_t>(I); |
5629 | } else if (ity == ElemKind::Int64ITy) { |
5630 | fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float, |
5631 | int64_t>(I); |
5632 | } else { |
5633 | llvm_unreachable("Index type is not supported" ); |
5634 | } |
5635 | } |
5636 | break; |
5637 | default: |
5638 | llvm_unreachable("Type is not supported" ); |
5639 | } |
5640 | } |
5641 | |
5642 | template <typename T, typename AccumT, typename TI> |
5643 | void BoundInterpreterFunction:: |
5644 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl( |
5645 | const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) { |
5646 | auto *out = getTensor(I->getDest()); |
5647 | auto *data = getTensor(I->getData()); |
5648 | auto *weights = getTensor(I->getWeights()); |
5649 | auto *indices = getTensor(I->getIndices()); |
5650 | auto *lengths = getTensor(I->getLengths()); |
5651 | |
5652 | out->zero(); |
5653 | |
5654 | auto IH = indices->getHandle<TI>(); |
5655 | auto LH = lengths->getHandle<int32_t>(); |
5656 | |
5657 | size_t segments = lengths->dims()[0]; |
5658 | size_t totalLength = 0; |
5659 | for (size_t i = 0; i < segments; i++) { |
5660 | totalLength += LH.raw(i); |
5661 | } |
5662 | assert(totalLength <= indices->dims()[0] && |
5663 | "sum(Lengths) must be equal to len(Indices)" ); |
5664 | |
5665 | const bool using4BitQuantization = |
5666 | data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy || |
5667 | data->getType().getElementType() == ElemKind::UInt4FusedQTy; |
5668 | |
5669 | const size_t outLineSize = out->size() / out->dims()[0]; |
5670 | |
5671 | auto DH = data->getHandle<uint8_t>(); |
5672 | auto WH = weights->getHandle<T>(); |
5673 | auto OH = out->getHandle<T>(); |
5674 | |
5675 | dim_t curIdx = 0; |
5676 | for (dim_t i = 0; i < segments; i++) { |
5677 | std::vector<AccumT> accum(outLineSize, 0.0f); |
5678 | for (dim_t j = 0, e = LH.raw(i); j < e; j++) { |
5679 | const float weight = static_cast<float>(WH.raw(curIdx)); |
5680 | const dim_t rowIdx = IH.raw(curIdx++); |
5681 | // Data type for the Scale and Offset for fused types need not follow |
5682 | // the type for the output Tensor passed in T. |
5683 | float scale, offset; |
5684 | switch ( |
5685 | getScaleOffsetElemKindFromFused(data->getType().getElementType())) { |
5686 | case ElemKind::FloatTy: |
5687 | std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<float>(rowIdx); |
5688 | break; |
5689 | case ElemKind::Float16Ty: |
5690 | std::tie(scale, offset) = |
5691 | DH.getFusedScaleOffsetFromRow<float16_t>(rowIdx); |
5692 | break; |
5693 | default: |
5694 | llvm_unreachable("Type is not supported" ); |
5695 | break; |
5696 | } |
5697 | |
5698 | for (dim_t k = 0; k < outLineSize; k++) { |
5699 | float d = 0.0f; |
5700 | if (!using4BitQuantization) { |
5701 | d = quantization::dequantizeWithFloatOffset( |
5702 | DH.at({rowIdx, k}), static_cast<float>(scale), |
5703 | static_cast<float>(offset)); |
5704 | } else { |
5705 | const bool isMSB = (k % 2 == 1); |
5706 | d = quantization::dequantize4BitWithFloatOffset( |
5707 | DH.at({rowIdx, k / 2}), static_cast<float>(scale), |
5708 | static_cast<float>(offset), isMSB); |
5709 | } |
5710 | accum[k] += d * weight; |
5711 | } |
5712 | } |
5713 | // Accumulation in FP32 complete, now copy back to output with cast to T. |
5714 | dim_t offsetOut = i * outLineSize; |
5715 | for (dim_t k = 0; k < outLineSize; k++) { |
5716 | OH.raw(offsetOut++) = static_cast<T>(accum[k]); |
5717 | } |
5718 | } |
5719 | } |
5720 | |
5721 | void BoundInterpreterFunction:: |
5722 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst( |
5723 | const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) { |
5724 | const auto ity = I->getIndices()->getElementType(); |
5725 | const bool fp32FusedScaleOffset = |
5726 | (I->getData()->getElementType() == ElemKind::UInt4FusedQTy) || |
5727 | (I->getData()->getElementType() == ElemKind::UInt8FusedQTy); |
5728 | |
5729 | switch (I->getDest()->getElementType()) { |
5730 | case ElemKind::FloatTy: |
5731 | if (ity == ElemKind::Int32ITy) { |
5732 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, |
5733 | int32_t>(I); |
5734 | } else if (ity == ElemKind::Int64ITy) { |
5735 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, |
5736 | int64_t>(I); |
5737 | } else { |
5738 | llvm_unreachable("Index type is not supported" ); |
5739 | } |
5740 | break; |
5741 | case ElemKind::Float16Ty: |
5742 | if (I->getUseFP16Accumulation() && !fp32FusedScaleOffset) { |
5743 | if (ity == ElemKind::Int32ITy) { |
5744 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl< |
5745 | float16_t, float16_t, int32_t>(I); |
5746 | } else if (ity == ElemKind::Int64ITy) { |
5747 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl< |
5748 | float16_t, float16_t, int64_t>(I); |
5749 | } else { |
5750 | llvm_unreachable("Index type is not supported" ); |
5751 | } |
5752 | } else { |
5753 | if (ity == ElemKind::Int32ITy) { |
5754 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float, |
5755 | int32_t>(I); |
5756 | } else if (ity == ElemKind::Int64ITy) { |
5757 | fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float, |
5758 | int64_t>(I); |
5759 | } else { |
5760 | llvm_unreachable("Index type is not supported" ); |
5761 | } |
5762 | } |
5763 | break; |
5764 | default: |
5765 | llvm_unreachable("Type is not supported" ); |
5766 | } |
5767 | } |
5768 | |
5769 | void BoundInterpreterFunction::fwdFusedRowwiseQuantizedSparseLengthsSumInst( |
5770 | const FusedRowwiseQuantizedSparseLengthsSumInst *I) { |
5771 | llvm_unreachable("Not supported" ); |
5772 | } |
5773 | |
5774 | template <typename T, typename AccumT, typename IndexT> |
5775 | void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl( |
5776 | const EmbeddingBagByteRowwiseOffsetsInst *I) { |
5777 | auto *out = getTensor(I->getDest()); |
5778 | auto *data = getTensor(I->getData()); |
5779 | auto *weights = getTensor(I->getWeights()); |
5780 | auto *indices = getTensor(I->getIndices()); |
5781 | auto *offsets = getTensor(I->getOffsets()); |
5782 | bool hasEndOffset = I->getHasEndOffset(); |
5783 | |
5784 | out->zero(); |
5785 | |
5786 | auto IH = indices->getHandle<IndexT>(); |
5787 | auto OFFH = offsets->getHandle<IndexT>(); |
5788 | |
5789 | // If an end offset is present to mark the end of the last segment then this |
5790 | // must be subtracted to get the correct number of segments |
5791 | size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0]; |
5792 | dim_t numIndices = indices->dims()[0]; |
5793 | |
5794 | const bool using4BitQuantization = |
5795 | data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy; |
5796 | |
5797 | const size_t outLineSize = out->size() / out->dims()[0]; |
5798 | |
5799 | auto DH = data->getHandle<uint8_t>(); |
5800 | auto WH = weights->getHandle<T>(); |
5801 | auto OH = out->getHandle<T>(); |
5802 | |
5803 | for (dim_t i = 0; i < segments; i++) { |
5804 | std::vector<AccumT> accum(outLineSize, 0.0f); |
5805 | size_t start = OFFH.raw(i); |
5806 | dim_t end; |
5807 | if (!hasEndOffset) { |
5808 | // Note that in this case we have to use numIndices to find the end of |
5809 | // the last segment. This is an issue though because it relies on |
5810 | // knowing the total length of the indices tensor which may not be |
5811 | // possible. Future implementations of this operator should always give |
5812 | // an end offset so eventually this case should be removed. |
5813 | end = i == segments - 1 ? numIndices : OFFH.raw(i + 1); |
5814 | } else { |
5815 | end = OFFH.raw(i + 1); |
5816 | } |
5817 | if (start == end) { |
5818 | continue; |
5819 | } else if (start > end) { |
5820 | break; |
5821 | } |
5822 | |
5823 | for (dim_t j = start; j < end; j++) { |
5824 | const float weight = static_cast<float>(WH.raw(j)); |
5825 | const dim_t rowIdx = IH.raw(j); |
5826 | T scale, offset; |
5827 | std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx); |
5828 | for (dim_t k = 0; k < outLineSize; k++) { |
5829 | float d = 0.0f; |
5830 | if (!using4BitQuantization) { |
5831 | d = quantization::dequantizeWithFloatOffset( |
5832 | DH.at({rowIdx, k}), static_cast<float>(scale), |
5833 | static_cast<float>(offset)); |
5834 | } else { |
5835 | const bool isMSB = (k % 2 == 1); |
5836 | d = quantization::dequantize4BitWithFloatOffset( |
5837 | DH.at({rowIdx, k / 2}), static_cast<float>(scale), |
5838 | static_cast<float>(offset), isMSB); |
5839 | } |
5840 | accum[k] += d * weight; |
5841 | } |
5842 | } |
5843 | // Accumulation in FP32 complete, now copy back to output with cast to T. |
5844 | dim_t offsetOut = i * outLineSize; |
5845 | for (dim_t k = 0; k < outLineSize; k++) { |
5846 | OH.raw(offsetOut++) = static_cast<T>(accum[k]); |
5847 | } |
5848 | } |
5849 | } |
5850 | |
5851 | void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst( |
5852 | const EmbeddingBagByteRowwiseOffsetsInst *I) { |
5853 | const auto ity = I->getIndices()->getElementType(); |
5854 | const bool fp32FusedScaleOffset = |
5855 | (I->getData()->getElementType() == ElemKind::UInt4FusedQTy) || |
5856 | (I->getData()->getElementType() == ElemKind::UInt8FusedQTy); |
5857 | |
5858 | switch (I->getDest()->getElementType()) { |
5859 | case ElemKind::FloatTy: |
5860 | if (ity == ElemKind::Int32ITy) { |
5861 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float, int32_t>(I); |
5862 | } else if (ity == ElemKind::Int64ITy) { |
5863 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float, int64_t>(I); |
5864 | } else { |
5865 | llvm_unreachable("Index type is not supported" ); |
5866 | } |
5867 | break; |
5868 | case ElemKind::Float16Ty: |
5869 | if (I->getUseFP16Accumulation() && !fp32FusedScaleOffset) { |
5870 | if (ity == ElemKind::Int32ITy) { |
5871 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t, int32_t>(I); |
5872 | } else if (ity == ElemKind::Int64ITy) { |
5873 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t, int64_t>(I); |
5874 | } else { |
5875 | llvm_unreachable("Index type is not supported" ); |
5876 | } |
5877 | } else { |
5878 | if (ity == ElemKind::Int32ITy) { |
5879 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float, int32_t>(I); |
5880 | } else if (ity == ElemKind::Int64ITy) { |
5881 | fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float, int64_t>(I); |
5882 | } else { |
5883 | llvm_unreachable("Index type is not supported" ); |
5884 | } |
5885 | } |
5886 | break; |
5887 | default: |
5888 | llvm_unreachable("Type is not supported" ); |
5889 | } |
5890 | } |
5891 | |
5892 | void BoundInterpreterFunction::fwdLengthsToRangesInst( |
5893 | const LengthsToRangesInst *I) { |
5894 | auto ranges = getTensor(I->getDest())->getHandle<int32_t>(); |
5895 | auto lengths = getTensor(I->getLengths())->getHandle<int32_t>(); |
5896 | int32_t offset = 0; |
5897 | for (dim_t i = 0; i < lengths.dims()[0]; i++) { |
5898 | auto length = lengths.at({i}); |
5899 | ranges.at({i, 0}) = offset; |
5900 | ranges.at({i, 1}) = length; |
5901 | offset += length; |
5902 | } |
5903 | } |
5904 | |
5905 | void BoundInterpreterFunction::fwdLengthsRangeFillInst( |
5906 | const LengthsRangeFillInst *I) { |
5907 | auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>(); |
5908 | auto resultH = getTensor(I->getDest())->getHandle<int32_t>(); |
5909 | dim_t curIdx = 0; |
5910 | for (dim_t i = 0, e = lengthsH.dims()[0]; i < e; i++) { |
5911 | for (int32_t j = 0, f = lengthsH.at({i}); j < f; j++) { |
5912 | resultH.at({curIdx++}) = j; |
5913 | } |
5914 | } |
5915 | } |
5916 | |
5917 | void BoundInterpreterFunction::fwdGaussianFillInst(const GaussianFillInst *I) { |
5918 | std::mt19937 rnd(I->getSeed()); |
5919 | std::normal_distribution<float> dist(I->getMean(), I->getScale()); |
5920 | auto outT = getTensor(I->getDest()); |
5921 | auto outH = outT->getHandle<float16_t>(); |
5922 | for (auto &elem : outH) { |
5923 | elem = dist(rnd); |
5924 | } |
5925 | } |
5926 | |
5927 | template <typename ElemTy, typename LengthsTy, typename IndicesTy> |
5928 | void BoundInterpreterFunction::fwdBatchSparseToDenseInstImpl2( |
5929 | const BatchSparseToDenseInst *I) { |
5930 | auto outH = getWeightHandle<ElemTy>(I->getDest()); |
5931 | auto lengthsH = getWeightHandle<LengthsTy>(I->getLengths()); |
5932 | auto valuesH = getWeightHandle<ElemTy>(I->getValues()); |
5933 | auto indicesH = getWeightHandle<IndicesTy>(I->getIndices()); |
5934 | auto denseLastDim = I->getDenseLastDim(); |
5935 | auto defaultValue = I->getDefaultValue(); |
5936 | outH.clear(defaultValue); |
5937 | |
5938 | // Verifying input sizes. |
5939 | size_t lengthsSum = 0; |
5940 | auto batchSize = lengthsH.size(); |
5941 | for (dim_t i = 0; i < batchSize; ++i) { |
5942 | lengthsSum += lengthsH.at(i); |
5943 | } |
5944 | CHECK_EQ(lengthsSum, indicesH.size()); |
5945 | |
5946 | dim_t k = 0; |
5947 | for (dim_t i = 0; i < batchSize; ++i) { |
5948 | for (dim_t j = 0; j < lengthsH.at(i); ++j) { |
5949 | CHECK_LT(indicesH.at(i), denseLastDim); |
5950 | outH.at({static_cast<dim_t>(i), static_cast<dim_t>(indicesH.at(k))}) = |
5951 | valuesH.at(k); |
5952 | k++; |
5953 | } |
5954 | } |
5955 | } |
5956 | |
5957 | template <typename ElemTy, typename LengthsTy> |
5958 | void BoundInterpreterFunction::fwdBatchSparseToDenseInstImpl1( |
5959 | const BatchSparseToDenseInst *I) { |
5960 | switch (I->getLengths()->getElementType()) { |
5961 | case ElemKind::Int32ITy: |
5962 | fwdBatchSparseToDenseInstImpl2<ElemTy, LengthsTy, int32_t>(I); |
5963 | break; |
5964 | case ElemKind::Int64ITy: |
5965 | fwdBatchSparseToDenseInstImpl2<ElemTy, LengthsTy, int64_t>(I); |
5966 | break; |
5967 | default: |
5968 | llvm_unreachable("Index type is not supported" ); |
5969 | } |
5970 | } |
5971 | |
5972 | void BoundInterpreterFunction::fwdBatchSparseToDenseInst( |
5973 | const BatchSparseToDenseInst *I) { |
5974 | dispatchFloatingPointAndIndexImpl(fwdBatchSparseToDenseInstImpl1, |
5975 | I->getDest()->getElementType(), |
5976 | I->getLengths()->getElementType(), I); |
5977 | } |
5978 | |
5979 | template <typename ElemTy, typename IndicatorTy> |
5980 | void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInstImpl2( |
5981 | const FillExamplesWithIndicatorInst *I) { |
5982 | auto outT = getTensor(I->getDest()); |
5983 | auto dataT = getTensor(I->getData()); |
5984 | auto elemSize = dataT->getType().getElementSize(); |
5985 | auto indicatorH = getWeightHandle<IndicatorTy>(I->getIndicator()); |
5986 | |
5987 | size_t numBatches = indicatorH.dims()[0]; |
5988 | outT->zero(); |
5989 | |
5990 | size_t nonzero = 0; |
5991 | for (size_t i = 0; i < numBatches; ++i) { |
5992 | if (static_cast<bool>(indicatorH.at(i))) { |
5993 | nonzero++; |
5994 | } |
5995 | } |
5996 | CHECK_EQ(dataT->dims()[0], nonzero); |
5997 | |
5998 | // Calculate size of last n-1 data dims |
5999 | size_t blockSize = 1; |
6000 | for (size_t i = 1; i < dataT->dims().size(); i++) { |
6001 | blockSize *= dataT->dims()[i]; |
6002 | } |
6003 | size_t blockByteSize = blockSize * elemSize; |
6004 | size_t dataP = 0; |
6005 | for (size_t i = 0; i < numBatches; i++) { |
6006 | if (static_cast<bool>(indicatorH.at(i))) { |
6007 | std::copy(&dataT->getUnsafePtr()[dataP], |
6008 | &dataT->getUnsafePtr()[dataP + blockByteSize], |
6009 | &outT->getUnsafePtr()[i * blockByteSize]); |
6010 | dataP += blockByteSize; |
6011 | } |
6012 | } |
6013 | } |
6014 | |
6015 | template <typename ElemTy> |
6016 | void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInstImpl1( |
6017 | const FillExamplesWithIndicatorInst *I) { |
6018 | switch (I->getIndicator()->getElementType()) { |
6019 | case ElemKind::Int32ITy: |
6020 | fwdFillExamplesWithIndicatorInstImpl2<ElemTy, int32_t>(I); |
6021 | break; |
6022 | case ElemKind::Int64ITy: |
6023 | fwdFillExamplesWithIndicatorInstImpl2<ElemTy, int64_t>(I); |
6024 | break; |
6025 | case ElemKind::BoolTy: |
6026 | fwdFillExamplesWithIndicatorInstImpl2<ElemTy, bool>(I); |
6027 | break; |
6028 | default: |
6029 | llvm_unreachable("Indicator type is not supported" ); |
6030 | } |
6031 | } |
6032 | |
6033 | void BoundInterpreterFunction::fwdFillExamplesWithIndicatorInst( |
6034 | const FillExamplesWithIndicatorInst *I) { |
6035 | dispatchArithmeticImpl(fwdFillExamplesWithIndicatorInstImpl1, |
6036 | I->getDest()->getElementType(), I); |
6037 | } |
6038 | void BoundInterpreterFunction::fwdSparseToDenseMaskInst( |
6039 | const SparseToDenseMaskInst *I) { |
6040 | auto out = getTensor(I->getDest()); |
6041 | auto values = getTensor(I->getValues()); |
6042 | auto defaultValue = getTensor(I->getDefaultValue()); |
6043 | |
6044 | auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>(); |
6045 | auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>(); |
6046 | |
6047 | const std::vector<dim_t> &mask = I->getMask(); |
6048 | size_t maskSize = mask.size(); |
6049 | // Create a reverse map from ID to its position in the mask. |
6050 | std::unordered_map<int64_t, size_t> reverseMap; |
6051 | for (size_t i = 0; i < maskSize; i++) { |
6052 | assert(reverseMap.find(mask[i]) == reverseMap.end() && |
6053 | "duplicate IDs in the mask" ); |
6054 | reverseMap[mask[i]] = i; |
6055 | } |
6056 | |
6057 | auto valueSize = defaultValue->getSizeInBytes(); |
6058 | |
6059 | // First un-processed index-value pair. |
6060 | size_t posIn = 0; |
6061 | // Beginning of output block for first unprocessed batch. |
6062 | size_t byteOffsetOut = 0; |
6063 | // Lengths can be scalar, which means that all pairs belong to one batch. |
6064 | size_t numBatches = lengthsH.dims().empty() ? 1 : lengthsH.dims()[0]; |
6065 | for (size_t batch = 0; batch < numBatches; batch++) { |
6066 | // Fill everything with maskSize copies of defaultValue. |
6067 | for (size_t i = 0; i < maskSize; i++) { |
6068 | std::copy(defaultValue->getUnsafePtr(), |
6069 | &defaultValue->getUnsafePtr()[valueSize], |
6070 | &out->getUnsafePtr()[byteOffsetOut + valueSize * i]); |
6071 | } |
6072 | // Go through input pairs and find matches. |
6073 | for (size_t i = 0, batchLen = lengthsH.raw(batch); i < batchLen; |
6074 | i++, posIn++) { |
6075 | int64_t idx = indicesH.raw(posIn); |
6076 | auto it = reverseMap.find(idx); |
6077 | // Skip if ID is not present in the mask. |
6078 | if (it == reverseMap.end()) |
6079 | continue; |
6080 | size_t to = it->second; |
6081 | |
6082 | std::copy(&values->getUnsafePtr()[posIn * valueSize], |
6083 | &values->getUnsafePtr()[(posIn + 1) * valueSize], |
6084 | &out->getUnsafePtr()[byteOffsetOut + valueSize * to]); |
6085 | } |
6086 | |
6087 | byteOffsetOut += maskSize * valueSize; |
6088 | } |
6089 | |
6090 | assert(posIn == indicesH.dims()[0] && |
6091 | "Sum of Lengths must be equal to size of indices." ); |
6092 | } |
6093 | |
6094 | void BoundInterpreterFunction::fwdSparseLabelSplitInst( |
6095 | const SparseLabelSplitInst *I) { |
6096 | auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>(); |
6097 | auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>(); |
6098 | auto valuesH = getTensor(I->getValues())->getHandle(); |
6099 | |
6100 | const auto numLabels = I->getNumLabels(); |
6101 | |
6102 | auto labelValuesH = getTensor(I->getLabelValues())->getHandle(); |
6103 | auto exampleIdsH = getTensor(I->getExampleIds())->getHandle<int32_t>(); |
6104 | auto gradientOffsetMapH = |
6105 | getTensor(I->getGradientOffsetMap())->getHandle<int32_t>(); |
6106 | |
6107 | // Verifying input sizes. |
6108 | size_t lengthsSum = 0; |
6109 | for (size_t i = 0; i < lengthsH.size(); ++i) { |
6110 | lengthsSum += lengthsH.at(i); |
6111 | } |
6112 | CHECK_EQ(lengthsSum, indicesH.size()); |
6113 | CHECK_EQ(indicesH.size(), valuesH.size()); |
6114 | |
6115 | // Verifying that outputs have same sizes. |
6116 | const auto numValuesPerRow = indicesH.size() / numLabels; |
6117 | std::vector<size_t> numExamplesPerTask(numLabels, 0); |
6118 | for (size_t i = 0; i < indicesH.size(); ++i) { |
6119 | numExamplesPerTask[indicesH.at(i)] += 1; |
6120 | } |
6121 | for (size_t i = 0; i < numLabels; ++i) { |
6122 | CHECK_EQ(numValuesPerRow, numExamplesPerTask[i]) |
6123 | << "Unexpected number of values at " << i; |
6124 | } |
6125 | |
6126 | // Populating outputs |
6127 | size_t pos = 0; |
6128 | std::fill(numExamplesPerTask.begin(), numExamplesPerTask.end(), 0); |
6129 | for (size_t i = 0; i < lengthsH.size(); ++i) { |
6130 | for (size_t l = 0; l < lengthsH.at(i); ++l) { |
6131 | auto ind = indicesH.at(pos); |
6132 | auto val = valuesH.at(pos); |
6133 | |
6134 | auto posOutput = numExamplesPerTask[ind]++; |
6135 | gradientOffsetMapH.at(pos) = posOutput; |
6136 | |
6137 | labelValuesH.at( |
6138 | {static_cast<dim_t>(ind), static_cast<dim_t>(posOutput)}) = val; |
6139 | exampleIdsH.at({static_cast<dim_t>(ind), static_cast<dim_t>(posOutput)}) = |
6140 | i; |
6141 | pos++; |
6142 | } |
6143 | } |
6144 | } |
6145 | |
6146 | //===----------------------------------------------------------------------===// |
6147 | // Instructions used by RNN |
6148 | //===----------------------------------------------------------------------===// |
6149 | template <typename T, typename TI> |
6150 | static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) { |
6151 | auto values = outW->getHandle<T>(); |
6152 | auto indices = indW->getHandle<TI>(); |
6153 | auto in = inW->getHandle<T>(); |
6154 | size_t n = in.dims().back(); |
6155 | |
6156 | size_t in_p = 0, out_p = 0; |
6157 | size_t tensor_end = in.size(); |
6158 | using pairType = std::pair<float, size_t>; |
6159 | std::vector<pairType> buf(n); |
6160 | |
6161 | while (in_p < tensor_end) { |
6162 | for (size_t i = 0; i < n; i++) { |
6163 | buf[i].first = in.raw(in_p++); |
6164 | buf[i].second = i; |
6165 | } |
6166 | // NOTE: it's possible to do N + KlogK, while this version is NlogN |
6167 | std::sort(buf.begin(), buf.end(), [](const pairType &a, const pairType &b) { |
6168 | if (a.first != b.first) |
6169 | return a.first > b.first; |
6170 | return a.second < b.second; |
6171 | }); |
6172 | for (size_t i = 0; i < k; i++) { |
6173 | values.raw(out_p) = buf[i].first; |
6174 | indices.raw(out_p) = buf[i].second; |
6175 | out_p++; |
6176 | } |
6177 | } |
6178 | } |
6179 | |
6180 | template <typename inpType, typename outType> |
6181 | static void fwdArgMax(Tensor *inpT, Tensor *outT, size_t axis) { |
6182 | |
6183 | // Get input/output handles with dimensions expanded to maximum. |
6184 | ShapeVector inpDims = expandDimsToMax(inpT->dims()); |
6185 | ShapeVector outDims = inpDims; |
6186 | outDims[axis] = 1; |
6187 | auto eInpT = inpT->getUnowned(inpDims); |
6188 | auto eOutT = outT->getUnowned(outDims); |
6189 | auto inpH = eInpT.getHandle<inpType>(); |
6190 | auto outH = eOutT.getHandle<outType>(); |
6191 | |
6192 | static_assert(max_tensor_dimensions == 6, |
6193 | "Loops below assume max_tensor_dimensions = 6." ); |
6194 | |
6195 | for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) { |
6196 | for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) { |
6197 | for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) { |
6198 | for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) { |
6199 | for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) { |
6200 | for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) { |
6201 | |
6202 | // Initialize maximum value/index. |
6203 | inpType maxVal = std::numeric_limits<inpType>::lowest(); |
6204 | outType maxIdx = 0; |
6205 | |
6206 | // Iterate input axis dimension. |
6207 | for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) { |
6208 | std::vector<dim_t> inpIdx = {idx0, idx1, idx2, |
6209 | idx3, idx4, idx5}; |
6210 | inpIdx[axis] = axisIdx; |
6211 | inpType inpVal = inpH.at(inpIdx); |
6212 | if (inpVal > maxVal) { |
6213 | maxVal = inpVal; |
6214 | maxIdx = axisIdx; |
6215 | } |
6216 | } |
6217 | |
6218 | // Store maximum index. |
6219 | outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = maxIdx; |
6220 | } |
6221 | } |
6222 | } |
6223 | } |
6224 | } |
6225 | } |
6226 | } |
6227 | |
6228 | template <typename inpType, typename outType> |
6229 | static void fwdArgMin(Tensor *inpT, Tensor *outT, size_t axis) { |
6230 | |
6231 | // Get input/output handles with dimensions expanded to maximum. |
6232 | ShapeVector inpDims = expandDimsToMax(inpT->dims()); |
6233 | ShapeVector outDims = inpDims; |
6234 | outDims[axis] = 1; |
6235 | auto eInpT = inpT->getUnowned(inpDims); |
6236 | auto eOutT = outT->getUnowned(outDims); |
6237 | auto inpH = eInpT.getHandle<inpType>(); |
6238 | auto outH = eOutT.getHandle<outType>(); |
6239 | |
6240 | static_assert(max_tensor_dimensions == 6, |
6241 | "Loops below assume max_tensor_dimensions = 6." ); |
6242 | |
6243 | for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) { |
6244 | for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) { |
6245 | for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) { |
6246 | for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) { |
6247 | for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) { |
6248 | for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) { |
6249 | |
6250 | // Initialize minimum value/index. |
6251 | inpType minVal = std::numeric_limits<inpType>::max(); |
6252 | outType minIdx = 0; |
6253 | |
6254 | // Iterate input axis dimension. |
6255 | for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) { |
6256 | std::vector<dim_t> inpIdx = {idx0, idx1, idx2, |
6257 | idx3, idx4, idx5}; |
6258 | inpIdx[axis] = axisIdx; |
6259 | inpType inpVal = inpH.at(inpIdx); |
6260 | if (inpVal < minVal) { |
6261 | minVal = inpVal; |
6262 | minIdx = axisIdx; |
6263 | } |
6264 | } |
6265 | |
6266 | // Store minimum index. |
6267 | outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = minIdx; |
6268 | } |
6269 | } |
6270 | } |
6271 | } |
6272 | } |
6273 | } |
6274 | } |
6275 | |
6276 | //===----------------------------------------------------------------------===// |
6277 | // Sorting operators |
6278 | //===----------------------------------------------------------------------===// |
6279 | template <typename ElemTy> |
6280 | static void |
6281 | CollectRpnProposalsHelper(const std::vector<std::vector<ElemTy>> &roisIn, |
6282 | const std::vector<ElemTy> &scores, |
6283 | std::vector<std::vector<ElemTy>> &rois, |
6284 | dim_t rpnPostNmsTopN) { |
6285 | // N + KlogK implementation, K is rpnPostNmsTopN |
6286 | dim_t roisSecondDim = roisIn[0].size(); |
6287 | |
6288 | std::vector<dim_t> idx(scores.size()); |
6289 | std::iota(idx.begin(), idx.end(), 0); |
6290 | |
6291 | auto comp = [&](dim_t leftIdx, dim_t rightIdx) -> bool { |
6292 | if (scores[leftIdx] > scores[rightIdx]) { |
6293 | return true; |
6294 | } |
6295 | if (scores[leftIdx] < scores[rightIdx]) { |
6296 | return false; |
6297 | } |
6298 | // Sort on indices if two scores are equal |
6299 | return leftIdx < rightIdx; |
6300 | }; |
6301 | |
6302 | dim_t k = rpnPostNmsTopN; |
6303 | // Getting the indices in order according to scores |
6304 | // Arrange Kth element at its proper location of sorted array, O(N) |
6305 | if (k < roisIn.size()) { |
6306 | std::nth_element(idx.begin(), idx.begin() + k, idx.end(), comp); |
6307 | } else { |
6308 | k = roisIn.size(); |
6309 | } |
6310 | |
6311 | rois.resize(k, std::vector<ElemTy>(roisSecondDim)); |
6312 | |
6313 | // KlogK, K is rpnPostNmsTopN |
6314 | std::sort(idx.begin(), idx.begin() + k, comp); |
6315 | |
6316 | // Output rois according to new order of indices |
6317 | for (dim_t i = 0; i < k; i++) { |
6318 | rois[i] = roisIn[idx[i]]; |
6319 | } |
6320 | } |
6321 | |
6322 | template <typename ElemTy> |
6323 | void BoundInterpreterFunction::fwdCollectRpnProposalsInstImpl( |
6324 | const glow::CollectRpnProposalsInst *I) { |
6325 | |
6326 | // Get params |
6327 | dim_t rpnPostNmsTopN = I->getRpnPostNmsTopN(); |
6328 | int64_t rpnMaxLevel = I->getRpnMaxLevel(); |
6329 | int64_t rpnMinLevel = I->getRpnMinLevel(); |
6330 | int64_t rpnLevels = rpnMaxLevel - rpnMinLevel + 1; |
6331 | |
6332 | std::vector<std::vector<ElemTy>> roisIn; |
6333 | std::vector<ElemTy> scores; |
6334 | std::vector<std::vector<ElemTy>> roisOut; |
6335 | |
6336 | auto getRoisAndScores = [&](Tensor *input, bool isroi) { |
6337 | auto inputHndl = input->getHandle<ElemTy>(); |
6338 | |
6339 | if (isroi) { |
6340 | for (dim_t i = 0; i < input->dims()[0]; i++) { |
6341 | std::vector<ElemTy> roi; |
6342 | |
6343 | for (dim_t j = 0; j < input->dims()[1]; j++) { |
6344 | roi.push_back(inputHndl.at({i, j})); |
6345 | } |
6346 | |
6347 | roisIn.push_back(roi); |
6348 | } |
6349 | } else { |
6350 | for (dim_t i = 0; i < input->dims()[0]; i++) { |
6351 | scores.push_back(inputHndl.at({i})); |
6352 | } |
6353 | } |
6354 | }; |
6355 | |
6356 | // Input starts from index 1 |
6357 | for (dim_t idx = 1; idx <= rpnLevels; idx++) { |
6358 | getRoisAndScores(getTensor(I->getOperand(idx).first), true); |
6359 | getRoisAndScores(getTensor(I->getOperand(idx + rpnLevels).first), false); |
6360 | } |
6361 | |
6362 | // Sorting the roisIn according to scores limited by rpnPostNmsTopN |
6363 | CollectRpnProposalsHelper<ElemTy>(roisIn, scores, roisOut, rpnPostNmsTopN); |
6364 | |
6365 | // Storing roisOut in result tensor |
6366 | dim_t roisSecondDim = roisIn[0].size(); |
6367 | Tensor *result = getTensor(I->getResult()); |
6368 | |
6369 | for (dim_t i = 0; i < rpnPostNmsTopN; i++) { |
6370 | for (dim_t j = 0; j < roisSecondDim; j++) { |
6371 | result->getHandle<ElemTy>().at({i, j}) = roisOut[i][j]; |
6372 | // Invalid rois are set to 0 |
6373 | if (i > roisOut.size()) { |
6374 | result->getHandle<ElemTy>().at({i, j}) = ElemTy(0); |
6375 | } |
6376 | } |
6377 | } |
6378 | } |
6379 | |
6380 | void BoundInterpreterFunction::fwdCollectRpnProposalsInst( |
6381 | const glow::CollectRpnProposalsInst *I) { |
6382 | dispatchFloatingPointImpl(fwdCollectRpnProposalsInstImpl, |
6383 | I->getOperand(1).first->getElementType(), I); |
6384 | } |
6385 | |
6386 | void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) { |
6387 | auto outW = getTensor(I->getValues()); |
6388 | auto indW = getTensor(I->getIndices()); |
6389 | auto inW = getTensor(I->getInput()); |
6390 | size_t k = I->getK(); |
6391 | |
6392 | if (inW->getType().isQuantizedType()) { |
6393 | if (indW->getElementType() == ElemKind::Int64ITy) { |
6394 | |
6395 | fwdTopK<int8_t, int64_t>(outW, indW, inW, k); |
6396 | } else if (indW->getElementType() == ElemKind::Int32ITy) { |
6397 | fwdTopK<int8_t, int32_t>(outW, indW, inW, k); |
6398 | } |
6399 | return; |
6400 | } |
6401 | |
6402 | dispatchFloatingPointAndIndexImpl(fwdTopK, inW->getElementType(), |
6403 | indW->getElementType(), outW, indW, inW, k); |
6404 | } |
6405 | |
6406 | void BoundInterpreterFunction::fwdBatchedUnaryEmbeddingsBagsInst( |
6407 | const BatchedUnaryEmbeddingsBagsInst *I) { |
6408 | dispatchFloatingPointAndIndexImpl(fwdBatchedUnaryEmbeddingsBagsInstImpl, |
6409 | I->getWeights()->getElementType(), |
6410 | I->getIndices()->getElementType(), I); |
6411 | } |
6412 | |
6413 | template <typename ElemTy, typename IndexType> |
6414 | void BoundInterpreterFunction::fwdBatchedUnaryEmbeddingsBagsInstImpl( |
6415 | const BatchedUnaryEmbeddingsBagsInst *I) { |
6416 | staticAssertFloatingPointType(ElemTy); |
6417 | |
6418 | auto out = getTensor(I->getDest()); |
6419 | auto weights = getTensor(I->getWeights()); |
6420 | auto tableOffsets = getTensor(I->getTableOffsets()); |
6421 | auto indices = getTensor(I->getIndices()); |
6422 | auto offsets = getTensor(I->getOffsets()); |
6423 | |
6424 | out->zero(); |
6425 | |
6426 | auto indicesH = indices->getHandle<IndexType>(); |
6427 | auto offsetsH = offsets->getHandle<IndexType>(); |
6428 | auto weightsH = weights->getHandle<ElemTy>(); |
6429 | auto tableOffsetsH = tableOffsets->getHandle<IndexType>(); |
6430 | auto outH = out->getHandle<ElemTy>(); |
6431 | |
6432 | size_t numTasks = weightsH.dims()[0]; |
6433 | size_t numTables = tableOffsets->size() - 1; |
6434 | size_t numBatches = (offsets->size() - 1) / numTables; |
6435 | |
6436 | IndexType sumTable = tableOffsetsH.raw(numTables); |
6437 | |
6438 | for (size_t n = 0; n < numTasks; n++) { |
6439 | for (size_t b = 0; b < numBatches; b++) { |
6440 | for (size_t t = 0; t < numTables; t++) { |
6441 | IndexType indicesStart = offsetsH.raw(t * numBatches + b); |
6442 | IndexType indicesEnd = offsetsH.raw(t * numBatches + b + 1); |
6443 | ElemTy sum = 0; |
6444 | for (IndexType i = indicesStart; i < indicesEnd; i++) { |
6445 | IndexType idx = n * sumTable + tableOffsetsH.raw(t) + indicesH.raw(i); |
6446 | assert(idx < weightsH.size() && |
6447 | "Index shall be within weights boundary." ); |
6448 | sum += weightsH.raw(idx); |
6449 | } |
6450 | outH.raw((n * numBatches + b) * numTables + t) = sum; |
6451 | } |
6452 | } |
6453 | } |
6454 | } |
6455 | |
6456 | void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingBagsInst( |
6457 | const IntNBitSplitEmbeddingBagsInst *I) { |
6458 | dispatchIndexAndOutputTypeImpl(fwdIntNBitSplitEmbeddingBagsInstImpl, |
6459 | I->getIndices()->getElementType(), |
6460 | I->getOutputDType(), I); |
6461 | } |
6462 | |
6463 | void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsInst( |
6464 | const IntNBitSplitEmbeddingWeightedBagsInst *I) { |
6465 | dispatchIndexAndOutputTypeImpl(fwdIntNBitSplitEmbeddingWeightedBagsInstImpl, |
6466 | I->getIndices()->getElementType(), |
6467 | I->getOutputDType(), I); |
6468 | } |
6469 | |
6470 | template <typename IndexTy, typename OutputTy> |
6471 | void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingBagsInstImpl( |
6472 | const IntNBitSplitEmbeddingBagsInst *I) { |
6473 | auto out = getTensor(I->getDest()); |
6474 | auto devWeights = getTensor(I->getDevWeights()); |
6475 | auto uvmWeights = getTensor(I->getUvmWeights()); |
6476 | auto weightsPlacements = getTensor(I->getWeightsPlacements()); |
6477 | auto weightsTys = getTensor(I->getWeightsTys()); |
6478 | auto dimOffsets = getTensor(I->getDimOffsets()); |
6479 | auto indices = getTensor(I->getIndices()); |
6480 | auto offsets = getTensor(I->getOffsets()); |
6481 | auto weightsOffsets = getTensor(I->getWeightsOffsets()); |
6482 | auto poolingMode = I->getPoolingMode(); |
6483 | auto totalDims = I->getTotalDims(); |
6484 | auto outputDType = I->getOutputDType(); |
6485 | |
6486 | fwdIntNBitSplitEmbeddingWeightedBagsImpl<IndexTy, OutputTy>( |
6487 | out, devWeights, uvmWeights, weightsPlacements, weightsTys, dimOffsets, |
6488 | indices, offsets, weightsOffsets, poolingMode, |
6489 | /* indiceWeights */ nullptr, totalDims, outputDType); |
6490 | } |
6491 | |
6492 | template <typename IndexTy, typename OutputTy> |
6493 | void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsInstImpl( |
6494 | const IntNBitSplitEmbeddingWeightedBagsInst *I) { |
6495 | auto out = getTensor(I->getDest()); |
6496 | auto devWeights = getTensor(I->getDevWeights()); |
6497 | auto uvmWeights = getTensor(I->getUvmWeights()); |
6498 | auto weightsPlacements = getTensor(I->getWeightsPlacements()); |
6499 | auto weightsTys = getTensor(I->getWeightsTys()); |
6500 | auto dimOffsets = getTensor(I->getDimOffsets()); |
6501 | auto indices = getTensor(I->getIndices()); |
6502 | auto offsets = getTensor(I->getOffsets()); |
6503 | auto weightsOffsets = getTensor(I->getWeightsOffsets()); |
6504 | auto poolingMode = I->getPoolingMode(); |
6505 | auto totalDims = I->getTotalDims(); |
6506 | auto outputDType = I->getOutputDType(); |
6507 | auto indiceWeights = getTensor(I->getIndiceWeight()); |
6508 | |
6509 | fwdIntNBitSplitEmbeddingWeightedBagsImpl<IndexTy, OutputTy>( |
6510 | out, devWeights, uvmWeights, weightsPlacements, weightsTys, dimOffsets, |
6511 | indices, offsets, weightsOffsets, poolingMode, indiceWeights, totalDims, |
6512 | outputDType); |
6513 | } |
6514 | |
6515 | template <typename IndexTy, typename OutputTy> |
6516 | void BoundInterpreterFunction::fwdIntNBitSplitEmbeddingWeightedBagsImpl( |
6517 | Tensor *out, Tensor *devWeights, Tensor *uvmWeights, |
6518 | Tensor *weightsPlacements, Tensor *weightsTys, Tensor *dimOffsets, |
6519 | Tensor *indices, Tensor *offsets, Tensor *weightsOffsets, |
6520 | int64_t poolingMode, Tensor *indiceWeights, int64_t totalDims, |
6521 | int64_t outputDType) { |
6522 | out->zero(); |
6523 | |
6524 | auto indicesH = indices->getHandle<IndexTy>(); |
6525 | auto offsetsH = offsets->getHandle<IndexTy>(); |
6526 | auto devWeightsH = devWeights->getHandle<uint8_t>(); |
6527 | auto uvmWeightsH = uvmWeights->getHandle<uint8_t>(); |
6528 | auto weightsPlacementH = weightsPlacements->getHandle<int32_t>(); |
6529 | auto weightsTysH = weightsTys->getHandle<uint8_t>(); |
6530 | auto dimOffsetsH = dimOffsets->getHandle<int32_t>(); |
6531 | auto weightsOffsetsH = weightsOffsets->getHandle<int32_t>(); |
6532 | llvm::Optional<Handle<float>> indiceWeightsH; |
6533 | if (indiceWeights) { |
6534 | indiceWeightsH = indiceWeights->getHandle<float>(); |
6535 | } |
6536 | auto outH = out->getHandle<uint8_t>(); |
6537 | |
6538 | size_t numTables = dimOffsets->size() - 1; |
6539 | size_t numBatches = (offsets->size() - 1) / numTables; |
6540 | auto outputSparseType = static_cast<SplitEmbeddingSparseType>(outputDType); |
6541 | auto numTotalBytes = IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes( |
6542 | totalDims, outputSparseType); |
6543 | int64_t maxDevWeightsBytes = devWeights->getSizeInBytes(); |
6544 | int64_t maxUvmWeightsBytes = uvmWeights->getSizeInBytes(); |
6545 | |
6546 | for (int32_t t = 0; t < numTables; t++) { |
6547 | const int32_t dimStart = dimOffsetsH.raw(t); |
6548 | const int32_t numDims = dimOffsetsH.raw(t + 1) - dimOffsetsH.raw(t); |
6549 | const auto placement = weightsPlacementH.raw(t); |
6550 | assert(placement != WeightsPlacement::DEVICE); |
6551 | auto weightsH = uvmWeightsH; |
6552 | auto maxWeightsBytes = maxUvmWeightsBytes; |
6553 | if (placement == WeightsPlacement::HOST) { |
6554 | weightsH = devWeightsH; |
6555 | maxWeightsBytes = maxDevWeightsBytes; |
6556 | } |
6557 | auto weightTy = static_cast<SplitEmbeddingSparseType>(weightsTysH.raw(t)); |
6558 | assert(weightTy != SplitEmbeddingSparseType::EST_INT2 && |
6559 | "Int2 sparse type isn't supported yet." ); |
6560 | auto numDimBytes = IntNBitSplitEmbeddingBagsHelper::paddedRowSizeInBytes( |
6561 | numDims, weightTy); |
6562 | |
6563 | for (int32_t b = 0; b < numBatches; b++) { |
6564 | IndexTy indicesStart = offsetsH.raw(t * numBatches + b); |
6565 | IndexTy indicesEnd = offsetsH.raw(t * numBatches + b + 1); |
6566 | for (int32_t d = 0; d < numDims; d++) { |
6567 | OutputTy sum = 0; |
6568 | for (IndexTy i = indicesStart; i < indicesEnd; i++) { |
6569 | int64_t idxRow = |
6570 | weightsOffsetsH.raw(t) + indicesH.raw(i) * numDimBytes; |
6571 | int64_t idxData = |
6572 | idxRow + IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes( |
6573 | d, weightTy); |
6574 | assert(idxData < maxWeightsBytes && |
6575 | "Index shall be within weights boundary." ); |
6576 | sum = IntNBitSplitEmbeddingBagsHelper::add( |
6577 | sum, &weightsH.raw(idxRow), &weightsH.raw(idxData), weightTy, |
6578 | d % 2, indiceWeights ? indiceWeightsH->raw(i) : 1.0); |
6579 | } |
6580 | int64_t idxOut = |
6581 | b * numTotalBytes + |
6582 | IntNBitSplitEmbeddingBagsHelper::unpaddedRowSizeInBytes( |
6583 | dimStart + d, outputSparseType); |
6584 | if (poolingMode == SplitEmbeddingPoolingMode::EP_MEAN && |
6585 | indicesEnd > indicesStart) { |
6586 | OutputTy scale = (indicesEnd - indicesStart); |
6587 | IntNBitSplitEmbeddingBagsHelper::save(sum / scale, &outH.raw(idxOut), |
6588 | outputSparseType); |
6589 | } else { |
6590 | IntNBitSplitEmbeddingBagsHelper::save(sum, &outH.raw(idxOut), |
6591 | outputSparseType); |
6592 | } |
6593 | } |
6594 | } |
6595 | } |
6596 | } |
6597 | |
6598 | #define DISPATCH_ARG_MIN_MAX(functionName, elemTy, elemTyIndex, ...) \ |
6599 | switch (elemTy) { \ |
6600 | case ElemKind::FloatTy: \ |
6601 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
6602 | functionName<float, int64_t>(__VA_ARGS__); \ |
6603 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
6604 | functionName<float, int32_t>(__VA_ARGS__); \ |
6605 | } \ |
6606 | break; \ |
6607 | case ElemKind::Float16Ty: \ |
6608 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
6609 | functionName<float16_t, int64_t>(__VA_ARGS__); \ |
6610 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
6611 | functionName<float16_t, int32_t>(__VA_ARGS__); \ |
6612 | } \ |
6613 | break; \ |
6614 | case ElemKind::Int8QTy: \ |
6615 | if (elemTyIndex == ElemKind::Int64ITy) { \ |
6616 | functionName<int8_t, int64_t>(__VA_ARGS__); \ |
6617 | } else if (elemTyIndex == ElemKind::Int32ITy) { \ |
6618 | functionName<int8_t, int32_t>(__VA_ARGS__); \ |
6619 | } \ |
6620 | break; \ |
6621 | default: \ |
6622 | llvm_unreachable("Type is not supported"); \ |
6623 | } |
6624 | |
6625 | void BoundInterpreterFunction::fwdArgMaxInst(const ArgMaxInst *I) { |
6626 | auto inpT = getTensor(I->getSrc()); |
6627 | auto outT = getTensor(I->getDest()); |
6628 | size_t axis = I->getAxis(); |
6629 | auto inpElemType = inpT->getElementType(); |
6630 | auto outElemType = outT->getElementType(); |
6631 | DISPATCH_ARG_MIN_MAX(fwdArgMax, inpElemType, outElemType, inpT, outT, axis); |
6632 | } |
6633 | |
6634 | void BoundInterpreterFunction::fwdArgMinInst(const ArgMinInst *I) { |
6635 | auto inpT = getTensor(I->getSrc()); |
6636 | auto outT = getTensor(I->getDest()); |
6637 | size_t axis = I->getAxis(); |
6638 | auto inpElemType = inpT->getElementType(); |
6639 | auto outElemType = outT->getElementType(); |
6640 | DISPATCH_ARG_MIN_MAX(fwdArgMin, inpElemType, outElemType, inpT, outT, axis); |
6641 | } |
6642 | #undef DISPATCH_ARG_MIN_MAX |
6643 | |
6644 | //===----------------------------------------------------------------------===// |
6645 | // Tensor allocation operations |
6646 | //===----------------------------------------------------------------------===// |
6647 | |
6648 | void BoundInterpreterFunction::fwdAllocActivationInst( |
6649 | const AllocActivationInst *I) { |
6650 | getOrCreateTensor(I); |
6651 | } |
6652 | |
6653 | void BoundInterpreterFunction::fwdDeallocActivationInst( |
6654 | const DeallocActivationInst *I) { |
6655 | deleteTensor(I->getSrc()); |
6656 | } |
6657 | |
6658 | //===----------------------------------------------------------------------===// |
6659 | // Debug instructions |
6660 | //===----------------------------------------------------------------------===// |
6661 | /// Prints a value of the instruction's operand. |
6662 | /// In most cases it will be the name of the variable and the value of the |
6663 | /// tensor. |
6664 | void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { |
6665 | auto *V = I->getSrc(); |
6666 | auto *T = getTensor(V); |
6667 | std::string format = I->getFormat(); |
6668 | std::string filename = I->getFileName(); |
6669 | |
6670 | if (format == "console" ) { |
6671 | // Dump tensor in console. |
6672 | llvm::outs() << I->getName() << ": " ; |
6673 | V->dump(); |
6674 | llvm::outs() << "\n" ; |
6675 | dumpImpl(T); |
6676 | llvm::outs() << "\n" ; |
6677 | } else if (format == "bin" ) { |
6678 | TensorSerializationOptions opts; |
6679 | opts.withType = true; |
6680 | glow::dumpTensorToBinaryFile(*T, filename, opts); |
6681 | } else if (format == "txt" ) { |
6682 | TensorSerializationOptions opts; |
6683 | opts.withType = true; |
6684 | glow::dumpTensorToTextFile(*T, filename, opts); |
6685 | } else if (format == "rawbin" ) { |
6686 | TensorSerializationOptions opts; |
6687 | opts.withType = false; |
6688 | glow::dumpTensorToBinaryFile(*T, filename, opts); |
6689 | } else if (format == "rawtxt" ) { |
6690 | TensorSerializationOptions opts; |
6691 | opts.withType = false; |
6692 | glow::dumpTensorToTextFile(*T, filename, opts); |
6693 | } else { |
6694 | llvm_unreachable("DebugPrint format not supported!" ); |
6695 | } |
6696 | } |
6697 | |
6698 | void BoundInterpreterFunction::fwdTraceEventInst(const TraceEventInst *I) { |
6699 | auto T = getTensor(I->getData()); |
6700 | auto IH = T->getHandle<int64_t>(); |
6701 | size_t index = I->getIndex(); |
6702 | IH.raw(index) = std::chrono::duration_cast<std::chrono::microseconds>( |
6703 | std::chrono::steady_clock::now().time_since_epoch()) |
6704 | .count(); |
6705 | } |
6706 | |
6707 | void BoundInterpreterFunction::fwdInstrumentInst(const InstrumentInst *I) { |
6708 | // The instrument instruction is not implemented on the Interpreter backend. |
6709 | // We cannot throw error though because the Interpreter can be potentially |
6710 | // used when constant folding parts of the graph while compiling for the |
6711 | // CPU backend with IR instrumentation. |
6712 | } |
6713 | |
6714 | //===----------------------------------------------------------------------===// |
6715 | // Instructions used by Quantization |
6716 | //===----------------------------------------------------------------------===// |
6717 | void BoundInterpreterFunction::fwdQuantizationProfileInst( |
6718 | const glow::QuantizationProfileInst *I) { |
6719 | auto inputTensor = getWeightHandle(I->getInputTensor()); |
6720 | auto currentHistogram = getWeightHandle(I->getHistogram()); |
6721 | auto computationInfo = getWeightHandle(I->getComputationInfo()); |
6722 | |
6723 | float &min = computationInfo.raw(0); |
6724 | float &max = computationInfo.raw(1); |
6725 | |
6726 | // Update current histogram, min and max based on the inputTensor data. |
6727 | quantization::generateTensorHistogram(inputTensor, currentHistogram, min, |
6728 | max); |
6729 | } |
6730 | |
6731 | /// Quantize floating point tensor. Scale and Offset are based on return type |
6732 | /// of the instruction \p I. |
6733 | void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { |
6734 | auto *srcTensor = getTensor(I->getSrc()); |
6735 | auto *destTensor = getTensor(I->getDest()); |
6736 | auto destTy = destTensor->getType(); |
6737 | Tensor qTensor = quantization::quantizeTensor( |
6738 | *srcTensor, {destTy.getScale(), destTy.getOffset()}, |
6739 | destTy.getElementType()); |
6740 | destTensor->assign(&qTensor); |
6741 | } |
6742 | |
6743 | /// Dequantize integer tensor. Scale and Offset are based |
6744 | /// on the source tensor type. |
6745 | void BoundInterpreterFunction::fwdDequantizeInst( |
6746 | const glow::DequantizeInst *I) { |
6747 | auto *srcTensor = getTensor(I->getSrc()); |
6748 | auto *destTensor = getTensor(I->getDest()); |
6749 | auto destTy = destTensor->getType(); |
6750 | Tensor fTensor = |
6751 | quantization::dequantizeTensor(*srcTensor, destTy.getElementType()); |
6752 | destTensor->assign(&fTensor); |
6753 | } |
6754 | |
6755 | template <class eTy> |
6756 | void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl( |
6757 | Value *src, Value *dest, TensorQuantizationParams &srcQ, |
6758 | TensorQuantizationParams &destQ) { |
6759 | |
6760 | auto srcH = getWeightHandle<eTy>(src); |
6761 | auto destH = getWeightHandle<eTy>(dest); |
6762 | |
6763 | for (size_t i = 0, e = destH.size(); i < e; ++i) { |
6764 | float val = quantization::dequantize(srcH.raw(i), srcQ); |
6765 | destH.raw(i) = quantization::quantize(val, destQ); |
6766 | } |
6767 | } |
6768 | |
6769 | void BoundInterpreterFunction::fwdRescaleQuantizedInst( |
6770 | const glow::RescaleQuantizedInst *I) { |
6771 | auto src = I->getSrc(); |
6772 | auto dest = I->getDest(); |
6773 | auto srcTy = src->getType(); |
6774 | auto destTy = dest->getType(); |
6775 | |
6776 | TensorQuantizationParams srcQ{srcTy->getScale(), srcTy->getOffset()}; |
6777 | TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()}; |
6778 | |
6779 | dispatchQuantizedImpl(fwdRescaleQuantizedInstImpl, destTy->getElementType(), |
6780 | src, dest, srcQ, destQ); |
6781 | } |
6782 | |
6783 | void BoundInterpreterFunction::fwdIntLookupTableInst( |
6784 | const IntLookupTableInst *I) { |
6785 | if (I->getSrc()->getElementType() == ElemKind::Int8QTy) { |
6786 | auto srcH = getWeightHandle<int8_t>(I->getSrc()); |
6787 | auto destH = getWeightHandle<int8_t>(I->getDest()); |
6788 | auto mappingH = getWeightHandle<int8_t>(I->getMapping()); |
6789 | for (size_t i = 0, e = destH.size(); i < e; i++) { |
6790 | destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 128); |
6791 | } |
6792 | } else if (I->getSrc()->getElementType() == ElemKind::Int16QTy) { |
6793 | auto srcH = getWeightHandle<int16_t>(I->getSrc()); |
6794 | auto destH = getWeightHandle<int16_t>(I->getDest()); |
6795 | auto mappingH = getWeightHandle<int16_t>(I->getMapping()); |
6796 | for (size_t i = 0, e = destH.size(); i < e; i++) { |
6797 | destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 32768); |
6798 | } |
6799 | } else { |
6800 | llvm_unreachable("Type not supported for IntLookupTable!" ); |
6801 | } |
6802 | } |
6803 | |
6804 | void BoundInterpreterFunction::fwdLookupTableInst(const LookupTableInst *I) { |
6805 | llvm_unreachable("LookupTable instruction is not supported yet" ); |
6806 | } |
6807 | |
6808 | void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { |
6809 | Tensor *source = getTensor(I->getInput()); |
6810 | Tensor *dest = getTensor(I->getResult()); |
6811 | auto srcElType = source->getType().getElementType(); |
6812 | auto destElType = dest->getType().getElementType(); |
6813 | if (srcElType == destElType) { |
6814 | // This is a noop conversion. |
6815 | dest->copyRawFrom(source); |
6816 | return; |
6817 | } |
6818 | |
6819 | #define CONVERT(T_FROM, T_TO, DTY_FROM, DTY_TO) \ |
6820 | if (srcElType == DTY_FROM && destElType == DTY_TO) { \ |
6821 | dest->copyWithCast<T_TO, T_FROM>(source); \ |
6822 | return; \ |
6823 | } |
6824 | CONVERT(float, float16_t, ElemKind::FloatTy, ElemKind::Float16Ty) |
6825 | CONVERT(float, bfloat16_t, ElemKind::FloatTy, ElemKind::BFloat16Ty) |
6826 | CONVERT(float, bool, ElemKind::FloatTy, ElemKind::BoolTy) |
6827 | CONVERT(float, int32_t, ElemKind::FloatTy, ElemKind::Int32ITy) |
6828 | CONVERT(float, int64_t, ElemKind::FloatTy, ElemKind::Int64ITy) |
6829 | CONVERT(float16_t, float, ElemKind::Float16Ty, ElemKind::FloatTy) |
6830 | CONVERT(float16_t, bfloat16_t, ElemKind::Float16Ty, ElemKind::BFloat16Ty) |
6831 | CONVERT(float16_t, int32_t, ElemKind::Float16Ty, ElemKind::Int32ITy) |
6832 | CONVERT(float16_t, int64_t, ElemKind::Float16Ty, ElemKind::Int64ITy) |
6833 | CONVERT(bfloat16_t, float, ElemKind::BFloat16Ty, ElemKind::FloatTy) |
6834 | CONVERT(bfloat16_t, float16_t, ElemKind::BFloat16Ty, ElemKind::Float16Ty) |
6835 | CONVERT(bfloat16_t, int32_t, ElemKind::BFloat16Ty, ElemKind::Int32ITy) |
6836 | CONVERT(bfloat16_t, int64_t, ElemKind::BFloat16Ty, ElemKind::Int64ITy) |
6837 | CONVERT(bool, float, ElemKind::BoolTy, ElemKind::FloatTy) |
6838 | CONVERT(bool, bfloat16_t, ElemKind::BoolTy, ElemKind::BFloat16Ty) |
6839 | CONVERT(int32_t, float, ElemKind::Int32ITy, ElemKind::FloatTy) |
6840 | CONVERT(int32_t, float16_t, ElemKind::Int32ITy, ElemKind::Float16Ty) |
6841 | CONVERT(int32_t, bfloat16_t, ElemKind::Int32ITy, ElemKind::BFloat16Ty) |
6842 | CONVERT(int32_t, int64_t, ElemKind::Int32ITy, ElemKind::Int64ITy) |
6843 | CONVERT(int64_t, float, ElemKind::Int64ITy, ElemKind::FloatTy) |
6844 | CONVERT(int64_t, float16_t, ElemKind::Int64ITy, ElemKind::Float16Ty) |
6845 | CONVERT(int64_t, bfloat16_t, ElemKind::Int64ITy, ElemKind::BFloat16Ty) |
6846 | CONVERT(int64_t, int32_t, ElemKind::Int64ITy, ElemKind::Int32ITy) |
6847 | CONVERT(bool, int32_t, ElemKind::BoolTy, ElemKind::Int32ITy) |
6848 | #undef CONVERT |
6849 | |
6850 | if (srcElType == ElemKind::UInt8FusedQTy && |
6851 | destElType == ElemKind::UInt8FusedFP16QTy) { |
6852 | Tensor result = source->getCopyConvertedToType(ElemKind::UInt8FusedFP16QTy); |
6853 | dest->assign(&result); |
6854 | return; |
6855 | } |
6856 | |
6857 | if ((srcElType == ElemKind::UInt8FusedFP16QTy || |
6858 | srcElType == ElemKind::UInt4FusedFP16QTy) && |
6859 | destElType == ElemKind::UInt8FusedQTy) { |
6860 | Tensor result = source->getCopyConvertedToType(ElemKind::UInt8FusedQTy); |
6861 | dest->assign(&result); |
6862 | return; |
6863 | } |
6864 | |
6865 | if (srcElType == ElemKind::UInt4FusedFP16QTy && |
6866 | destElType == ElemKind::UInt4FusedQTy) { |
6867 | Tensor result = source->getCopyConvertedToType(ElemKind::UInt4FusedQTy); |
6868 | dest->assign(&result); |
6869 | return; |
6870 | } |
6871 | |
6872 | llvm_unreachable("Type not supported" ); |
6873 | } |
6874 | |
6875 | template <typename ElemTy> |
6876 | void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInstImpl( |
6877 | const BatchedPairwiseDotProductInst *I) { |
6878 | auto destT = getTensor(I->getDest()); |
6879 | auto destH = destT->getHandle<ElemTy>(); |
6880 | |
6881 | dim_t batchCount = destT->getType().dims()[0]; |
6882 | |
6883 | // Gather all batched vector operands into an array so that they can be |
6884 | // indexed easily. |
6885 | std::vector<Value *> srcs; |
6886 | for (unsigned i = 1, e = I->getNumOperands(); i < e; ++i) { |
6887 | auto op = I->getOperand(i); |
6888 | srcs.emplace_back(op.first); |
6889 | } |
6890 | |
6891 | // pairIdx is the total number of pairs (i, j) that have been processed. |
6892 | unsigned pairIdx = 0; |
6893 | |
6894 | // For each src operand: |
6895 | for (unsigned i = 1, e = I->getNumInputs(); i < e; ++i) { |
6896 | auto vAH = getTensor(srcs[i])->getHandle<ElemTy>(); |
6897 | dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1]; |
6898 | |
6899 | // Compute the dot product of src[i] with every other vector with a |
6900 | // smaller index. |
6901 | for (unsigned j = 0; j < i; ++j) { |
6902 | auto vBH = getTensor(srcs[j])->getHandle<ElemTy>(); |
6903 | |
6904 | // Process all batches for a given pair (i, j). |
6905 | for (dim_t b = 0; b < batchCount; ++b) { |
6906 | ElemTy accum = 0; |
6907 | |
6908 | for (dim_t k = 0; k < vectorSize; ++k) { |
6909 | accum += vAH.at({b, k}) * vBH.at({b, k}); |
6910 | } |
6911 | |
6912 | destH.at({b, pairIdx}) = accum; |
6913 | } |
6914 | |
6915 | ++pairIdx; |
6916 | } |
6917 | } |
6918 | } |
6919 | |
6920 | void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInst( |
6921 | const BatchedPairwiseDotProductInst *I) { |
6922 | dispatchImpl(fwdBatchedPairwiseDotProductInstImpl, |
6923 | I->getDest()->getElementType(), I); |
6924 | } |
6925 | |
6926 | template <typename ElemTy> |
6927 | void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInstImpl( |
6928 | const BatchedPairwiseDotProductGradInst *I) { |
6929 | auto destGradT = getTensor(I->getDestGrad()); |
6930 | auto destGradH = destGradT->getHandle<ElemTy>(); |
6931 | |
6932 | dim_t batchCount = destGradT->getType().dims()[0]; |
6933 | |
6934 | // Gather all batched vector operands into arrays so that they can be |
6935 | // indexed easily. Operands 1 -> numInputs are gradients of inputs, and |
6936 | // operands numInputs + 1 -> numOperands - 1 are the corresponding original |
6937 | // inputs. |
6938 | std::vector<Value *> srcs, srcGrads; |
6939 | for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) { |
6940 | auto gradOp = I->getOperand(i + 1); |
6941 | auto inputOp = I->getOperand(i + 1 + e); |
6942 | |
6943 | srcGrads.emplace_back(gradOp.first); |
6944 | srcs.emplace_back(inputOp.first); |
6945 | } |
6946 | |
6947 | // Zero initialize all srcGrad tensors. |
6948 | for (auto &s : srcGrads) { |
6949 | getTensor(s)->zero(); |
6950 | } |
6951 | |
6952 | // pairIdx is the total number of pairs (i, j) that have been processed. |
6953 | unsigned pairIdx = 0; |
6954 | |
6955 | // For each srcGrad operand: |
6956 | for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) { |
6957 | auto dvAH = getTensor(srcGrads[i])->getHandle<ElemTy>(); |
6958 | dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1]; |
6959 | |
6960 | // Accmulate into it the product of the gradient of all dot products that |
6961 | // src[i] contributed to and the corresponding vectors that src[i] was |
6962 | // dotted with. |
6963 | for (unsigned j = i + 1; j < e; ++j) { |
6964 | auto vBH = getTensor(srcs[j])->getHandle<ElemTy>(); |
6965 | |
6966 | // Process all batches for a given pair (i, j). |
6967 | for (dim_t b = 0; b < batchCount; ++b) { |
6968 | ElemTy grad = destGradH.at({b, pairIdx}); |
6969 | |
6970 | for (dim_t k = 0; k < vectorSize; ++k) { |
6971 | dvAH.at({b, k}) += grad * vBH.at({b, k}); |
6972 | } |
6973 | } |
6974 | |
6975 | ++pairIdx; |
6976 | } |
6977 | } |
6978 | } |
6979 | |
6980 | void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInst( |
6981 | const BatchedPairwiseDotProductGradInst *I) { |
6982 | dispatchImpl(fwdBatchedPairwiseDotProductGradInstImpl, |
6983 | I->getDestGrad()->getElementType(), I); |
6984 | } |
6985 | |
6986 | template <typename ElemTy> |
6987 | void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) { |
6988 | |
6989 | static_assert(max_tensor_dimensions == 6, |
6990 | "Loops below assume max_tensor_dimensions = 6." ); |
6991 | |
6992 | auto *src = I->getSrc(); |
6993 | auto *dest = I->getDest(); |
6994 | |
6995 | // Get unowned handles of src and dest with dims expanded to maximum. |
6996 | ShapeVector eDims = expandDimsToMax(src->dims()); |
6997 | auto eSrc = getTensor(src)->getUnowned(eDims); |
6998 | auto eDest = getTensor(dest)->getUnowned(eDims); |
6999 | auto srcH = eSrc.getHandle<ElemTy>(); |
7000 | auto destH = eDest.getHandle<ElemTy>(); |
7001 | |
7002 | #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \ |
7003 | for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \ |
7004 | for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \ |
7005 | for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \ |
7006 | for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \ |
7007 | for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \ |
7008 | for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \ |
7009 | destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \ |
7010 | srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \ |
7011 | } \ |
7012 | return; |
7013 | |
7014 | switch (I->getAxis()) { |
7015 | case 0: |
7016 | LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5); |
7017 | case 1: |
7018 | LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5); |
7019 | case 2: |
7020 | LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5); |
7021 | case 3: |
7022 | LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5); |
7023 | case 4: |
7024 | LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5); |
7025 | case 5: |
7026 | LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5); |
7027 | default: |
7028 | llvm_unreachable("Axis should be less than max_tensor_dimensions." ); |
7029 | } |
7030 | } |
7031 | |
7032 | void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) { |
7033 | dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I); |
7034 | } |
7035 | |
7036 | //===----------------------------------------------------------------------===// |
7037 | // Instructions used by ObjectDetection |
7038 | //===----------------------------------------------------------------------===// |
7039 | static void maxMin(float lhs, float rhs, float &min, float &max) { |
7040 | if (lhs >= rhs) { |
7041 | min = rhs; |
7042 | max = lhs; |
7043 | } else { |
7044 | min = lhs; |
7045 | max = rhs; |
7046 | } |
7047 | } |
7048 | |
7049 | using ClassBox = std::pair<float, dim_t>; |
7050 | |
7051 | struct Box { |
7052 | float classValue{0.0f}; |
7053 | dim_t batchIndex{0}; |
7054 | dim_t classIndex{0}; |
7055 | dim_t boxIndex{0}; |
7056 | }; |
7057 | |
7058 | template <typename ElemTy> |
7059 | static bool doIOU(Handle<ElemTy> &boxes, dim_t batchIndex, |
7060 | dim_t selectedBoxIndex, dim_t candidateBoxIndex, |
7061 | int centerPointBox, float iouThreshold, bool isV4) { |
7062 | float sx[] = {0.0f, 0.0f, 0.0f, 0.0f}; |
7063 | float cx[] = {0.0f, 0.0f, 0.0f, 0.0f}; |
7064 | |
7065 | if (isV4) { |
7066 | for (dim_t i = 0; i < 4; ++i) { |
7067 | sx[i] = boxes.at({selectedBoxIndex, i}); |
7068 | cx[i] = boxes.at({candidateBoxIndex, i}); |
7069 | } |
7070 | } else { |
7071 | for (dim_t i = 0; i < 4; ++i) { |
7072 | sx[i] = boxes.at({batchIndex, selectedBoxIndex, i}); |
7073 | cx[i] = boxes.at({batchIndex, candidateBoxIndex, i}); |
7074 | } |
7075 | } |
7076 | |
7077 | float xSMin = 0.0f; |
7078 | float ySMin = 0.0f; |
7079 | float xSMax = 0.0f; |
7080 | float ySMax = 0.0f; |
7081 | |
7082 | float xCMin = 0.0f; |
7083 | float yCMin = 0.0f; |
7084 | float xCMax = 0.0f; |
7085 | float yCMax = 0.0f; |
7086 | |
7087 | // Standardizing coordinates so that (xmin, ymin) is upper left corner of a |
7088 | // box and (xmax, ymax) is lower right corner of the box. |
7089 | if (!centerPointBox) { |
7090 | // 0 means coordinates for diagonal ends of a box. |
7091 | // Coordinates can either be absolute or normalized. |
7092 | maxMin(sx[0], sx[2], xSMin, xSMax); |
7093 | maxMin(sx[1], sx[3], ySMin, ySMax); |
7094 | |
7095 | maxMin(cx[0], cx[2], xCMin, xCMax); |
7096 | maxMin(cx[1], cx[3], yCMin, yCMax); |
7097 | } else { |
7098 | float halfWidthS = sx[2] / 2.0f; |
7099 | float halfHeightS = sx[3] / 2.0f; |
7100 | float halfWidthC = cx[2] / 2.0f; |
7101 | float halfHeightC = cx[3] / 2.0f; |
7102 | |
7103 | xSMin = sx[0] - halfWidthS; |
7104 | ySMin = sx[1] - halfHeightS; |
7105 | xSMax = sx[0] + halfWidthS; |
7106 | ySMax = sx[1] + halfHeightS; |
7107 | |
7108 | xCMin = cx[0] - halfWidthC; |
7109 | yCMin = cx[1] - halfHeightC; |
7110 | xCMax = cx[0] + halfWidthC; |
7111 | yCMax = cx[1] + halfHeightC; |
7112 | } |
7113 | |
7114 | // finding upper left and lower right corner of a box formed by |
7115 | // intersection. |
7116 | float xMin = std::max(xSMin, xCMin); |
7117 | float yMin = std::max(ySMin, yCMin); |
7118 | float xMax = std::min(xSMax, xCMax); |
7119 | float yMax = std::min(ySMax, yCMax); |
7120 | |
7121 | float intersectionArea = |
7122 | std::max(0.0f, xMax - xMin) * std::max(0.0f, yMax - yMin); |
7123 | |
7124 | if (intersectionArea == 0.0f) { |
7125 | return false; |
7126 | } |
7127 | |
7128 | float sArea = (xSMax - xSMin) * (ySMax - ySMin); |
7129 | float cArea = (xCMax - xCMin) * (yCMax - yCMin); |
7130 | float unionArea = sArea + cArea - intersectionArea; |
7131 | |
7132 | return intersectionArea > iouThreshold * unionArea; |
7133 | } |
7134 | |
7135 | template <typename T> |
7136 | void BoundInterpreterFunction::fwdNonMaxSuppressionInstImpl( |
7137 | glow::NonMaxSuppressionInst const *I) { |
7138 | |
7139 | auto boxes = I->getBoxes(); |
7140 | auto scores = I->getScores(); |
7141 | auto indices = I->getIndices(); |
7142 | auto numDetected = I->getNumberOfSelectedIndices(); |
7143 | float iouThreshold = I->getIouThreshold(); |
7144 | dim_t maxBoxesPerClass = I->getMaxOutputBoxesPerClass(); |
7145 | float scoreThreshold = I->getScoreThreshold(); |
7146 | unsigned centerPointBox = I->getCenterPointBox(); |
7147 | bool isV4 = I->getIsTFVersion4(); |
7148 | |
7149 | auto boxesH = getTensor(boxes)->getHandle<float>(); |
7150 | auto scoresH = getTensor(scores)->getHandle<float>(); |
7151 | auto indicesH = getTensor(indices)->getHandle<T>(); |
7152 | auto numDetectedH = getTensor(numDetected)->getHandle<T>(); |
7153 | |
7154 | int boxesBoxDim = boxes->dims().size() - 2; |
7155 | |
7156 | dim_t numBatches = 1; |
7157 | dim_t numClasses = 1; |
7158 | dim_t numBoxes = boxes->dims()[boxesBoxDim]; |
7159 | |
7160 | size_t maxOutputPerBatch = 0; |
7161 | |
7162 | if (!isV4) { |
7163 | int boxesBatchDim = boxes->dims().size() - 3; |
7164 | |
7165 | int scoresBatchDim = scores->dims().size() - 3; |
7166 | int scoresBoxDim = scores->dims().size() - 1; |
7167 | int scoresClassDim = scores->dims().size() - 2; |
7168 | assert(scores->dims()[scoresBoxDim] == boxes->dims()[boxesBoxDim] && |
7169 | "Mismatch between number of scores and number of boxes." ); |
7170 | assert(scores->dims()[scoresBatchDim] == boxes->dims()[boxesBatchDim] && |
7171 | "Mismatch in batch dimension." ); |
7172 | (void)boxesBatchDim; |
7173 | (void)scoresBoxDim; |
7174 | numBatches = scores->dims()[scoresBatchDim]; |
7175 | numClasses = scores->dims()[scoresClassDim]; |
7176 | numBoxes = boxes->dims()[boxesBoxDim]; |
7177 | maxOutputPerBatch = |
7178 | indices->dims()[indices->dims().size() - 2] / numBatches; |
7179 | } else { |
7180 | maxOutputPerBatch = |
7181 | indices->dims()[indices->dims().size() - 1] / numBatches; |
7182 | } |
7183 | |
7184 | auto cmpFunc = [](const ClassBox &a, const ClassBox &b) { |
7185 | return a.first < b.first; |
7186 | }; |
7187 | |
7188 | std::vector<ClassBox> selectedIndices(numBoxes); |
7189 | dim_t outPutBoxIndex = 0; |
7190 | |
7191 | for (dim_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) { |
7192 | Box minBox{scoresH.raw(batchIndex * numClasses * numBoxes), batchIndex, 0, |
7193 | 0}; |
7194 | int32_t detectedPerBatch = 0; |
7195 | for (dim_t classIndex = 0; classIndex < numClasses; ++classIndex) { |
7196 | selectedIndices.clear(); |
7197 | size_t detectedPerClass = 0; |
7198 | std::priority_queue<ClassBox, std::vector<ClassBox>, decltype(cmpFunc)> |
7199 | queue(cmpFunc); |
7200 | |
7201 | for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { |
7202 | float classValue = scoresH.raw( |
7203 | (batchIndex * numClasses + classIndex) * numBoxes + boxIndex); |
7204 | if (classValue > scoreThreshold) { |
7205 | queue.emplace(classValue, boxIndex); |
7206 | } |
7207 | } |
7208 | |
7209 | float tScore = minBox.classValue; |
7210 | while (!queue.empty()) { |
7211 | auto priorBox = queue.top(); |
7212 | queue.pop(); |
7213 | |
7214 | bool selected = true; |
7215 | for (auto &sBox : selectedIndices) { |
7216 | if (doIOU(boxesH, batchIndex, sBox.second, priorBox.second, |
7217 | centerPointBox, iouThreshold, isV4)) { |
7218 | selected = false; |
7219 | break; |
7220 | } |
7221 | } |
7222 | |
7223 | if (selected) { |
7224 | selectedIndices.emplace_back(priorBox); |
7225 | if (isV4) { |
7226 | indicesH.at({outPutBoxIndex}) = priorBox.second; |
7227 | tScore = scoresH.at({priorBox.second}); |
7228 | } else { |
7229 | indicesH.at({outPutBoxIndex, 0}) = batchIndex; |
7230 | indicesH.at({outPutBoxIndex, 1}) = classIndex; |
7231 | indicesH.at({outPutBoxIndex, 2}) = priorBox.second; |
7232 | tScore = scoresH.at({batchIndex, classIndex, priorBox.second}); |
7233 | } |
7234 | |
7235 | ++outPutBoxIndex; |
7236 | ++detectedPerClass; |
7237 | ++detectedPerBatch; |
7238 | } |
7239 | if (maxBoxesPerClass == detectedPerClass) { |
7240 | break; |
7241 | } |
7242 | } |
7243 | |
7244 | if (tScore < minBox.classValue) { |
7245 | minBox.classValue = tScore; |
7246 | minBox.classIndex = classIndex; |
7247 | if (isV4) { |
7248 | minBox.boxIndex = indicesH.at({outPutBoxIndex - 1}); |
7249 | } else { |
7250 | minBox.boxIndex = indicesH.at({outPutBoxIndex - 1, 2}); |
7251 | } |
7252 | } |
7253 | } |
7254 | |
7255 | for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) { |
7256 | if (isV4) { |
7257 | indicesH.at({outPutBoxIndex}) = minBox.boxIndex; |
7258 | } else { |
7259 | indicesH.at({outPutBoxIndex, 0}) = minBox.batchIndex; |
7260 | indicesH.at({outPutBoxIndex, 1}) = minBox.classIndex; |
7261 | indicesH.at({outPutBoxIndex, 2}) = minBox.boxIndex; |
7262 | } |
7263 | |
7264 | ++outPutBoxIndex; |
7265 | } |
7266 | // For ONNX NMS it's not used, for TF Batch Dimension is 1. |
7267 | for (dim_t i = 0; i < maxBoxesPerClass; ++i) { |
7268 | numDetectedH.at({batchIndex * maxBoxesPerClass + i}) = detectedPerBatch; |
7269 | } |
7270 | } |
7271 | } |
7272 | |
7273 | void BoundInterpreterFunction::fwdNonMaxSuppressionInst( |
7274 | glow::NonMaxSuppressionInst const *I) { |
7275 | switch (I->getBoxes()->getElementType()) { |
7276 | case ElemKind::FloatTy: |
7277 | if (I->getIndices()->getElementType() == ElemKind::Int32ITy) { |
7278 | fwdNonMaxSuppressionInstImpl<int32_t>(I); |
7279 | } else if (I->getIndices()->getElementType() == ElemKind::Int64ITy) { |
7280 | fwdNonMaxSuppressionInstImpl<int64_t>(I); |
7281 | } else { |
7282 | llvm_unreachable("Output type is not supported." ); |
7283 | } |
7284 | break; |
7285 | default: |
7286 | llvm_unreachable("Type is not supported." ); |
7287 | break; |
7288 | } |
7289 | } |
7290 | |
7291 | //===----------------------------------------------------------------------===// |
7292 | // TensorFlowLite NonMaxSuppression |
7293 | //===----------------------------------------------------------------------===// |
7294 | static int32_t partition(int32_t *arr, int32_t low, int32_t high, |
7295 | float *values) { |
7296 | float pivot = values[high]; |
7297 | int32_t i = (low - 1); |
7298 | float swap_float; |
7299 | int32_t swap_int; |
7300 | |
7301 | for (int32_t j = low; j <= high - 1; j++) { |
7302 | if (values[j] > pivot) { |
7303 | i++; |
7304 | |
7305 | swap_float = values[i]; |
7306 | values[i] = values[j]; |
7307 | values[j] = swap_float; |
7308 | |
7309 | swap_int = arr[i]; |
7310 | arr[i] = arr[j]; |
7311 | arr[j] = swap_int; |
7312 | } |
7313 | } |
7314 | |
7315 | swap_float = values[i + 1]; |
7316 | values[i + 1] = values[high]; |
7317 | values[high] = swap_float; |
7318 | |
7319 | swap_int = arr[i + 1]; |
7320 | arr[i + 1] = arr[high]; |
7321 | arr[high] = swap_int; |
7322 | |
7323 | return (i + 1); |
7324 | } |
7325 | |
7326 | static void partial_sort(int32_t *arr, int32_t i, int32_t j, int32_t k, |
7327 | float *values) { |
7328 | int32_t p; |
7329 | if (i < j) { |
7330 | p = partition(arr, i, j, values); |
7331 | |
7332 | partial_sort(arr, i, p - 1, k, values); |
7333 | |
7334 | if (p < k - 1) |
7335 | partial_sort(arr, p + 1, j, k, values); |
7336 | } |
7337 | } |
7338 | |
7339 | static void iota(int32_t *first, int32_t *last, int32_t value) { |
7340 | while (first != last) { |
7341 | *first++ = value; |
7342 | value++; |
7343 | } |
7344 | } |
7345 | |
7346 | static void decreasing_partial_arg_sort(float *values, int32_t num_values, |
7347 | int32_t num_to_sort, int32_t *indices, |
7348 | float *aux_values) { |
7349 | iota(indices, indices + num_values, 0); |
7350 | |
7351 | memcpy(aux_values, values, sizeof(float) * num_values); |
7352 | |
7353 | partial_sort(indices, 0, num_values - 1, num_to_sort, aux_values); |
7354 | } |
7355 | |
7356 | static void select_detection_above_score_threshold( |
7357 | float *scores, int32_t num_scores, float threshold, float *keep_values, |
7358 | int32_t *keep_indices, int32_t *num_indices) { |
7359 | int32_t idx = 0; |
7360 | for (int32_t i = 0; i < num_scores; i++) { |
7361 | if (scores[i] >= threshold) { |
7362 | keep_indices[idx] = i; |
7363 | keep_values[idx] = scores[i]; |
7364 | idx++; |
7365 | } |
7366 | } |
7367 | *num_indices = idx; |
7368 | } |
7369 | |
7370 | /// Compute the IOU (Intersection Over Union) metric between two boxes. Each |
7371 | /// of box1 and box2 is a vector with 4 floating-point values with the box |
7372 | /// coordinates in the following format: [ymin, xmin, ymax, xmax]. |
7373 | static float tflite_compute_iou(float *box1, float *box2) { |
7374 | |
7375 | // Compute the areas of the two boxes. |
7376 | float box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]); |
7377 | float box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]); |
7378 | |
7379 | // If box coordinates are invalid we return 0. |
7380 | if (box1Area <= 0 || box2Area <= 0) { |
7381 | return 0.0f; |
7382 | } |
7383 | |
7384 | // Determine the coordinates of the intersection rectangle. |
7385 | float iYmin = MAX(box1[0], box2[0]); |
7386 | float iXmin = MAX(box1[1], box2[1]); |
7387 | float iYmax = MIN(box1[2], box2[2]); |
7388 | float iXmax = MIN(box1[3], box2[3]); |
7389 | |
7390 | // Compute the area of the intersection rectangle. |
7391 | float iArea = MAX(0.0f, iXmax - iXmin) * MAX(0.0f, iYmax - iYmin); |
7392 | |
7393 | // Compute the area of the union (reunion) rectangle. |
7394 | float uArea = box1Area + box2Area - iArea; |
7395 | |
7396 | // Compute the Intersection Over Union metric. |
7397 | return iArea / uArea; |
7398 | } |
7399 | |
7400 | static void tflite_helper(float *boxesPtr, int32_t num_boxes, |
7401 | float nms_score_threshold, float nms_iou_treshold, |
7402 | float *class_scores, int32_t num_scores, |
7403 | int32_t *selected, int32_t *num_selected, |
7404 | int32_t max_detections, int32_t *keep_indices, |
7405 | float *keep_scores, int32_t *sorted_indices_helper) { |
7406 | |
7407 | *num_selected = 0; |
7408 | |
7409 | int32_t num_scores_kept; |
7410 | select_detection_above_score_threshold(class_scores, num_boxes, |
7411 | nms_score_threshold, keep_scores, |
7412 | keep_indices, &num_scores_kept); |
7413 | |
7414 | decreasing_partial_arg_sort(keep_scores, num_scores_kept, num_scores_kept, |
7415 | sorted_indices_helper, (float *)selected); |
7416 | |
7417 | int32_t num_boxes_kept = num_scores_kept; |
7418 | int32_t output_size = MIN(num_boxes_kept, max_detections); |
7419 | |
7420 | int32_t num_active_candidate = num_boxes_kept; |
7421 | |
7422 | uint8_t *active_box_candidate = (uint8_t *)keep_scores; |
7423 | |
7424 | for (int32_t row = 0; row < num_boxes_kept; row++) { |
7425 | active_box_candidate[row] = 1; |
7426 | } |
7427 | |
7428 | for (int32_t i = 0; i < num_boxes_kept; i++) { |
7429 | if (num_active_candidate == 0 || *num_selected >= output_size) |
7430 | break; |
7431 | if (active_box_candidate[i] == 1) { |
7432 | selected[*num_selected] = keep_indices[sorted_indices_helper[i]]; |
7433 | (*num_selected)++; |
7434 | active_box_candidate[i] = 0; |
7435 | num_active_candidate--; |
7436 | } else { |
7437 | continue; |
7438 | } |
7439 | |
7440 | for (int32_t j = i + 1; j < num_boxes_kept; ++j) { |
7441 | if (active_box_candidate[j] == 1) { |
7442 | |
7443 | float *box1 = boxesPtr + 4 * keep_indices[sorted_indices_helper[i]]; |
7444 | float *box2 = boxesPtr + 4 * keep_indices[sorted_indices_helper[j]]; |
7445 | float iou = tflite_compute_iou(box1, box2); |
7446 | |
7447 | if (iou > nms_iou_treshold) { |
7448 | active_box_candidate[j] = 0; |
7449 | num_active_candidate--; |
7450 | } |
7451 | } |
7452 | } |
7453 | } |
7454 | } |
7455 | |
7456 | static void tflite_detection_post_process_f( |
7457 | float *boxes, float *scores, float *anchors, float *detectionBoxes, |
7458 | int32_t *detectionClasses, float *detectionScores, int32_t *numDetections, |
7459 | int8_t *scratch, int32_t numBoxes, int32_t numTotalClasses, |
7460 | int32_t numClasses, int32_t maxDetections, int32_t maxClassesPerDetection, |
7461 | int32_t maxDetectionsPerClass, float iouThreshold, float scoreThreshold, |
7462 | float xScaleInv, float yScaleInv, float hScaleInv, float wScaleInv, |
7463 | bool regularNMS) { |
7464 | |
7465 | // Decode the box coordinates in-place using the anchors. |
7466 | for (int32_t i = 0; i < numBoxes; i++) { |
7467 | |
7468 | float *box = &boxes[i * 4]; |
7469 | float *anchor = &anchors[i * 4]; |
7470 | |
7471 | float ycenter = box[0] * yScaleInv * anchor[2] + anchor[0]; |
7472 | float xcenter = box[1] * xScaleInv * anchor[3] + anchor[1]; |
7473 | |
7474 | float half_h = 0.5f * expf(box[2] * hScaleInv) * anchor[2]; |
7475 | float half_w = 0.5f * expf(box[3] * wScaleInv) * anchor[3]; |
7476 | |
7477 | box[0] = ycenter - half_h; |
7478 | box[1] = xcenter - half_w; |
7479 | box[2] = ycenter + half_h; |
7480 | box[3] = xcenter + half_w; |
7481 | } |
7482 | |
7483 | int32_t max_categories_per_anchor = maxClassesPerDetection; |
7484 | int32_t num_categories_per_anchor = |
7485 | MIN(max_categories_per_anchor, numClasses); |
7486 | int32_t label_offset = numTotalClasses - numClasses; |
7487 | |
7488 | if (regularNMS) { |
7489 | int32_t num_detections_per_class = maxDetectionsPerClass; |
7490 | |
7491 | float *class_scores = (float *)(scratch); |
7492 | scratch += numBoxes * sizeof(float); |
7493 | |
7494 | int32_t *box_indices_after_regular_nms = (int32_t *)(scratch); |
7495 | scratch += (numBoxes + maxDetections) * sizeof(int32_t); |
7496 | |
7497 | float *scores_after_regular_nms = (float *)(scratch); |
7498 | scratch += (numBoxes + maxDetections) * sizeof(float); |
7499 | |
7500 | int32_t size_of_sorted_indices = 0; |
7501 | |
7502 | int32_t *sorted_indices = (int32_t *)(scratch); |
7503 | scratch += (numBoxes + maxDetections) * sizeof(int32_t); |
7504 | |
7505 | float *sorted_values = (float *)(scratch); |
7506 | scratch += MIN(numBoxes, maxDetectionsPerClass) * sizeof(float); |
7507 | |
7508 | int32_t *selected = (int32_t *)scratch; |
7509 | scratch += numBoxes * sizeof(int32_t); |
7510 | |
7511 | int32_t *keep_indices = (int32_t *)(scratch); |
7512 | scratch += numBoxes * sizeof(int32_t); |
7513 | |
7514 | float *keep_scores = (float *)(scratch); |
7515 | scratch += numBoxes * sizeof(float); |
7516 | |
7517 | int32_t *sorted_indices_helper = (int32_t *)scratch; |
7518 | scratch += numBoxes * sizeof(int32_t); |
7519 | |
7520 | for (int32_t col = 0; col < numClasses; col++) { |
7521 | for (int32_t row = 0; row < numBoxes; row++) { |
7522 | class_scores[row] = |
7523 | *(scores + row * numTotalClasses + col + label_offset); |
7524 | } |
7525 | |
7526 | int32_t num_selected; |
7527 | tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, class_scores, |
7528 | numBoxes, selected, &num_selected, num_detections_per_class, |
7529 | keep_indices, keep_scores, sorted_indices_helper); |
7530 | |
7531 | int32_t output_index = size_of_sorted_indices; |
7532 | for (int32_t i = 0; i < num_selected; i++) { |
7533 | int32_t selected_index = selected[i]; |
7534 | box_indices_after_regular_nms[output_index] = |
7535 | (selected_index * numTotalClasses + col + label_offset); |
7536 | scores_after_regular_nms[output_index] = class_scores[selected_index]; |
7537 | output_index++; |
7538 | } |
7539 | |
7540 | int32_t num_indices_to_sort = MIN(output_index, maxDetections); |
7541 | |
7542 | decreasing_partial_arg_sort(scores_after_regular_nms, output_index, |
7543 | num_indices_to_sort, sorted_indices, |
7544 | keep_scores); |
7545 | |
7546 | for (int32_t row = 0; row < num_indices_to_sort; row++) { |
7547 | int32_t temp = sorted_indices[row]; |
7548 | sorted_indices[row] = box_indices_after_regular_nms[temp]; |
7549 | sorted_values[row] = scores_after_regular_nms[temp]; |
7550 | } |
7551 | |
7552 | for (int32_t row = 0; row < num_indices_to_sort; row++) { |
7553 | box_indices_after_regular_nms[row] = sorted_indices[row]; |
7554 | scores_after_regular_nms[row] = sorted_values[row]; |
7555 | } |
7556 | |
7557 | size_of_sorted_indices = num_indices_to_sort; |
7558 | } |
7559 | |
7560 | for (int32_t output_box_index = 0; |
7561 | output_box_index < size_of_sorted_indices; output_box_index++) { |
7562 | |
7563 | int32_t anchor_index = |
7564 | box_indices_after_regular_nms[output_box_index] / numTotalClasses; |
7565 | int32_t class_index = box_indices_after_regular_nms[output_box_index] - |
7566 | anchor_index * numTotalClasses - label_offset; |
7567 | float selected_score = scores_after_regular_nms[output_box_index]; |
7568 | float *box = boxes + anchor_index * 4; |
7569 | |
7570 | *detectionBoxes++ = *box++; |
7571 | *detectionBoxes++ = *box++; |
7572 | *detectionBoxes++ = *box++; |
7573 | *detectionBoxes++ = *box++; |
7574 | *detectionClasses++ = class_index; |
7575 | *detectionScores++ = selected_score; |
7576 | } |
7577 | |
7578 | *numDetections = size_of_sorted_indices; |
7579 | } else { |
7580 | float *max_scores = (float *)scratch; |
7581 | scratch += numBoxes * sizeof(float); |
7582 | |
7583 | int32_t *sorted_classes_indices = (int32_t *)scratch; |
7584 | scratch += numBoxes * MIN(maxDetections, numClasses) * sizeof(int32_t); |
7585 | |
7586 | int32_t *selected = (int32_t *)scratch; |
7587 | scratch += numBoxes * sizeof(int32_t); |
7588 | |
7589 | int32_t *keep_indices = (int32_t *)(scratch); |
7590 | scratch += numBoxes * sizeof(int32_t); |
7591 | |
7592 | float *keep_scores = (float *)(scratch); |
7593 | scratch += numBoxes * sizeof(float); |
7594 | |
7595 | int32_t *sorted_indices_helper = (int32_t *)scratch; |
7596 | scratch += numBoxes * sizeof(int32_t); |
7597 | |
7598 | for (int32_t row = 0; row < numBoxes; row++) { |
7599 | float *box_scores = scores + row * numTotalClasses + label_offset; |
7600 | int32_t *class_indices = |
7601 | sorted_classes_indices + row * num_categories_per_anchor; |
7602 | |
7603 | decreasing_partial_arg_sort(box_scores, numClasses, |
7604 | num_categories_per_anchor, keep_indices, |
7605 | keep_scores); |
7606 | |
7607 | for (int32_t i = 0; i < num_categories_per_anchor; i++) { |
7608 | class_indices[i] = keep_indices[i]; |
7609 | } |
7610 | |
7611 | max_scores[row] = box_scores[class_indices[0]]; |
7612 | } |
7613 | |
7614 | int32_t selected_size = 0; |
7615 | tflite_helper(boxes, numBoxes, scoreThreshold, iouThreshold, max_scores, |
7616 | numBoxes, selected, &selected_size, maxDetections, |
7617 | keep_indices, keep_scores, sorted_indices_helper); |
7618 | |
7619 | int32_t num_detections = 0; |
7620 | for (int32_t i = 0; i < selected_size; i++) { |
7621 | |
7622 | int32_t selected_index = selected[i]; |
7623 | float *box = boxes + selected_index * 4; |
7624 | float *box_scores = |
7625 | scores + selected_index * numTotalClasses + label_offset; |
7626 | int32_t *class_indices = |
7627 | sorted_classes_indices + selected_index * num_categories_per_anchor; |
7628 | |
7629 | for (int32_t col = 0; (col < num_categories_per_anchor) && |
7630 | (num_detections <= selected_size); |
7631 | ++col) { |
7632 | *detectionBoxes++ = box[0]; |
7633 | *detectionBoxes++ = box[1]; |
7634 | *detectionBoxes++ = box[2]; |
7635 | *detectionBoxes++ = box[3]; |
7636 | *detectionClasses++ = class_indices[col]; |
7637 | *detectionScores++ = box_scores[class_indices[col]]; |
7638 | num_detections++; |
7639 | } |
7640 | } |
7641 | |
7642 | *numDetections = selected_size; |
7643 | } |
7644 | } |
7645 | |
7646 | void BoundInterpreterFunction::fwdTFLiteDetectionPostProcessInst( |
7647 | glow::TFLiteDetectionPostProcessInst const *I) { |
7648 | auto boxes = I->getBoxes(); |
7649 | auto scores = I->getScores(); |
7650 | auto anchors = I->getAnchors(); |
7651 | auto detectionBoxes = I->getDetectionBoxes(); |
7652 | auto detectionClasses = I->getDetectionClasses(); |
7653 | auto detectionScores = I->getDetectionScores(); |
7654 | auto numDetections = I->getNumDetections(); |
7655 | auto scratch = I->getScratch(); |
7656 | |
7657 | // Get raw pointers. |
7658 | float *boxesPtr = (float *)getTensor(boxes)->getUnsafePtr(); |
7659 | float *scoresPtr = (float *)getTensor(scores)->getUnsafePtr(); |
7660 | float *anchorsPtr = (float *)getTensor(anchors)->getUnsafePtr(); |
7661 | float *detectionBoxesPtr = (float *)getTensor(detectionBoxes)->getUnsafePtr(); |
7662 | int32_t *detectionClassesPtr = |
7663 | (int32_t *)getTensor(detectionClasses)->getUnsafePtr(); |
7664 | float *detectionScoresPtr = |
7665 | (float *)getTensor(detectionScores)->getUnsafePtr(); |
7666 | int32_t *numDetectionsPtr = |
7667 | (int32_t *)getTensor(numDetections)->getUnsafePtr(); |
7668 | int8_t *scratchPtr = (int8_t *)getTensor(scratch)->getUnsafePtr(); |
7669 | |
7670 | // Get parameters. |
7671 | int32_t numBoxes = boxes->dims()[1]; |
7672 | int32_t numTotalClasses = scores->dims()[2]; |
7673 | int32_t numClasses = I->getNumClasses(); |
7674 | int32_t maxDetections = I->getMaxDetections(); |
7675 | int32_t maxClassesPerDetection = I->getMaxClassesPerDetection(); |
7676 | int32_t maxDetectionsPerClass = I->getMaxDetectionsPerClass(); |
7677 | float iouThreshold = I->getIouThreshold(); |
7678 | float scoreThreshold = I->getScoreThreshold(); |
7679 | float xScaleInv = 1.0f / I->getXScale(); |
7680 | float yScaleInv = 1.0f / I->getYScale(); |
7681 | float hScaleInv = 1.0f / I->getHScale(); |
7682 | float wScaleInv = 1.0f / I->getWScale(); |
7683 | bool regularNMS = I->getRegularNMS(); |
7684 | |
7685 | // Compute TFLite NMS. |
7686 | tflite_detection_post_process_f( |
7687 | boxesPtr, scoresPtr, anchorsPtr, detectionBoxesPtr, detectionClassesPtr, |
7688 | detectionScoresPtr, numDetectionsPtr, scratchPtr, numBoxes, |
7689 | numTotalClasses, numClasses, maxDetections, maxClassesPerDetection, |
7690 | maxDetectionsPerClass, iouThreshold, scoreThreshold, xScaleInv, yScaleInv, |
7691 | hScaleInv, wScaleInv, regularNMS); |
7692 | } |
7693 | |
7694 | void BoundInterpreterFunction::fwdAudioSpectrogramInstFloatImpl( |
7695 | glow::AudioSpectrogramInst const *I) { |
7696 | |
7697 | auto spectrogram = I->getSpectrogram(); |
7698 | auto input = I->getInput(); |
7699 | auto window = I->getWindow(); |
7700 | int64_t windowSize = I->getWindowSize(); |
7701 | int64_t windowStride = I->getWindowStride(); |
7702 | |
7703 | auto spectrogramH = getTensor(spectrogram)->getHandle<float>(); |
7704 | auto inputH = getTensor(input)->getHandle<float>(); |
7705 | auto windowH = getTensor(window)->getHandle<float>(); |
7706 | |
7707 | // Compute window count. |
7708 | int64_t inputLength = input->size(); |
7709 | int64_t windowCount = |
7710 | std::floor((inputLength - windowSize) / windowStride) + 1; |
7711 | |
7712 | // Compute FFT length (next power of 2) and spectrogram length. |
7713 | dim_t fftLen = 1 << (dim_t)std::ceil(std::log2((double)windowSize)); |
7714 | dim_t specLen = fftLen / 2 + 1; |
7715 | |
7716 | // Allocate temporary buffers. |
7717 | auto winOut = std::make_unique<float[]>(windowSize); |
7718 | auto fftRealOut = std::make_unique<float[]>(specLen); |
7719 | auto fftImagOut = std::make_unique<float[]>(specLen); |
7720 | |
7721 | // Compute the spectrogram. |
7722 | for (dim_t winIdx = 0; int64_t(winIdx) < windowCount; winIdx++) { |
7723 | |
7724 | // Windowing. |
7725 | for (int64_t n = 0; n < windowSize; n++) { |
7726 | winOut[n] = inputH.raw(winIdx * windowStride + n) * windowH.raw(n); |
7727 | } |
7728 | |
7729 | // Compute spectrum (perform FFT). |
7730 | for (dim_t k = 0; k < specLen; k++) { |
7731 | fftRealOut[k] = 0; |
7732 | fftImagOut[k] = 0; |
7733 | for (int n = 0; n < windowSize; n++) { |
7734 | fftRealOut[k] += |
7735 | winOut[n] * cos(2.0 * M_PI * (double)(n * k) / (double)(fftLen)); |
7736 | fftImagOut[k] -= |
7737 | winOut[n] * sin(2.0 * M_PI * (double)(n * k) / (double)(fftLen)); |
7738 | } |
7739 | } |
7740 | |
7741 | // Compute spectrum magnitude/power. |
7742 | if (I->getMagnitudeSquared()) { |
7743 | for (dim_t k = 0; k < specLen; k++) { |
7744 | spectrogramH.at({winIdx, k}) = |
7745 | fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k]; |
7746 | } |
7747 | } else { |
7748 | for (dim_t k = 0; k < specLen; k++) { |
7749 | spectrogramH.at({winIdx, k}) = |
7750 | sqrt(fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k]); |
7751 | } |
7752 | } |
7753 | } |
7754 | } |
7755 | |
7756 | void BoundInterpreterFunction::fwdAudioSpectrogramInst( |
7757 | glow::AudioSpectrogramInst const *I) { |
7758 | auto inputTy = I->getInput()->getElementType(); |
7759 | auto spectrogramTy = I->getSpectrogram()->getElementType(); |
7760 | if ((inputTy == ElemKind::FloatTy) && (spectrogramTy == ElemKind::FloatTy)) { |
7761 | fwdAudioSpectrogramInstFloatImpl(I); |
7762 | } else { |
7763 | llvm_unreachable("Type is not supported." ); |
7764 | } |
7765 | } |
7766 | |
7767 | void BoundInterpreterFunction::fwdMFCCInstFloatImpl(glow::MFCCInst const *I) { |
7768 | |
7769 | auto coefficients = I->getCoefficients(); |
7770 | auto spectrogram = I->getSpectrogram(); |
7771 | auto melWeights = I->getMelWeights(); |
7772 | auto melRanges = I->getMelRanges(); |
7773 | auto dctMat = I->getDctMat(); |
7774 | int64_t filterBankCount = I->getFilterBankCount(); |
7775 | int64_t numCoefficients = I->getNumCoefficients(); |
7776 | |
7777 | auto coefficientsH = getTensor(coefficients)->getHandle<float>(); |
7778 | auto spectrogramH = getTensor(spectrogram)->getHandle<float>(); |
7779 | auto melWeightsH = getTensor(melWeights)->getHandle<float>(); |
7780 | auto melRangesH = getTensor(melRanges)->getHandle<int32_t>(); |
7781 | auto dctMatH = getTensor(dctMat)->getHandle<float>(); |
7782 | |
7783 | // Perform MFCC for all the windows. |
7784 | auto winNum = spectrogram->dims()[0]; |
7785 | auto melBuff = std::make_unique<float[]>(filterBankCount); |
7786 | for (dim_t winIdx = 0; winIdx < winNum; winIdx++) { |
7787 | |
7788 | // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we |
7789 | // assume the spectrogram is a power value and not a magnitude. |
7790 | dim_t melBinCoeffIdx = 0; |
7791 | for (int64_t melIdx = 0; melIdx < filterBankCount; melIdx++) { |
7792 | int32_t freqIdxStart = melRangesH.raw(2 * melIdx + 0); |
7793 | int32_t freqIdxStop = melRangesH.raw(2 * melIdx + 1); |
7794 | float melPwr = 0.0f; |
7795 | for (dim_t freqIdx = freqIdxStart; int32_t(freqIdx) <= freqIdxStop; |
7796 | freqIdx++) { |
7797 | melPwr += std::sqrt(spectrogramH.at({winIdx, freqIdx})) * |
7798 | melWeightsH.raw(melBinCoeffIdx++); |
7799 | } |
7800 | melBuff[melIdx] = melPwr; |
7801 | } |
7802 | |
7803 | // Take logarithm in-place (avoid log(0)). |
7804 | for (int64_t melIdx = 0; melIdx < filterBankCount; melIdx++) { |
7805 | float melPwr = melBuff[melIdx]; |
7806 | melBuff[melIdx] = (melPwr == 0.0) |
7807 | ? logf(std::numeric_limits<float>::min()) |
7808 | : logf(melPwr); |
7809 | } |
7810 | |
7811 | // Compute DCT transform. |
7812 | for (dim_t k = 0; int64_t(k) < numCoefficients; k++) { |
7813 | float dctOut = 0.0f; |
7814 | for (dim_t n = 0; int64_t(n) < filterBankCount; n++) { |
7815 | dctOut += dctMatH.at({k, n}) * melBuff[n]; |
7816 | } |
7817 | coefficientsH.at({winIdx, k}) = dctOut; |
7818 | } |
7819 | } |
7820 | } |
7821 | |
7822 | void BoundInterpreterFunction::fwdMFCCInst(glow::MFCCInst const *I) { |
7823 | auto spectrogramTy = I->getSpectrogram()->getElementType(); |
7824 | auto coefficientsTy = I->getCoefficients()->getElementType(); |
7825 | if ((spectrogramTy == ElemKind::FloatTy) && |
7826 | (coefficientsTy == ElemKind::FloatTy)) { |
7827 | fwdMFCCInstFloatImpl(I); |
7828 | } else { |
7829 | llvm_unreachable("Type is not supported." ); |
7830 | } |
7831 | } |
7832 | |
7833 | namespace { |
7834 | /// Positions of the input values to be used for bilinear interpolation for |
7835 | /// each sample point and the weights to use for each. |
7836 | template <typename T> struct BinGrid { |
7837 | dim_t left; |
7838 | dim_t top; |
7839 | dim_t right; |
7840 | dim_t bottom; |
7841 | T leftW; |
7842 | T topW; |
7843 | T rightW; |
7844 | T bottomW; |
7845 | }; |
7846 | } // namespace |
7847 | |
7848 | /// Function to calculate the xy coordinates of the resized image (grid) |
7849 | /// Ref: https://arxiv.org/pdf/1703.06870.pdf and OnnxRuntime implementation |
7850 | /// \p featureMapHeight and \p featureMapWidth are the dimensions of the |
7851 | /// input feature map to the operator, \p outputHeight and \p outputWidth are |
7852 | /// the dimensions of the operator output tensor, \p samplingRatioH and \p |
7853 | /// samplingRatioW are the number of sampling points to use in each bin in the |
7854 | /// height and width directions respectively (total sample points is |
7855 | /// samplingRatioH * samplingRatioW), \p boxHeight and \p boxWidth are the |
7856 | /// height and width of the RoI box, \p yRef and \p xRef are the adjustment to |
7857 | /// be made for each sampling point, this is either the top left corer of the |
7858 | /// box for RoiAlign or a vector to be added to center point after rotation |
7859 | /// for RoiAlignRotated, \p rotated is true if the op is RoiAlignRotated, \p |
7860 | /// theta is the rotation angle in the case of RoiAlignRotated and is unused |
7861 | /// in RoiAlign, \p boxCenterH and \p boxCenterW are the center of the box |
7862 | /// used for rotation in the case of RoiAlignRotated and unused in the case of |
7863 | /// RoiAlign. \returns a vector of BinGrids, each one to be used to compute a |
7864 | /// single sample point value. |
7865 | template <typename T> |
7866 | static std::vector<BinGrid<T>> getROIAlignInterpolationCoordinates( |
7867 | dim_t featureMapHeight, dim_t featureMapWidth, dim_t outputHeight, |
7868 | dim_t outputWidth, dim_t samplingRatioH, dim_t samplingRatioW, T boxHeight, |
7869 | T boxWidth, T yRef, T xRef, bool rotated, T theta, T boxCenterH, |
7870 | T boxCenterW) { |
7871 | |
7872 | T sinTheta = T(0.0f); |
7873 | T cosTheta = T(0.0f); |
7874 | if (rotated) { |
7875 | sinTheta = T(std::sin(float(theta))); |
7876 | cosTheta = T(std::cos(float(theta))); |
7877 | } |
7878 | |
7879 | std::vector<BinGrid<T>> binGrids; |
7880 | |
7881 | // height and width of the each bin in the final output |
7882 | const T binH = boxHeight / T(outputHeight); |
7883 | const T binW = boxWidth / T(outputWidth); |
7884 | const T roiBinSizeH = binH / T(samplingRatioH); |
7885 | const T roiBinSizeW = binW / T(samplingRatioW); |
7886 | for (dim_t oh = 0; oh < outputHeight; oh++) { |
7887 | for (dim_t ow = 0; ow < outputWidth; ow++) { |
7888 | for (dim_t gh = 0; gh < samplingRatioH; gh++) { |
7889 | for (dim_t gw = 0; gw < samplingRatioW; gw++) { |
7890 | // x,y coordinates or vector w.r.t input dimensions |
7891 | T inY = yRef + (T(oh) * binH) + ((T(gh) + T(0.5f)) * roiBinSizeH); |
7892 | T inX = xRef + (T(ow) * binW) + ((T(gw) + T(0.5f)) * roiBinSizeW); |
7893 | |
7894 | // If ROI is rotated, rotate by theta around the box center then |
7895 | // translate |
7896 | if (rotated) { |
7897 | T inYY = inY; |
7898 | T inXX = inX; |
7899 | inY = inYY * cosTheta - inXX * sinTheta + boxCenterH; |
7900 | inX = inXX * cosTheta + inYY * sinTheta + boxCenterW; |
7901 | } |
7902 | |
7903 | // zero pad mal-formed boxes |
7904 | if (inY < T(-1) || inY > T(featureMapHeight)) { |
7905 | BinGrid<T> bg = BinGrid<T>{0, 0, 0, 0, 0, 0, 0, 0}; |
7906 | binGrids.push_back(bg); |
7907 | continue; |
7908 | } |
7909 | if (inX < T(-1) || inX > T(featureMapWidth)) { |
7910 | BinGrid<T> bg = BinGrid<T>{0, 0, 0, 0, 0, 0, 0, 0}; |
7911 | binGrids.push_back(bg); |
7912 | continue; |
7913 | } |
7914 | |
7915 | // clip to input dimensions |
7916 | T y = std::min(std::max(inY, T(0)), T(featureMapHeight - 1)); |
7917 | T x = std::min(std::max(inX, T(0)), T(featureMapWidth - 1)); |
7918 | |
7919 | // calc interpolation parameters |
7920 | const dim_t yl = dim_t(std::floor(float(y))); |
7921 | const dim_t xl = dim_t(std::floor(float(x))); |
7922 | const dim_t yh = std::min(yl + 1, featureMapHeight - 1); |
7923 | const dim_t xh = std::min(xl + 1, featureMapWidth - 1); |
7924 | |
7925 | BinGrid<T> bg; |
7926 | bg.left = xl; |
7927 | bg.top = yl; |
7928 | bg.right = xh; |
7929 | bg.bottom = yh; |
7930 | |
7931 | bg.rightW = x - T(xl); |
7932 | bg.bottomW = y - T(yl); |
7933 | bg.leftW = T(1.0) - bg.rightW; |
7934 | bg.topW = T(1.0) - bg.bottomW; |
7935 | |
7936 | binGrids.push_back(bg); |
7937 | } // end of w |
7938 | } // end of h |
7939 | } // end of W |
7940 | } // end of H |
7941 | |
7942 | return binGrids; |
7943 | } |
7944 | |
7945 | // Implementation of ROIAlign as described in |
7946 | // https://arxiv.org/pdf/1703.06870.pdf ROIAlign is similar to crop_and_resize |
7947 | // + pooling with minor modifications in the crop_and_resize. |
7948 | template <typename T> |
7949 | void BoundInterpreterFunction::fwdROIAlignInstFloatImpl( |
7950 | glow::ROIAlignInst const *I) { |
7951 | auto featureMap = I->getFeatureMap(); |
7952 | auto boxes = I->getBoxes(); |
7953 | auto batchIndices = I->getBatchIndices(); |
7954 | auto result = I->getResult(); |
7955 | |
7956 | auto boxesH = getTensor(boxes)->getHandle<T>(); |
7957 | auto featureMapH = getTensor(featureMap)->getHandle<T>(); |
7958 | auto resultH = getTensor(result)->getHandle<T>(); |
7959 | |
7960 | const bool rotated = I->getRotated(); |
7961 | const PoolingMode mode = PoolingMode(I->getMode()); |
7962 | const bool aligned = I->getAligned(); |
7963 | const dim_t samplingRatio = I->getSamplingRatio(); |
7964 | const T spatialScale = I->getSpatialScale(); |
7965 | |
7966 | const dim_t featureMapHeight = featureMapH.dims()[1]; |
7967 | const dim_t featureMapWidth = featureMapH.dims()[2]; |
7968 | const dim_t numBoxes = resultH.dims()[0]; |
7969 | const dim_t outputHeight = resultH.dims()[1]; |
7970 | const dim_t outputWidth = resultH.dims()[2]; |
7971 | const dim_t depth = resultH.dims()[3]; |
7972 | |
7973 | const T offset = aligned ? T(0.5) : T(0); |
7974 | |
7975 | bool useSeparateBatchIndexVector = true; |
7976 | dim_t boxesStartCol = 0; |
7977 | if (rotated || boxes->dims()[1] == 5) { |
7978 | boxesStartCol = 1; |
7979 | useSeparateBatchIndexVector = false; |
7980 | } |
7981 | |
7982 | // Extract batch indices from batchIndices tensor if that is used (only used |
7983 | // by ONNX which may provide Int64ITy tensors.) |
7984 | std::vector<dim_t> ; |
7985 | if (useSeparateBatchIndexVector) { |
7986 | Tensor *batchIndicesTensor = getTensor(batchIndices); |
7987 | auto batchIndicesElemKind = batchIndicesTensor->getElementType(); |
7988 | for (dim_t b = 0; b < numBoxes; b++) { |
7989 | if (batchIndicesElemKind == ElemKind::Int32ITy) { |
7990 | batchIndicesExtracted.push_back( |
7991 | batchIndicesTensor->getHandle<int32_t>().at({b})); |
7992 | } else { |
7993 | batchIndicesExtracted.push_back( |
7994 | batchIndicesTensor->getHandle<int64_t>().at({b})); |
7995 | } |
7996 | } |
7997 | } |
7998 | |
7999 | for (dim_t b = 0; b < numBoxes; b++) { |
8000 | dim_t batchIndex; |
8001 | if (useSeparateBatchIndexVector) { |
8002 | batchIndex = batchIndicesExtracted[b]; |
8003 | } else { |
8004 | batchIndex = dim_t(float(boxesH.at({b, 0}))); |
8005 | } |
8006 | |
8007 | // Values used to determine sampling points during bilinear interpolation. |
8008 | // yRef and xRef have different interpreterations for rotated vs unrotated |
8009 | // cases (vector vs coordinates) but are used very similarly. |
8010 | T yRef; |
8011 | T xRef; |
8012 | T boxHeight; |
8013 | T boxWidth; |
8014 | |
8015 | // Values only used in rotated case. |
8016 | T theta = T(0.0); |
8017 | T boxCenterH = T(0.0); |
8018 | T boxCenterW = T(0.0); |
8019 | |
8020 | if (rotated) { |
8021 | // Do not round |
8022 | boxCenterW = boxesH.at({b, boxesStartCol + 0}) * spatialScale - offset; |
8023 | boxCenterH = boxesH.at({b, boxesStartCol + 1}) * spatialScale - offset; |
8024 | boxWidth = boxesH.at({b, boxesStartCol + 2}) * spatialScale; |
8025 | boxHeight = boxesH.at({b, boxesStartCol + 3}) * spatialScale; |
8026 | theta = boxesH.at({b, boxesStartCol + 4}) * T(M_PI) / T(180.0); |
8027 | |
8028 | if (aligned) { |
8029 | assert(boxWidth >= T(0.0) && boxHeight >= T(0.0) && |
8030 | "ROIs in ROIAlign must not have non-negative size!" ); |
8031 | } else { // backward compatibility |
8032 | // Force malformed ROIs to be 1x1 |
8033 | boxHeight = std::max(boxHeight, T(1.0)); |
8034 | boxWidth = std::max(boxWidth, T(1.0)); |
8035 | } |
8036 | |
8037 | // These are computed wrt the center of RoI (x, y). |
8038 | // Appropriate translation needs to be applied after. |
8039 | yRef = (T(-1.0) * boxHeight) / T(2.0); |
8040 | xRef = (T(-1.0) * boxWidth) / T(2.0); |
8041 | } else { |
8042 | llvm::SmallVector<T, 4> box = { |
8043 | boxesH.at({b, boxesStartCol + 0}) * spatialScale - offset, |
8044 | boxesH.at({b, boxesStartCol + 1}) * spatialScale - offset, |
8045 | boxesH.at({b, boxesStartCol + 2}) * spatialScale - offset, |
8046 | boxesH.at({b, boxesStartCol + 3}) * spatialScale - offset}; |
8047 | |
8048 | if (aligned) { |
8049 | CHECK_GE(box[3] - box[1], T(0.0)) << "Roi height cannot be negative." ; |
8050 | CHECK_GE(box[2] - box[0], T(0.0)) << "Roi width cannot be negative." ; |
8051 | } else { |
8052 | // Caffe2 backwards compatibility for mal-formed ROIs: |
8053 | // Force ROI size to be at least 1x1. |
8054 | box[2] = std::max(box[2], box[0] + T(1.0)); |
8055 | box[3] = std::max(box[3], box[1] + T(1.0)); |
8056 | } |
8057 | |
8058 | yRef = box[1]; |
8059 | xRef = box[0]; |
8060 | boxHeight = (box[3] - box[1]); |
8061 | boxWidth = (box[2] - box[0]); |
8062 | } |
8063 | |
8064 | const dim_t samplingRatioH = |
8065 | (samplingRatio > 0) ? samplingRatio |
8066 | : std::ceil(float(boxHeight) / outputHeight); |
8067 | const dim_t samplingRatioW = (samplingRatio > 0) |
8068 | ? samplingRatio |
8069 | : std::ceil(float(boxWidth) / outputWidth); |
8070 | |
8071 | // get the xy coordinates in the resized image(grid) |
8072 | std::vector<BinGrid<T>> binGrids = getROIAlignInterpolationCoordinates<T>( |
8073 | featureMapHeight, featureMapWidth, outputHeight, outputWidth, |
8074 | samplingRatioH, samplingRatioW, boxHeight, boxWidth, yRef, xRef, |
8075 | rotated, theta, boxCenterH, boxCenterW); |
8076 | |
8077 | uint64_t binCount = 0; |
8078 | for (dim_t oh = 0; oh < outputHeight; ++oh) { |
8079 | for (dim_t ow = 0; ow < outputWidth; ++ow) { |
8080 | for (dim_t d = 0; d < depth; ++d) { |
8081 | std::vector<T> values; |
8082 | for (dim_t gh = 0; gh < samplingRatioH; ++gh) { |
8083 | for (dim_t gw = 0; gw < samplingRatioW; ++gw) { |
8084 | BinGrid<T> bg = binGrids[binCount++]; |
8085 | // The four values of the i/p image surrounding the point of |
8086 | // interest (POI) in the resized image |
8087 | const T topLeft = |
8088 | featureMapH.at({batchIndex, bg.top, bg.left, d}); |
8089 | const T topRight = |
8090 | featureMapH.at({batchIndex, bg.top, bg.right, d}); |
8091 | const T bottomLeft = |
8092 | featureMapH.at({batchIndex, bg.bottom, bg.left, d}); |
8093 | const T bottomRight = |
8094 | featureMapH.at({batchIndex, bg.bottom, bg.right, d}); |
8095 | |
8096 | // bilinear interpolation |
8097 | const T value = (topLeft * (bg.topW * bg.leftW)) + |
8098 | (topRight * (bg.topW * bg.rightW)) + |
8099 | (bottomLeft * (bg.bottomW * bg.leftW)) + |
8100 | (bottomRight * (bg.bottomW * bg.rightW)); |
8101 | // interpolation along vertical line |
8102 | values.push_back(value); |
8103 | } // end of w |
8104 | } // end of h |
8105 | // {Average or Max} pooling |
8106 | resultH.at({b, oh, ow, d}) = |
8107 | (mode == PoolingMode::AVG) |
8108 | ? std::accumulate(values.begin(), values.end(), T(0.0)) / |
8109 | T(values.size()) |
8110 | : *std::max_element(values.begin(), values.end()); |
8111 | |
8112 | binCount = binCount - (samplingRatioH * samplingRatioW); |
8113 | } // end of d |
8114 | binCount = binCount + (samplingRatioH * samplingRatioW); |
8115 | } // end of W |
8116 | } // end of H |
8117 | } // end of b |
8118 | } |
8119 | |
8120 | void BoundInterpreterFunction::fwdROIAlignInst(glow::ROIAlignInst const *I) { |
8121 | dispatchFloatingPointImpl(fwdROIAlignInstFloatImpl, |
8122 | I->getFeatureMap()->getElementType(), I); |
8123 | } |
8124 | |
8125 | // Forward transform that maps proposal boxes to ground-truth boxes using |
8126 | // bounding-box regression deltas. |
8127 | // boxes: pixel coordinates of the bounding boxes |
8128 | // size (M, 4), format [x1; y1; x2; y2], x2 >= x1, y2 >= y1 |
8129 | // deltas: bounding box translations and scales |
8130 | // size (M, 4), format [dx; dy; dw; dh] |
8131 | // dx, dy: scale-invariant translation of the center of the bounding box |
8132 | // dw, dh: log-space scaling of the width and height of the bounding box |
8133 | // weights: weights [wx, wy, ww, wh] for the deltas |
8134 | // bboxXformClip: minimum bounding box width and height in log-space after |
8135 | // transofmration |
8136 | // correct_transform_coords: Correct bounding box transform coordates. Set to |
8137 | // true to match the detectron code, set to false for backward |
8138 | // compatibility |
8139 | // return: pixel coordinates of the bounding boxes |
8140 | // size (M, 4), format [x1; y1; x2; y2] |
8141 | // see "Rich feature hierarchies for accurate object detection and semantic |
8142 | // segmentation" Appendix C for more details |
8143 | // reference: detectron/lib/utils/boxes.py bbox_transform() |
8144 | template <typename T> |
8145 | static void bbox_transform_upright( |
8146 | Handle<T> &boxesOut, const Handle<T> &boxes, const Handle<T> &deltas, |
8147 | dim_t startRowBoxesOut, dim_t startColBoxesOut, dim_t startRowBoxes, |
8148 | dim_t startColBoxes, dim_t startRowDeltas, dim_t startColDeltas, dim_t rows, |
8149 | llvm::ArrayRef<float> weights, const T &bboxXformClip, T scaleBeforeInv, |
8150 | const bool legacyPlusOne = false) { |
8151 | |
8152 | if (boxes.dims()[0] == 0) { |
8153 | return; |
8154 | } |
8155 | |
8156 | std::vector<T> widths(rows), heights(rows), ctrX(rows), ctrY(rows); |
8157 | for (dim_t i = 0; i < rows; i++) { |
8158 | widths[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) * |
8159 | scaleBeforeInv - |
8160 | boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv + |
8161 | T(((legacyPlusOne) ? 1 : 0)); |
8162 | heights[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) * |
8163 | scaleBeforeInv - |
8164 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) * |
8165 | scaleBeforeInv + |
8166 | T(((legacyPlusOne) ? 1 : 0)); |
8167 | |
8168 | ctrX[i] = boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv + |
8169 | T(0.5) * widths[i]; |
8170 | ctrY[i] = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) * |
8171 | scaleBeforeInv + |
8172 | T(0.5) * heights[i]; |
8173 | } |
8174 | |
8175 | std::vector<T> dx(rows), dy(rows), dw(rows), dh(rows); |
8176 | for (dim_t i = 0; i < rows; i++) { |
8177 | dx[i] = deltas.at({startRowDeltas + i, startColDeltas}) / T(weights[0]); |
8178 | dy[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(1)}) / |
8179 | T(weights[1]); |
8180 | dw[i] = |
8181 | std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(2)}) / |
8182 | T(weights[2]), |
8183 | bboxXformClip); |
8184 | dh[i] = |
8185 | std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(3)}) / |
8186 | T(weights[3]), |
8187 | bboxXformClip); |
8188 | } |
8189 | |
8190 | std::vector<T> predCtrX(rows), predCtrY(rows), predW(rows), predH(rows); |
8191 | for (dim_t i = 0; i < rows; i++) { |
8192 | predCtrX[i] = dx[i] * widths[i] + ctrX[i]; |
8193 | predCtrY[i] = dy[i] * heights[i] + ctrY[i]; |
8194 | predW[i] = T(std::exp(float(dw[i]))) * widths[i]; |
8195 | predH[i] = T(std::exp(float(dh[i]))) * heights[i]; |
8196 | } |
8197 | |
8198 | for (dim_t i = 0; i < rows; i++) { |
8199 | // x1 |
8200 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut}) = |
8201 | predCtrX[i] - T(0.5) * predW[i]; |
8202 | // x2 |
8203 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(1)}) = |
8204 | predCtrY[i] - T(0.5) * predH[i]; |
8205 | // y1 |
8206 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(2)}) = |
8207 | predCtrX[i] + T(0.5) * predW[i] - T(((legacyPlusOne) ? 1 : 0)); |
8208 | // y2 |
8209 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(3)}) = |
8210 | predCtrY[i] + T(0.5) * predH[i] - T(((legacyPlusOne) ? 1 : 0)); |
8211 | } |
8212 | } |
8213 | |
8214 | // Like bbox_transform_upright, but works on rotated boxes. |
8215 | // boxes: pixel coordinates of the bounding boxes |
8216 | // size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)] |
8217 | // deltas: bounding box translations and scales |
8218 | // size (M, 5), format [dx; dy; dw; dh; da] |
8219 | // dx, dy: scale-invariant translation of the center of the bounding box |
8220 | // dw, dh: log-space scaling of the width and height of the bounding box |
8221 | // da: delta for angle in radians |
8222 | // return: pixel coordinates of the bounding boxes |
8223 | // size (M, 5), format [ctr_x; ctr_y; width; height; angle (in degrees)] |
8224 | template <typename T> |
8225 | static void bbox_transform_rotated( |
8226 | Handle<T> &boxesOut, const Handle<T> &boxes, const Handle<T> &deltas, |
8227 | dim_t startRowBoxesOut, dim_t startColBoxesOut, dim_t startRowBoxes, |
8228 | dim_t startColBoxes, dim_t startRowDeltas, dim_t startColDeltas, dim_t rows, |
8229 | llvm::ArrayRef<float> weights, const T &bboxXformClip, T scaleBeforeInv, |
8230 | const bool angleBoundOn, ssize_t angleBoundLo, ssize_t angleBoundHi) { |
8231 | |
8232 | if (boxes.dims()[0] == 0) { |
8233 | return; |
8234 | } |
8235 | |
8236 | const T PI = 3.1415926535897931; |
8237 | |
8238 | std::vector<T> dx(rows), dy(rows), dw(rows), dh(rows), da(rows); |
8239 | for (dim_t i = 0; i < rows; i++) { |
8240 | dx[i] = deltas.at({startRowDeltas + i, startColDeltas}) / T(weights[0]); |
8241 | dy[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(1)}) / |
8242 | T(weights[1]); |
8243 | dw[i] = |
8244 | std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(2)}) / |
8245 | T(weights[2]), |
8246 | bboxXformClip); |
8247 | dh[i] = |
8248 | std::min(deltas.at({startRowDeltas + i, startColDeltas + dim_t(3)}) / |
8249 | T(weights[3]), |
8250 | bboxXformClip); |
8251 | // Convert back to degrees |
8252 | da[i] = deltas.at({startRowDeltas + i, startColDeltas + dim_t(4)}) * |
8253 | T(180.0) / PI; |
8254 | } |
8255 | |
8256 | for (dim_t i = 0; i < rows; i++) { |
8257 | // new ctr_x |
8258 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut}) = |
8259 | dx[i] * boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) * |
8260 | scaleBeforeInv + |
8261 | boxes.at({startRowBoxes + i, startColBoxes}) * scaleBeforeInv; |
8262 | // new ctr_y |
8263 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(1)}) = |
8264 | dy[i] * boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) * |
8265 | scaleBeforeInv + |
8266 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) * |
8267 | scaleBeforeInv; |
8268 | // new width |
8269 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(2)}) = |
8270 | T(std::exp(float(dw[i]))) * |
8271 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) * |
8272 | scaleBeforeInv; |
8273 | // new height |
8274 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(3)}) = |
8275 | T(std::exp(float(dh[i]))) * |
8276 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) * |
8277 | scaleBeforeInv; |
8278 | // new angle |
8279 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) = |
8280 | da[i] + boxes.at({startRowBoxes + i, startColBoxes + dim_t(4)}); |
8281 | } |
8282 | |
8283 | if (angleBoundOn) { |
8284 | const ssize_t period = angleBoundHi - angleBoundLo; |
8285 | |
8286 | for (dim_t i = 0; i < rows; i++) { |
8287 | if (ssize_t(boxesOut.at({startRowBoxesOut + i, |
8288 | startColBoxesOut + dim_t(4)})) < angleBoundLo) { |
8289 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) += |
8290 | T(period); |
8291 | } else if (ssize_t(boxesOut.at( |
8292 | {startRowBoxesOut + i, startColBoxesOut + dim_t(4)})) > |
8293 | angleBoundHi) { |
8294 | boxesOut.at({startRowBoxesOut + i, startColBoxesOut + dim_t(4)}) -= |
8295 | T(period); |
8296 | } |
8297 | } |
8298 | } |
8299 | } |
8300 | |
8301 | // Clip boxes to image boundaries |
8302 | // boxes: pixel coordinates of bounding box, size (M * 4) |
8303 | template <typename T> |
8304 | void clip_boxes_upright(Handle<T> &boxes, dim_t startRowBoxes, |
8305 | dim_t startColBoxes, dim_t rows, int height, int width, |
8306 | T scaleAfter, bool legacyPlusOne = false, |
8307 | std::vector<dim_t> uprightRows = {}) { |
8308 | for (dim_t i = 0; i < rows; i++) { |
8309 | if (uprightRows.size() == rows && !uprightRows[i]) { |
8310 | continue; |
8311 | } |
8312 | // x1 >= 0 && x1 < width |
8313 | boxes.at({startRowBoxes + i, startColBoxes}) = |
8314 | scaleAfter * |
8315 | std::max(std::min(boxes.at({startRowBoxes + i, startColBoxes}), |
8316 | T(width - int(((legacyPlusOne) ? 1 : 0)))), |
8317 | T(0)); |
8318 | // y1 >= 0 && y1 < height |
8319 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) = |
8320 | scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i, |
8321 | startColBoxes + dim_t(1)}), |
8322 | T(height - ((legacyPlusOne) ? 1 : 0))), |
8323 | T(0)); |
8324 | |
8325 | // x2 >= 0 && x2 < width |
8326 | boxes.at({startRowBoxes + i, startColBoxes + 2}) = |
8327 | scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i, |
8328 | startColBoxes + dim_t(2)}), |
8329 | T(width - ((legacyPlusOne) ? 1 : 0))), |
8330 | T(0)); |
8331 | // y2 >= 0 && y2 < height |
8332 | boxes.at({startRowBoxes + i, startColBoxes + 3}) = |
8333 | scaleAfter * std::max(std::min(boxes.at({startRowBoxes + i, |
8334 | startColBoxes + dim_t(3)}), |
8335 | T(height - ((legacyPlusOne) ? 1 : 0))), |
8336 | T(0)); |
8337 | } |
8338 | } |
8339 | |
8340 | // Similar to clip_boxes_upright but handles rotated boxes with angle info. |
8341 | // boxes: size (M, 5), format [ctr_x; ctr_y; width; height; angle (in |
8342 | // degrees)] |
8343 | // |
8344 | // Clipping is only performed for boxes that are almost upright |
8345 | // (within a given `angle_thresh` tolerance) to maintain backward |
8346 | // compatibility for non-rotated boxes. |
8347 | // |
8348 | // We don't clip rotated boxes due to a couple of reasons: |
8349 | // (1) There are potentially multiple ways to clip a rotated box to make it |
8350 | // fit within the image. |
8351 | // (2) It's tricky to make the entire rectangular box fit within the image and |
8352 | // still be able to not leave out pixels of interest. |
8353 | // Therefore, we rely on upstream ops like RoIAlignRotated safely handling |
8354 | // this. |
8355 | template <typename T> |
8356 | void clip_boxes_rotated(Handle<T> &boxes, dim_t startRowBoxes, |
8357 | dim_t startColBoxes, dim_t rows, int imH, int imW, |
8358 | T scaleAfter, float angleThresh = 1.0, |
8359 | bool legacyPlusOne = false) { |
8360 | std::vector<dim_t> uprightRows(rows, 0); |
8361 | for (dim_t i = 0; i < rows; i++) { |
8362 | if (std::abs(float(boxes.at( |
8363 | {startRowBoxes + i, startColBoxes + dim_t(4)}))) <= angleThresh) { |
8364 | const T ctrX = boxes.at({startRowBoxes + i, startColBoxes}); |
8365 | const T ctrY = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}); |
8366 | const T width = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}); |
8367 | const T height = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}); |
8368 | boxes.at({startRowBoxes + i, startColBoxes}) = |
8369 | ctrX - (width - T(((legacyPlusOne) ? 1 : 0))) / T(2.0); |
8370 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) = |
8371 | ctrY - (height - T(((legacyPlusOne) ? 1 : 0))) / T(2.0); |
8372 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) = |
8373 | ctrX + (width - T(((legacyPlusOne) ? 1 : 0))) / T(2.0); |
8374 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) = |
8375 | ctrY + (height - T(((legacyPlusOne) ? 1 : 0))) / T(2.0); |
8376 | uprightRows[i] = 1; |
8377 | } |
8378 | } |
8379 | clip_boxes_upright(boxes, startRowBoxes, startColBoxes, rows, imH, imW, |
8380 | /* scaleAfter */ T(1.0), legacyPlusOne, uprightRows); |
8381 | |
8382 | for (dim_t i = 0; i < rows; i++) { |
8383 | if (uprightRows[i] == 1) { |
8384 | const T x1 = boxes.at({startRowBoxes + i, startColBoxes}); |
8385 | const T y1 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}); |
8386 | const T x2 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}); |
8387 | const T y2 = boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}); |
8388 | boxes.at({startRowBoxes + i, startColBoxes}) = (x1 + x2) / T(2.0); |
8389 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(1)}) = |
8390 | (y1 + y2) / T(2.0); |
8391 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(2)}) = |
8392 | x2 - x1 + T(((legacyPlusOne) ? 1 : 0)); |
8393 | boxes.at({startRowBoxes + i, startColBoxes + dim_t(3)}) = |
8394 | y2 - y1 + T(((legacyPlusOne) ? 1 : 0)); |
8395 | } |
8396 | |
8397 | for (dim_t j = 0; j < 4; j++) { |
8398 | boxes.at({startRowBoxes + i, startColBoxes + j}) *= scaleAfter; |
8399 | } |
8400 | } |
8401 | } |
8402 | |
8403 | template <typename T> |
8404 | void BoundInterpreterFunction::fwdBBoxTransformInstFloatImpl( |
8405 | glow::BBoxTransformInst const *I) { |
8406 | auto roiIn = I->getRois(); |
8407 | auto deltaIn = I->getDeltas(); |
8408 | auto imInfoIn = I->getImInfo(); |
8409 | |
8410 | auto boxOut = I->getBoxOut(); |
8411 | auto roiBatchSplits = I->getRoiBatchSplits(); |
8412 | |
8413 | auto weights = I->getWeights(); |
8414 | const dim_t boxDim = I->getRotated() ? 5 : 4; |
8415 | auto applyScale = I->getApplyScale(); |
8416 | auto rotated = I->getRotated(); |
8417 | auto angleBoundOn = I->getAngleBoundOn(); |
8418 | auto angleBoundLo = I->getAngleBoundLo(); |
8419 | auto angleBoundHi = I->getAngleBoundHi(); |
8420 | auto angleThresh = I->getClipAngleThresh(); |
8421 | auto legacyPlusOne = I->getLegacyPlusOne(); |
8422 | const dim_t N = roiIn->dims()[0]; |
8423 | const dim_t numClasses = deltaIn->dims()[1] / boxDim; |
8424 | const dim_t batchSize = imInfoIn->dims()[0]; |
8425 | |
8426 | auto roisH = getTensor(roiIn)->getHandle<T>(); |
8427 | auto deltasH = getTensor(deltaIn)->getHandle<T>(); |
8428 | auto roiBatchSplitsH = getTensor(roiBatchSplits)->getHandle<T>(); |
8429 | |
8430 | // Count the number of RoIs per batch |
8431 | std::vector<int> numRoisPerBatch(batchSize, 0); |
8432 | if (roiIn->dims()[1] == boxDim) { |
8433 | numRoisPerBatch[0] = N; |
8434 | } else { |
8435 | for (dim_t i = 0; i < N; ++i) { |
8436 | const int roiBatchId = roisH.at({i, 0}); |
8437 | numRoisPerBatch[roiBatchId]++; |
8438 | } |
8439 | } |
8440 | |
8441 | auto imInfoH = getTensor(imInfoIn)->getHandle<T>(); |
8442 | auto boxOutH = getTensor(boxOut)->getHandle<T>(); |
8443 | getTensor(boxOut)->zero(); |
8444 | |
8445 | // Default value for minimum bounding box width and height after bounding |
8446 | // box transformation (bbox_transform()) in log-space |
8447 | const T bboxXformClip = std::log(1000.0 / 16.0); |
8448 | |
8449 | // We assume roiIn and deltaIn over multiple batches are grouped |
8450 | // together in increasing order as generated by GenerateProposalsOp |
8451 | dim_t offset = 0; |
8452 | for (dim_t i = 0; i < batchSize; ++i) { |
8453 | const dim_t numRois = numRoisPerBatch[i]; |
8454 | const T scaleBefore = imInfoH.at({i, 2}); |
8455 | const T scaleAfter = applyScale ? scaleBefore : T(1.0); |
8456 | dim_t imgH = dim_t(float(imInfoH.at({i, 0}) / scaleBefore) + 0.5); |
8457 | dim_t imgW = dim_t(float(imInfoH.at({i, 1}) / scaleBefore) + 0.5); |
8458 | |
8459 | // Apply for the rectangle starting at (startRowRoi, startColRoi) |
8460 | // with height (Rows) of num_rois, and width (Cols) of boxDim. |
8461 | dim_t startRowRoi = offset; |
8462 | dim_t startColRoi = roiIn->dims()[1] != boxDim ? 1 : 0; |
8463 | dim_t rows = numRois; |
8464 | T scaleBeforeInv = T(1) / scaleBefore; |
8465 | |
8466 | // scale before and after on the fly. |
8467 | // Do not apply scale for angle in rotated boxes |
8468 | for (dim_t k = 0; k < numClasses; k++) { |
8469 | dim_t startRowDelta = offset; |
8470 | dim_t startColDelta = k * boxDim; |
8471 | if (rotated) { |
8472 | bbox_transform_rotated<T>(boxOutH, roisH, deltasH, startRowDelta, |
8473 | startColDelta, startRowRoi, startColRoi, |
8474 | startRowDelta, startColDelta, rows, weights, |
8475 | bboxXformClip, scaleBeforeInv, angleBoundOn, |
8476 | angleBoundLo, angleBoundHi); |
8477 | clip_boxes_rotated<T>(boxOutH, startRowDelta, startColDelta, rows, imgH, |
8478 | imgW, scaleAfter, angleThresh, legacyPlusOne); |
8479 | } else { |
8480 | bbox_transform_upright<T>(boxOutH, roisH, deltasH, startRowDelta, |
8481 | startColDelta, startRowRoi, startColRoi, |
8482 | startRowDelta, startColDelta, rows, weights, |
8483 | bboxXformClip, scaleBeforeInv, legacyPlusOne); |
8484 | clip_boxes_upright<T>(boxOutH, startRowDelta, startColDelta, rows, imgH, |
8485 | imgW, scaleAfter, legacyPlusOne); |
8486 | } |
8487 | } |
8488 | |
8489 | offset += rows; |
8490 | } |
8491 | |
8492 | for (dim_t i = 0; i < batchSize; i++) { |
8493 | roiBatchSplitsH.at({i}) = numRoisPerBatch[i]; |
8494 | } |
8495 | } |
8496 | |
8497 | void BoundInterpreterFunction::fwdBBoxTransformInst( |
8498 | glow::BBoxTransformInst const *I) { |
8499 | dispatchFloatingPointImpl(fwdBBoxTransformInstFloatImpl, |
8500 | I->getRois()->getElementType(), I); |
8501 | } |
8502 | |
8503 | void BoundInterpreterFunction::fwdExternalFunctionCallInst( |
8504 | glow::ExternalFunctionCallInst const *) { |
8505 | LOG(FATAL) << "ExternalFunctionCallInst is not supported yet" ; |
8506 | } |
8507 | |