1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H |
17 | #define GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H |
18 | |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/Backend/BackendUtils.h" |
21 | #include "glow/Backend/CompiledFunction.h" |
22 | #include "glow/Base/Tensor.h" |
23 | #include "glow/ExecutionContext/ExecutionContext.h" |
24 | #include "glow/Quantization/Base/Base.h" |
25 | |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | |
28 | #include <memory> |
29 | #include <unordered_map> |
30 | |
31 | namespace glow { |
32 | |
33 | class IRFunction; |
34 | class Value; |
35 | class Tensor; |
36 | class Constant; |
37 | |
38 | // Forward declare all of the classes. |
39 | #define DEF_VALUE(CLASS, NAME) class CLASS; |
40 | #define DEF_INSTR(CLASS, NAME) class CLASS; |
41 | #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) |
42 | #include "glow/AutoGenInstr.def" |
43 | |
44 | /// Function "compiled" for execution by the interpreter. |
45 | class InterpreterFunction final : public CompiledFunction, |
46 | public IRInstructionProcessingHandler { |
47 | /// The IR to be executed. |
48 | std::unique_ptr<IRFunction> F_; |
49 | |
50 | /// Maps Value.name to tensors for constants. |
51 | std::unordered_map<std::string, Tensor *> constants_; |
52 | |
53 | public: |
54 | InterpreterFunction(std::unique_ptr<IRFunction> F, |
55 | runtime::RuntimeBundle &&bundle); |
56 | |
57 | /// \name CompiledFunction interface |
58 | ///@{ |
59 | ~InterpreterFunction() override; |
60 | |
61 | Error execute(ExecutionContext *context) override; |
62 | |
63 | /// Collects constants for runtime. |
64 | void collectConstants(const Module *module) override; |
65 | |
66 | /// Add a constant to the function, this is used for loading static |
67 | /// placeholders. |
68 | void addConstant(std::string name, Tensor *T); |
69 | |
70 | /// Get reference to IR function. |
71 | IRFunction *getIR() { return F_.get(); } |
72 | |
73 | /// Read trace events out of this func and write them into /p context |
74 | void translateTraceEvents(ExecutionContext *context) const override; |
75 | |
76 | /// \returns the backend used to compile this function. |
77 | virtual std::string getCompileBackendName() const override { |
78 | return "Interpreter" ; |
79 | } |
80 | ///@} |
81 | }; |
82 | |
83 | /// An InterpreterFunction bound to a specific invocation. |
84 | class BoundInterpreterFunction : public IRInstructionProcessingHandler { |
85 | /// Maps values to Tensors, that are owned by this class. |
86 | std::unordered_map<const Value *, Tensor *> tensors_; |
87 | |
88 | /// Maps values to Tensors, that are *not* owned by this class. |
89 | std::unordered_map<const Value *, Tensor *> externalTensors_; |
90 | |
91 | /// A reference to the constant map from the owning InterpreterFunction. |
92 | const std::unordered_map<std::string, Tensor *> &constants_; |
93 | |
94 | public: |
95 | explicit BoundInterpreterFunction( |
96 | const std::unordered_map<std::string, Tensor *> &constants) |
97 | : constants_(constants) {} |
98 | |
99 | ~BoundInterpreterFunction(); |
100 | |
101 | Error execute(IRFunction *F, ExecutionContext *context); |
102 | |
103 | /// \returns a pointer to the tensor that is saved under \p v. |
104 | Tensor *getTensor(const Value *v) const; |
105 | |
106 | /// \returns a typed handle to the tensor that is stored at \p v. |
107 | template <class ElemTy = float> |
108 | Handle<ElemTy> getWeightHandle(Value *v) const { |
109 | return getTensor(v)->getHandle<ElemTy>(); |
110 | } |
111 | |
112 | private: |
113 | /// Allocate a tensor to back the value \p v. Do not allocate anything if a |
114 | /// tensor is already allocated for \p v. |
115 | /// \returns a tensor for \p v. |
116 | Tensor *getOrCreateTensor(const Value *v); |
117 | |
118 | /// Allocate an unowned tensor to back the value \p v. The source tensor of |
119 | /// the unowned tensor is provided by \p src. |
120 | /// \returns a tensor for \p v. |
121 | Tensor *getOrCreateUnownedTensor(const Value *v, const Value *src, |
122 | llvm::ArrayRef<dim_t> offsets); |
123 | |
124 | /// If a tensor is allocated for \p v then delete it. |
125 | void deleteTensor(const Value *v); |
126 | |
127 | /// @name BoundInterpreterFunction methods. This is a list of method |
128 | /// declerations that are used by the interpreter to dispatch different |
129 | /// instructions. |
130 | ///@{ |
131 | |
132 | #define DEF_VALUE(CLASS, NAME) |
133 | #define DEF_INSTR(CLASS, NAME) void fwd##CLASS(const CLASS *I); |
134 | #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) |
135 | #include "glow/AutoGenInstr.def" |
136 | |
137 | template <typename ElemTy, typename AccumulatorTy, |
138 | typename BiasElemTy = int32_t> |
139 | void fwdConvolutionInstQuantizedImpl(Value *inV, Value *outV, Value *filterV, |
140 | Value *biasV, |
141 | llvm::ArrayRef<unsigned_t> kernelSizes, |
142 | llvm::ArrayRef<unsigned_t> strides, |
143 | llvm::ArrayRef<unsigned_t> pads, |
144 | size_t group, |
145 | llvm::ArrayRef<unsigned_t> dilation); |
146 | |
147 | template <typename ElemTy = float> |
148 | void fwdConvolutionInstFloatImpl(Value *inV, Value *outV, Value *filterV, |
149 | Value *biasV, |
150 | llvm::ArrayRef<unsigned_t> kernelSizes, |
151 | llvm::ArrayRef<unsigned_t> strides, |
152 | llvm::ArrayRef<unsigned_t> pads, |
153 | size_t group, |
154 | llvm::ArrayRef<unsigned_t> dilation); |
155 | |
156 | template <typename ElemTy, typename AccumulatorTy, |
157 | typename BiasElemTy = int32_t> |
158 | void fwdConvolution3DInstQuantizedImpl(Value *inV, Value *outV, |
159 | Value *filterV, Value *biasV, |
160 | llvm::ArrayRef<unsigned_t> kernelSizes, |
161 | llvm::ArrayRef<unsigned_t> strides, |
162 | llvm::ArrayRef<unsigned_t> pads, |
163 | size_t group); |
164 | |
165 | template <typename ElemTy = float> |
166 | void fwdConvolution3DInstFloatImpl(Value *inV, Value *outV, Value *filterV, |
167 | Value *biasV, |
168 | llvm::ArrayRef<unsigned_t> kernelSizes, |
169 | llvm::ArrayRef<unsigned_t> strides, |
170 | llvm::ArrayRef<unsigned_t> pads, |
171 | size_t group); |
172 | |
173 | template <typename ElemTy = float> |
174 | void fwdConvTransposeInstFloatImpl(Value *inV, Value *outV, Value *filterV, |
175 | Value *biasV, |
176 | llvm::ArrayRef<unsigned_t> kernelSizes, |
177 | llvm::ArrayRef<unsigned_t> strides, |
178 | llvm::ArrayRef<unsigned_t> pads, |
179 | size_t group, |
180 | llvm::ArrayRef<unsigned_t> dilation); |
181 | |
182 | template <typename ElemTy = float> |
183 | void fwdBatchNormalizationFloatImpl(const BatchNormalizationInst *I, |
184 | int numDims); |
185 | template <typename ParamTy = float16_t> |
186 | void fwdBatchNormalizationI8Impl(const BatchNormalizationInst *I, |
187 | int numDims); |
188 | |
189 | template <typename ElemTy = float> |
190 | void |
191 | fwdLayerNormalizationInstFloatImpl(const glow::LayerNormalizationInst *I); |
192 | |
193 | void fwdAvgPoolInstI8Impl(const AvgPoolInst *I); |
194 | template <typename ElemTy> void fwdAvgPoolInstFloatImpl(const AvgPoolInst *I); |
195 | |
196 | void fwdAvgPool3DInstI8Impl(const AvgPoolInst *I); |
197 | template <typename ElemTy> |
198 | void fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I); |
199 | |
200 | void fwdAdaptiveAvgPoolInstI8Impl(const AdaptiveAvgPoolInst *I); |
201 | template <typename ElemTy> |
202 | void fwdAdaptiveAvgPoolInstFloatImpl(const AdaptiveAvgPoolInst *I); |
203 | |
204 | template <typename ElemTy> void fwdSoftMaxInstImpl(const SoftMaxInst *I); |
205 | template <typename ElemTy> |
206 | void fwdLogSoftMaxInstImpl(const LogSoftMaxInst *I); |
207 | |
208 | template <typename ElemTy, typename AccumulatorTy> |
209 | void fwdMatMulInstQuantizedImpl(const MatMulInst *I); |
210 | template <typename ElemTy> void fwdMatMulInstFloatImpl(const MatMulInst *I); |
211 | |
212 | template <typename ElemTy> |
213 | void fwdBatchMatMulInstFloatImpl(const BatchMatMulInst *I); |
214 | |
215 | template <typename ElemTy, typename AccumulatorTy, |
216 | typename BiasElemTy = int32_t> |
217 | void fwdFullyConnectedInstQuantizedImpl(const FullyConnectedInst *I); |
218 | template <typename ElemTy> |
219 | void fwdFullyConnectedInstFloatImpl(const FullyConnectedInst *I); |
220 | |
221 | template <typename ElemTy, typename OutputTy, typename AccumulatorTy> |
222 | void fwdDynRowwiseQuantizedFullyConnectedInstImpl( |
223 | Handle<ElemTy> inW, Handle<OutputTy> &outW, dim_t baseRow, |
224 | Handle<ElemTy> weightsW, Handle<float> biasW, Handle<float> scalesW, |
225 | Handle<int32_t> offsetsW); |
226 | |
227 | void fwdDynRowwiseQuantizedFullyConnectedInstPreimpl( |
228 | Tensor *inputTensor, Tensor *weightsTensor, Tensor *biasTensor, |
229 | Tensor *resultTensor, Tensor *wScaleTensor, Tensor *wOffsetTensor, |
230 | bool isSymmetric, bool isPerBatchElement); |
231 | |
232 | template <typename ElemTy, typename AccumulatorTy, |
233 | typename BiasElemTy = int32_t> |
234 | void fwdRowwiseQuantizedFullyConnectedInstImpl(Value *inV, Value *outV, |
235 | Value *weightsV, Value *biasV, |
236 | Value *scalesV, |
237 | Value *offsetsV); |
238 | |
239 | template <typename ElemTy, typename AccumulatorTy, |
240 | typename BiasElemTy = int32_t> |
241 | void fwdChannelwiseQuantizedConv2DInstImpl( |
242 | const ChannelwiseQuantizedConvolutionInst *I); |
243 | |
244 | template <typename ElemTy, typename AccumulatorTy, |
245 | typename BiasElemTy = int32_t> |
246 | void fwdChannelwiseQuantizedConv3DInstImpl( |
247 | const ChannelwiseQuantizedConvolutionInst *I); |
248 | |
249 | void fwdElementAddInstI8Impl(const ElementAddInst *I); |
250 | template <typename ElemTy> |
251 | void fwdElementAddInstArithmeticImpl(const ElementAddInst *I); |
252 | |
253 | void fwdElementMaxInstI8Impl(const ElementMaxInst *I); |
254 | template <typename ElemTy> |
255 | void fwdElementMaxInstArithmeticImpl(const ElementMaxInst *I); |
256 | |
257 | template <typename ElemTy> |
258 | void fwdBatchedAddInstFloatImpl(const BatchedAddInst *I); |
259 | |
260 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
261 | typename CmpTy = ElemTy> |
262 | void fwdElementCmpEQInstImpl(const ElementCmpEQInst *I); |
263 | |
264 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
265 | typename CmpTy = ElemTy> |
266 | void fwdElementCmpNEQInstImpl(const ElementCmpNEQInst *I); |
267 | |
268 | template <typename ElemTy> |
269 | void fwdBatchOneHotImpl(const glow::BatchOneHotInst *I); |
270 | |
271 | template <typename ElemTy> |
272 | void fwdSpaceToDepthInstImpl(const glow::SpaceToDepthInst *I); |
273 | |
274 | template <typename ElemTy> |
275 | void fwdResizeNearestInstImpl(const ResizeNearestInst *I); |
276 | |
277 | template <typename ElemTy> |
278 | void fwdResizeBilinearInstImpl(const ResizeBilinearInst *I); |
279 | |
280 | template <typename ElemTy> void fwdSigmoidInstFloatImpl(const SigmoidInst *I); |
281 | |
282 | template <typename ElemTy> void fwdTanhInstFloatImpl(const TanhInst *I); |
283 | |
284 | template <typename ElemTy> |
285 | void fwdSoftPlusInstFloatImpl(const SoftPlusInst *I); |
286 | |
287 | template <typename ElemTy> |
288 | void fwdCrossEntropyLossInstFloatImpl(const CrossEntropyLossInst *I); |
289 | |
290 | template <typename ElemTy> |
291 | void fwdLocalResponseNormalizationInstFloatImpl( |
292 | const glow::LocalResponseNormalizationInst *I); |
293 | |
294 | template <typename ElemTy> |
295 | void fwdElementSubInstArithmeticImpl(const ElementSubInst *I); |
296 | |
297 | template <typename ElemTy> |
298 | void fwdElementMulInstArithmeticImpl(const ElementMulInst *I); |
299 | |
300 | template <typename ElemTy> |
301 | void fwdElementMinInstArithmeticImpl(const ElementMinInst *I); |
302 | |
303 | template <typename ElemTy> |
304 | void fwdElementBitwiseOrInstImpl(const ElementBitwiseOrInst *I); |
305 | |
306 | template <typename ElemTy> |
307 | void fwdElementBitwiseXorInstImpl(const ElementBitwiseXorInst *I); |
308 | |
309 | template <typename ElemTy> |
310 | void fwdElementBitwiseAndInstImpl(const ElementBitwiseAndInst *I); |
311 | |
312 | template <typename ElemTy, typename InstKind> |
313 | void fwdUnaryArithmeticImpl(const InstKind *I, |
314 | std::function<float(float)> func); |
315 | |
316 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
317 | typename CmpTy = ElemTy> |
318 | void fwdElementCmpLTEInstImpl(const ElementCmpLTEInst *I); |
319 | |
320 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
321 | typename CmpTy = ElemTy> |
322 | void fwdElementCmpLTInstImpl(const ElementCmpLTInst *I); |
323 | |
324 | template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy, |
325 | typename CmpTy, typename InstCmpKind> |
326 | void |
327 | fwdElementCmpHelperImpl(const InstCmpKind *I, |
328 | std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper); |
329 | |
330 | template <typename ElemTy> |
331 | void fwdElementPowInstFloatImpl(const ElementPowInst *I); |
332 | |
333 | void fwdElementPowInstI8Impl(const ElementPowInst *I); |
334 | |
335 | template <typename ElemTy> |
336 | void fwdElementIsNaNInstFloatImpl(const ElementIsNaNInst *I); |
337 | |
338 | template <typename ElemTy> |
339 | void fwdElementLogInstFloatImpl(const ElementLogInst *I); |
340 | |
341 | template <typename ElemTy> |
342 | void fwdElementExpInstFloatImpl(const ElementExpInst *I); |
343 | |
344 | template <typename ElemTy> |
345 | void fwdElementSignInstFloatImpl(const ElementSignInst *I); |
346 | |
347 | template <typename ElemTy> void fwdNonZeroInstImpl(const NonZeroInst *I); |
348 | |
349 | template <typename ElemTy> |
350 | void fwdElementSelectInstFloatImpl(const ElementSelectInst *I); |
351 | |
352 | template <typename ElemTy, typename InstKind> |
353 | void fwdUnaryTrigonometricImpl(const InstKind *I, |
354 | std::function<float(float)> func); |
355 | template <typename ElemTy> |
356 | void fwdBatchedReduceAddInstImpl(Value *batch, Value *dest, unsigned_t axis, |
357 | const ShapeVector &eBatchDims, |
358 | const ShapeVector &eDestDims); |
359 | |
360 | template <typename ElemTy> |
361 | void fwdBatchedReduceMinInstImpl(Value *batch, Value *dest, |
362 | const ShapeVector &eBatchDims, |
363 | const ShapeVector &eDestDims, ElemTy max); |
364 | |
365 | template <typename ElemTy> |
366 | void fwdBatchedReduceMaxInstImpl(Value *batch, Value *dest, |
367 | const ShapeVector &eBatchDims, |
368 | const ShapeVector &eDestDims, ElemTy min); |
369 | |
370 | template <typename ElemTy> |
371 | void fwdBatchedReduceProdInstFloatImpl(Value *batch, Value *dest, |
372 | unsigned_t axis, |
373 | const ShapeVector &eBatchDims, |
374 | const ShapeVector &eDestDims); |
375 | |
376 | template <typename ElemTy> |
377 | void fwdCumSumInstImpl(Value *input, Value *dest, int64_t dim, bool exclusive, |
378 | bool reverse); |
379 | |
380 | template <typename ElemTy> |
381 | void fwdLengthsSumInstFloatImpl(const LengthsSumInst *I); |
382 | |
383 | template <typename ElemTy> void fwdGatherInstImpl(const GatherInst *I); |
384 | template <typename ElemTy> void fwdGatherNDInstImpl(const GatherNDInst *I); |
385 | template <typename IndexTy> |
386 | void fwdGatherElementsInstImpl(const GatherElementsInst *I); |
387 | template <typename ElemTy> |
388 | void fwdGatherRangesInstImpl(const GatherRangesInst *I); |
389 | template <typename ElemTy, typename IndicesElemTy> |
390 | void fwdScatterDataInstCopyImpl(const ScatterDataInst *I); |
391 | template <typename ElemTy, typename IndicesElemTy> |
392 | void fwdScatterDataInstAddFloatImpl(const ScatterDataInst *I); |
393 | template <typename ElemTy, typename IndicesElemTy> |
394 | void fwdScatterDataInstAddQuantizedImpl(const ScatterDataInst *I); |
395 | |
396 | template <typename ElemTy> |
397 | void fwdSparseLengthsSumInstI8Impl(const SparseLengthsSumInst *I); |
398 | template <typename ElemTy, typename TI> |
399 | void fwdSparseLengthsSumInstFloatImpl(const SparseLengthsSumInst *I); |
400 | |
401 | template <typename ElemTy> |
402 | void |
403 | fwdSparseLengthsWeightedSumInstI8Impl(const SparseLengthsWeightedSumInst *I); |
404 | template <typename ElemTy, typename TI> |
405 | void fwdSparseLengthsWeightedSumInstFloatImpl( |
406 | const SparseLengthsWeightedSumInst *I); |
407 | |
408 | template <typename ElemTy> |
409 | void fwdEmbeddingInstImpl(Tensor *wtT, Tensor *indT, Tensor *outT, |
410 | int64_t padIdx, bool sparse, bool scale, |
411 | dim_t embedding_dim); |
412 | |
413 | template <typename ElemTy, typename IndexTy> |
414 | void fwdEmbeddingBagInstFloatImpl(const EmbeddingBagInst *I); |
415 | |
416 | template <typename ElemTy, typename IndexTy> |
417 | void fwdBatchSparseToDenseInstImpl1(const BatchSparseToDenseInst *I); |
418 | |
419 | template <typename ElemTy, typename LengthsTy, typename IndicesTy> |
420 | void fwdBatchSparseToDenseInstImpl2(const BatchSparseToDenseInst *I); |
421 | |
422 | template <typename ElemTy> |
423 | void |
424 | fwdFillExamplesWithIndicatorInstImpl1(const FillExamplesWithIndicatorInst *I); |
425 | |
426 | template <typename ElemTy, typename IndicatorTy> |
427 | void |
428 | fwdFillExamplesWithIndicatorInstImpl2(const FillExamplesWithIndicatorInst *I); |
429 | |
430 | template <class eTy> |
431 | void fwdRescaleQuantizedInstImpl(Value *src, Value *dest, |
432 | TensorQuantizationParams &srcQ, |
433 | TensorQuantizationParams &destQ); |
434 | |
435 | template <typename ElemTy> void fwdModuloInstImpl(glow::ModuloInst const *I); |
436 | |
437 | template <typename ElemTy> |
438 | void fwdCollectRpnProposalsInstImpl(const CollectRpnProposalsInst *I); |
439 | |
440 | template <typename T, typename AccumT, typename TI> |
441 | void fwdRowwiseQuantizedSparseLengthsWeightedSumImpl( |
442 | const RowwiseQuantizedSparseLengthsWeightedSumInst *I); |
443 | |
444 | template <typename T, typename AccumT, typename TI> |
445 | void fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl( |
446 | const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I); |
447 | |
448 | template <typename T> |
449 | void fwdNonMaxSuppressionInstImpl(glow::NonMaxSuppressionInst const *I); |
450 | |
451 | void fwdAudioSpectrogramInstFloatImpl(glow::AudioSpectrogramInst const *I); |
452 | |
453 | void fwdMFCCInstFloatImpl(glow::MFCCInst const *I); |
454 | |
455 | template <typename T> |
456 | void fwdROIAlignInstFloatImpl(glow::ROIAlignInst const *I); |
457 | |
458 | template <typename T> |
459 | void fwdBBoxTransformInstFloatImpl(glow::BBoxTransformInst const *I); |
460 | |
461 | template <typename T, typename AccumT, typename IndexT> |
462 | void fwdEmbeddingBagByteRowwiseOffsetsImpl( |
463 | const EmbeddingBagByteRowwiseOffsetsInst *I); |
464 | |
465 | template <typename ElemTy> void fwdFlipInstImpl(const FlipInst *I); |
466 | |
467 | template <typename ElemTy> |
468 | void fwdBatchedPairwiseDotProductInstImpl( |
469 | const glow::BatchedPairwiseDotProductInst *I); |
470 | |
471 | template <typename ElemTy> |
472 | void fwdBatchedPairwiseDotProductGradInstImpl( |
473 | const glow::BatchedPairwiseDotProductGradInst *I); |
474 | void fwdAvgPool2DGradInst(const AvgPoolGradInst *I); |
475 | void fwdAvgPool3DGradInst(const AvgPoolGradInst *I); |
476 | |
477 | template <typename ElemTy, typename IndexTy> |
478 | void fwdBatchedUnaryEmbeddingsBagsInstImpl( |
479 | const BatchedUnaryEmbeddingsBagsInst *I); |
480 | |
481 | template <typename IndexTy, typename OutputTy> |
482 | void |
483 | fwdIntNBitSplitEmbeddingBagsInstImpl(const IntNBitSplitEmbeddingBagsInst *I); |
484 | |
485 | template <typename IndexTy, typename OutputTy> |
486 | void fwdIntNBitSplitEmbeddingWeightedBagsInstImpl( |
487 | const IntNBitSplitEmbeddingWeightedBagsInst *I); |
488 | |
489 | template <typename IndexTy, typename OutputTy> |
490 | void fwdIntNBitSplitEmbeddingWeightedBagsImpl( |
491 | Tensor *out, Tensor *devWeights, Tensor *uvmWeights, |
492 | Tensor *weightsPlacements, Tensor *weightsTys, Tensor *dimOffsets, |
493 | Tensor *indices, Tensor *offsets, Tensor *weightsOffsets, |
494 | int64_t poolingMode, Tensor *indiceWeights, int64_t totalDims, |
495 | int64_t outputDType); |
496 | |
497 | ///@} |
498 | }; |
499 | |
500 | } // end namespace glow |
501 | |
502 | #endif // GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H |
503 | |