1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H
17#define GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H
18
19#include "glow/Backend/Backend.h"
20#include "glow/Backend/BackendUtils.h"
21#include "glow/Backend/CompiledFunction.h"
22#include "glow/Base/Tensor.h"
23#include "glow/ExecutionContext/ExecutionContext.h"
24#include "glow/Quantization/Base/Base.h"
25
26#include "llvm/ADT/ArrayRef.h"
27
28#include <memory>
29#include <unordered_map>
30
31namespace glow {
32
33class IRFunction;
34class Value;
35class Tensor;
36class Constant;
37
38// Forward declare all of the classes.
39#define DEF_VALUE(CLASS, NAME) class CLASS;
40#define DEF_INSTR(CLASS, NAME) class CLASS;
41#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME)
42#include "glow/AutoGenInstr.def"
43
44/// Function "compiled" for execution by the interpreter.
45class InterpreterFunction final : public CompiledFunction,
46 public IRInstructionProcessingHandler {
47 /// The IR to be executed.
48 std::unique_ptr<IRFunction> F_;
49
50 /// Maps Value.name to tensors for constants.
51 std::unordered_map<std::string, Tensor *> constants_;
52
53public:
54 InterpreterFunction(std::unique_ptr<IRFunction> F,
55 runtime::RuntimeBundle &&bundle);
56
57 /// \name CompiledFunction interface
58 ///@{
59 ~InterpreterFunction() override;
60
61 Error execute(ExecutionContext *context) override;
62
63 /// Collects constants for runtime.
64 void collectConstants(const Module *module) override;
65
66 /// Add a constant to the function, this is used for loading static
67 /// placeholders.
68 void addConstant(std::string name, Tensor *T);
69
70 /// Get reference to IR function.
71 IRFunction *getIR() { return F_.get(); }
72
73 /// Read trace events out of this func and write them into /p context
74 void translateTraceEvents(ExecutionContext *context) const override;
75
76 /// \returns the backend used to compile this function.
77 virtual std::string getCompileBackendName() const override {
78 return "Interpreter";
79 }
80 ///@}
81};
82
83/// An InterpreterFunction bound to a specific invocation.
84class BoundInterpreterFunction : public IRInstructionProcessingHandler {
85 /// Maps values to Tensors, that are owned by this class.
86 std::unordered_map<const Value *, Tensor *> tensors_;
87
88 /// Maps values to Tensors, that are *not* owned by this class.
89 std::unordered_map<const Value *, Tensor *> externalTensors_;
90
91 /// A reference to the constant map from the owning InterpreterFunction.
92 const std::unordered_map<std::string, Tensor *> &constants_;
93
94public:
95 explicit BoundInterpreterFunction(
96 const std::unordered_map<std::string, Tensor *> &constants)
97 : constants_(constants) {}
98
99 ~BoundInterpreterFunction();
100
101 Error execute(IRFunction *F, ExecutionContext *context);
102
103 /// \returns a pointer to the tensor that is saved under \p v.
104 Tensor *getTensor(const Value *v) const;
105
106 /// \returns a typed handle to the tensor that is stored at \p v.
107 template <class ElemTy = float>
108 Handle<ElemTy> getWeightHandle(Value *v) const {
109 return getTensor(v)->getHandle<ElemTy>();
110 }
111
112private:
113 /// Allocate a tensor to back the value \p v. Do not allocate anything if a
114 /// tensor is already allocated for \p v.
115 /// \returns a tensor for \p v.
116 Tensor *getOrCreateTensor(const Value *v);
117
118 /// Allocate an unowned tensor to back the value \p v. The source tensor of
119 /// the unowned tensor is provided by \p src.
120 /// \returns a tensor for \p v.
121 Tensor *getOrCreateUnownedTensor(const Value *v, const Value *src,
122 llvm::ArrayRef<dim_t> offsets);
123
124 /// If a tensor is allocated for \p v then delete it.
125 void deleteTensor(const Value *v);
126
127 /// @name BoundInterpreterFunction methods. This is a list of method
128 /// declerations that are used by the interpreter to dispatch different
129 /// instructions.
130 ///@{
131
132#define DEF_VALUE(CLASS, NAME)
133#define DEF_INSTR(CLASS, NAME) void fwd##CLASS(const CLASS *I);
134#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME)
135#include "glow/AutoGenInstr.def"
136
137 template <typename ElemTy, typename AccumulatorTy,
138 typename BiasElemTy = int32_t>
139 void fwdConvolutionInstQuantizedImpl(Value *inV, Value *outV, Value *filterV,
140 Value *biasV,
141 llvm::ArrayRef<unsigned_t> kernelSizes,
142 llvm::ArrayRef<unsigned_t> strides,
143 llvm::ArrayRef<unsigned_t> pads,
144 size_t group,
145 llvm::ArrayRef<unsigned_t> dilation);
146
147 template <typename ElemTy = float>
148 void fwdConvolutionInstFloatImpl(Value *inV, Value *outV, Value *filterV,
149 Value *biasV,
150 llvm::ArrayRef<unsigned_t> kernelSizes,
151 llvm::ArrayRef<unsigned_t> strides,
152 llvm::ArrayRef<unsigned_t> pads,
153 size_t group,
154 llvm::ArrayRef<unsigned_t> dilation);
155
156 template <typename ElemTy, typename AccumulatorTy,
157 typename BiasElemTy = int32_t>
158 void fwdConvolution3DInstQuantizedImpl(Value *inV, Value *outV,
159 Value *filterV, Value *biasV,
160 llvm::ArrayRef<unsigned_t> kernelSizes,
161 llvm::ArrayRef<unsigned_t> strides,
162 llvm::ArrayRef<unsigned_t> pads,
163 size_t group);
164
165 template <typename ElemTy = float>
166 void fwdConvolution3DInstFloatImpl(Value *inV, Value *outV, Value *filterV,
167 Value *biasV,
168 llvm::ArrayRef<unsigned_t> kernelSizes,
169 llvm::ArrayRef<unsigned_t> strides,
170 llvm::ArrayRef<unsigned_t> pads,
171 size_t group);
172
173 template <typename ElemTy = float>
174 void fwdConvTransposeInstFloatImpl(Value *inV, Value *outV, Value *filterV,
175 Value *biasV,
176 llvm::ArrayRef<unsigned_t> kernelSizes,
177 llvm::ArrayRef<unsigned_t> strides,
178 llvm::ArrayRef<unsigned_t> pads,
179 size_t group,
180 llvm::ArrayRef<unsigned_t> dilation);
181
182 template <typename ElemTy = float>
183 void fwdBatchNormalizationFloatImpl(const BatchNormalizationInst *I,
184 int numDims);
185 template <typename ParamTy = float16_t>
186 void fwdBatchNormalizationI8Impl(const BatchNormalizationInst *I,
187 int numDims);
188
189 template <typename ElemTy = float>
190 void
191 fwdLayerNormalizationInstFloatImpl(const glow::LayerNormalizationInst *I);
192
193 void fwdAvgPoolInstI8Impl(const AvgPoolInst *I);
194 template <typename ElemTy> void fwdAvgPoolInstFloatImpl(const AvgPoolInst *I);
195
196 void fwdAvgPool3DInstI8Impl(const AvgPoolInst *I);
197 template <typename ElemTy>
198 void fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I);
199
200 void fwdAdaptiveAvgPoolInstI8Impl(const AdaptiveAvgPoolInst *I);
201 template <typename ElemTy>
202 void fwdAdaptiveAvgPoolInstFloatImpl(const AdaptiveAvgPoolInst *I);
203
204 template <typename ElemTy> void fwdSoftMaxInstImpl(const SoftMaxInst *I);
205 template <typename ElemTy>
206 void fwdLogSoftMaxInstImpl(const LogSoftMaxInst *I);
207
208 template <typename ElemTy, typename AccumulatorTy>
209 void fwdMatMulInstQuantizedImpl(const MatMulInst *I);
210 template <typename ElemTy> void fwdMatMulInstFloatImpl(const MatMulInst *I);
211
212 template <typename ElemTy>
213 void fwdBatchMatMulInstFloatImpl(const BatchMatMulInst *I);
214
215 template <typename ElemTy, typename AccumulatorTy,
216 typename BiasElemTy = int32_t>
217 void fwdFullyConnectedInstQuantizedImpl(const FullyConnectedInst *I);
218 template <typename ElemTy>
219 void fwdFullyConnectedInstFloatImpl(const FullyConnectedInst *I);
220
221 template <typename ElemTy, typename OutputTy, typename AccumulatorTy>
222 void fwdDynRowwiseQuantizedFullyConnectedInstImpl(
223 Handle<ElemTy> inW, Handle<OutputTy> &outW, dim_t baseRow,
224 Handle<ElemTy> weightsW, Handle<float> biasW, Handle<float> scalesW,
225 Handle<int32_t> offsetsW);
226
227 void fwdDynRowwiseQuantizedFullyConnectedInstPreimpl(
228 Tensor *inputTensor, Tensor *weightsTensor, Tensor *biasTensor,
229 Tensor *resultTensor, Tensor *wScaleTensor, Tensor *wOffsetTensor,
230 bool isSymmetric, bool isPerBatchElement);
231
232 template <typename ElemTy, typename AccumulatorTy,
233 typename BiasElemTy = int32_t>
234 void fwdRowwiseQuantizedFullyConnectedInstImpl(Value *inV, Value *outV,
235 Value *weightsV, Value *biasV,
236 Value *scalesV,
237 Value *offsetsV);
238
239 template <typename ElemTy, typename AccumulatorTy,
240 typename BiasElemTy = int32_t>
241 void fwdChannelwiseQuantizedConv2DInstImpl(
242 const ChannelwiseQuantizedConvolutionInst *I);
243
244 template <typename ElemTy, typename AccumulatorTy,
245 typename BiasElemTy = int32_t>
246 void fwdChannelwiseQuantizedConv3DInstImpl(
247 const ChannelwiseQuantizedConvolutionInst *I);
248
249 void fwdElementAddInstI8Impl(const ElementAddInst *I);
250 template <typename ElemTy>
251 void fwdElementAddInstArithmeticImpl(const ElementAddInst *I);
252
253 void fwdElementMaxInstI8Impl(const ElementMaxInst *I);
254 template <typename ElemTy>
255 void fwdElementMaxInstArithmeticImpl(const ElementMaxInst *I);
256
257 template <typename ElemTy>
258 void fwdBatchedAddInstFloatImpl(const BatchedAddInst *I);
259
260 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
261 typename CmpTy = ElemTy>
262 void fwdElementCmpEQInstImpl(const ElementCmpEQInst *I);
263
264 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
265 typename CmpTy = ElemTy>
266 void fwdElementCmpNEQInstImpl(const ElementCmpNEQInst *I);
267
268 template <typename ElemTy>
269 void fwdBatchOneHotImpl(const glow::BatchOneHotInst *I);
270
271 template <typename ElemTy>
272 void fwdSpaceToDepthInstImpl(const glow::SpaceToDepthInst *I);
273
274 template <typename ElemTy>
275 void fwdResizeNearestInstImpl(const ResizeNearestInst *I);
276
277 template <typename ElemTy>
278 void fwdResizeBilinearInstImpl(const ResizeBilinearInst *I);
279
280 template <typename ElemTy> void fwdSigmoidInstFloatImpl(const SigmoidInst *I);
281
282 template <typename ElemTy> void fwdTanhInstFloatImpl(const TanhInst *I);
283
284 template <typename ElemTy>
285 void fwdSoftPlusInstFloatImpl(const SoftPlusInst *I);
286
287 template <typename ElemTy>
288 void fwdCrossEntropyLossInstFloatImpl(const CrossEntropyLossInst *I);
289
290 template <typename ElemTy>
291 void fwdLocalResponseNormalizationInstFloatImpl(
292 const glow::LocalResponseNormalizationInst *I);
293
294 template <typename ElemTy>
295 void fwdElementSubInstArithmeticImpl(const ElementSubInst *I);
296
297 template <typename ElemTy>
298 void fwdElementMulInstArithmeticImpl(const ElementMulInst *I);
299
300 template <typename ElemTy>
301 void fwdElementMinInstArithmeticImpl(const ElementMinInst *I);
302
303 template <typename ElemTy>
304 void fwdElementBitwiseOrInstImpl(const ElementBitwiseOrInst *I);
305
306 template <typename ElemTy>
307 void fwdElementBitwiseXorInstImpl(const ElementBitwiseXorInst *I);
308
309 template <typename ElemTy>
310 void fwdElementBitwiseAndInstImpl(const ElementBitwiseAndInst *I);
311
312 template <typename ElemTy, typename InstKind>
313 void fwdUnaryArithmeticImpl(const InstKind *I,
314 std::function<float(float)> func);
315
316 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
317 typename CmpTy = ElemTy>
318 void fwdElementCmpLTEInstImpl(const ElementCmpLTEInst *I);
319
320 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
321 typename CmpTy = ElemTy>
322 void fwdElementCmpLTInstImpl(const ElementCmpLTInst *I);
323
324 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
325 typename CmpTy, typename InstCmpKind>
326 void
327 fwdElementCmpHelperImpl(const InstCmpKind *I,
328 std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper);
329
330 template <typename ElemTy>
331 void fwdElementPowInstFloatImpl(const ElementPowInst *I);
332
333 void fwdElementPowInstI8Impl(const ElementPowInst *I);
334
335 template <typename ElemTy>
336 void fwdElementIsNaNInstFloatImpl(const ElementIsNaNInst *I);
337
338 template <typename ElemTy>
339 void fwdElementLogInstFloatImpl(const ElementLogInst *I);
340
341 template <typename ElemTy>
342 void fwdElementExpInstFloatImpl(const ElementExpInst *I);
343
344 template <typename ElemTy>
345 void fwdElementSignInstFloatImpl(const ElementSignInst *I);
346
347 template <typename ElemTy> void fwdNonZeroInstImpl(const NonZeroInst *I);
348
349 template <typename ElemTy>
350 void fwdElementSelectInstFloatImpl(const ElementSelectInst *I);
351
352 template <typename ElemTy, typename InstKind>
353 void fwdUnaryTrigonometricImpl(const InstKind *I,
354 std::function<float(float)> func);
355 template <typename ElemTy>
356 void fwdBatchedReduceAddInstImpl(Value *batch, Value *dest, unsigned_t axis,
357 const ShapeVector &eBatchDims,
358 const ShapeVector &eDestDims);
359
360 template <typename ElemTy>
361 void fwdBatchedReduceMinInstImpl(Value *batch, Value *dest,
362 const ShapeVector &eBatchDims,
363 const ShapeVector &eDestDims, ElemTy max);
364
365 template <typename ElemTy>
366 void fwdBatchedReduceMaxInstImpl(Value *batch, Value *dest,
367 const ShapeVector &eBatchDims,
368 const ShapeVector &eDestDims, ElemTy min);
369
370 template <typename ElemTy>
371 void fwdBatchedReduceProdInstFloatImpl(Value *batch, Value *dest,
372 unsigned_t axis,
373 const ShapeVector &eBatchDims,
374 const ShapeVector &eDestDims);
375
376 template <typename ElemTy>
377 void fwdCumSumInstImpl(Value *input, Value *dest, int64_t dim, bool exclusive,
378 bool reverse);
379
380 template <typename ElemTy>
381 void fwdLengthsSumInstFloatImpl(const LengthsSumInst *I);
382
383 template <typename ElemTy> void fwdGatherInstImpl(const GatherInst *I);
384 template <typename ElemTy> void fwdGatherNDInstImpl(const GatherNDInst *I);
385 template <typename IndexTy>
386 void fwdGatherElementsInstImpl(const GatherElementsInst *I);
387 template <typename ElemTy>
388 void fwdGatherRangesInstImpl(const GatherRangesInst *I);
389 template <typename ElemTy, typename IndicesElemTy>
390 void fwdScatterDataInstCopyImpl(const ScatterDataInst *I);
391 template <typename ElemTy, typename IndicesElemTy>
392 void fwdScatterDataInstAddFloatImpl(const ScatterDataInst *I);
393 template <typename ElemTy, typename IndicesElemTy>
394 void fwdScatterDataInstAddQuantizedImpl(const ScatterDataInst *I);
395
396 template <typename ElemTy>
397 void fwdSparseLengthsSumInstI8Impl(const SparseLengthsSumInst *I);
398 template <typename ElemTy, typename TI>
399 void fwdSparseLengthsSumInstFloatImpl(const SparseLengthsSumInst *I);
400
401 template <typename ElemTy>
402 void
403 fwdSparseLengthsWeightedSumInstI8Impl(const SparseLengthsWeightedSumInst *I);
404 template <typename ElemTy, typename TI>
405 void fwdSparseLengthsWeightedSumInstFloatImpl(
406 const SparseLengthsWeightedSumInst *I);
407
408 template <typename ElemTy>
409 void fwdEmbeddingInstImpl(Tensor *wtT, Tensor *indT, Tensor *outT,
410 int64_t padIdx, bool sparse, bool scale,
411 dim_t embedding_dim);
412
413 template <typename ElemTy, typename IndexTy>
414 void fwdEmbeddingBagInstFloatImpl(const EmbeddingBagInst *I);
415
416 template <typename ElemTy, typename IndexTy>
417 void fwdBatchSparseToDenseInstImpl1(const BatchSparseToDenseInst *I);
418
419 template <typename ElemTy, typename LengthsTy, typename IndicesTy>
420 void fwdBatchSparseToDenseInstImpl2(const BatchSparseToDenseInst *I);
421
422 template <typename ElemTy>
423 void
424 fwdFillExamplesWithIndicatorInstImpl1(const FillExamplesWithIndicatorInst *I);
425
426 template <typename ElemTy, typename IndicatorTy>
427 void
428 fwdFillExamplesWithIndicatorInstImpl2(const FillExamplesWithIndicatorInst *I);
429
430 template <class eTy>
431 void fwdRescaleQuantizedInstImpl(Value *src, Value *dest,
432 TensorQuantizationParams &srcQ,
433 TensorQuantizationParams &destQ);
434
435 template <typename ElemTy> void fwdModuloInstImpl(glow::ModuloInst const *I);
436
437 template <typename ElemTy>
438 void fwdCollectRpnProposalsInstImpl(const CollectRpnProposalsInst *I);
439
440 template <typename T, typename AccumT, typename TI>
441 void fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(
442 const RowwiseQuantizedSparseLengthsWeightedSumInst *I);
443
444 template <typename T, typename AccumT, typename TI>
445 void fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
446 const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I);
447
448 template <typename T>
449 void fwdNonMaxSuppressionInstImpl(glow::NonMaxSuppressionInst const *I);
450
451 void fwdAudioSpectrogramInstFloatImpl(glow::AudioSpectrogramInst const *I);
452
453 void fwdMFCCInstFloatImpl(glow::MFCCInst const *I);
454
455 template <typename T>
456 void fwdROIAlignInstFloatImpl(glow::ROIAlignInst const *I);
457
458 template <typename T>
459 void fwdBBoxTransformInstFloatImpl(glow::BBoxTransformInst const *I);
460
461 template <typename T, typename AccumT, typename IndexT>
462 void fwdEmbeddingBagByteRowwiseOffsetsImpl(
463 const EmbeddingBagByteRowwiseOffsetsInst *I);
464
465 template <typename ElemTy> void fwdFlipInstImpl(const FlipInst *I);
466
467 template <typename ElemTy>
468 void fwdBatchedPairwiseDotProductInstImpl(
469 const glow::BatchedPairwiseDotProductInst *I);
470
471 template <typename ElemTy>
472 void fwdBatchedPairwiseDotProductGradInstImpl(
473 const glow::BatchedPairwiseDotProductGradInst *I);
474 void fwdAvgPool2DGradInst(const AvgPoolGradInst *I);
475 void fwdAvgPool3DGradInst(const AvgPoolGradInst *I);
476
477 template <typename ElemTy, typename IndexTy>
478 void fwdBatchedUnaryEmbeddingsBagsInstImpl(
479 const BatchedUnaryEmbeddingsBagsInst *I);
480
481 template <typename IndexTy, typename OutputTy>
482 void
483 fwdIntNBitSplitEmbeddingBagsInstImpl(const IntNBitSplitEmbeddingBagsInst *I);
484
485 template <typename IndexTy, typename OutputTy>
486 void fwdIntNBitSplitEmbeddingWeightedBagsInstImpl(
487 const IntNBitSplitEmbeddingWeightedBagsInst *I);
488
489 template <typename IndexTy, typename OutputTy>
490 void fwdIntNBitSplitEmbeddingWeightedBagsImpl(
491 Tensor *out, Tensor *devWeights, Tensor *uvmWeights,
492 Tensor *weightsPlacements, Tensor *weightsTys, Tensor *dimOffsets,
493 Tensor *indices, Tensor *offsets, Tensor *weightsOffsets,
494 int64_t poolingMode, Tensor *indiceWeights, int64_t totalDims,
495 int64_t outputDType);
496
497 ///@}
498};
499
500} // end namespace glow
501
502#endif // GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H
503