1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/LLVMIRCodeGen/LLVMBackend.h"
18#include "glow/LLVMIRCodeGen/BundleSaver.h"
19#include "glow/LLVMIRCodeGen/CommandLine.h"
20#include "glow/LLVMIRCodeGen/LLVMCompiledFunction.h"
21
22#include "glow/Backend/BackendUtils.h"
23#include "glow/Graph/Graph.h"
24#include "glow/Graph/PlaceholderBindings.h"
25#include "glow/IR/Instrs.h"
26#include "glow/Optimizer/IROptimizer/IROptimizer.h"
27#include "glow/Support/Debug.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/LLVMContext.h"
32#include "llvm/Support/Host.h"
33
34using namespace glow;
35
36namespace {
37
38//===----------------------------------------------------------------------===//
39// Functions for executing code using JIT
40//===----------------------------------------------------------------------===//
41
42/// Perform memory allocation for a JIT execution.
43void allocateJITMemory(const IRFunction *F, AllocationsInfo &allocationsInfo) {
44 allocationsInfo.numberValues(F);
45 allocationsInfo.allocateWeightVars(F);
46 allocationsInfo.allocateActivations(F);
47 allocationsInfo.allocateTensorViews(F);
48}
49
50} // end namespace
51
52bool LLVMBackend::isOpSupported(const NodeInfo &NI) const {
53 // Note: For brevity below, "X ==> Y, Z" signifes that Node X is IRGen'd into
54 // Instructions Y and Z.
55 switch (NI.getKind()) {
56 case Kinded::Kind::BatchedReduceMaxNodeKind:
57 case Kinded::Kind::BatchedReduceMinNodeKind:
58 return NI.allInputsAndOutputsHaveSameElemKind(
59 {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy});
60
61 case Kinded::Kind::BatchedReduceProdNodeKind:
62 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
63
64 case Kinded::Kind::AddNodeKind:
65 case Kinded::Kind::MulNodeKind:
66 return NI.allInputsAndOutputsHaveSameElemKind(
67 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy,
68 ElemKind::Int64ITy});
69
70 case Kinded::Kind::ReluNodeKind:
71 case Kinded::Kind::ClipNodeKind:
72 case Kinded::Kind::LeakyReluNodeKind:
73 case Kinded::Kind::SubNodeKind:
74 case Kinded::Kind::MaxNodeKind:
75 case Kinded::Kind::MinNodeKind:
76 case Kinded::Kind::BatchedReduceAddNodeKind:
77 case Kinded::Kind::MatMulNodeKind:
78 case Kinded::Kind::AvgPoolNodeKind:
79 return NI.allInputsAndOutputsHaveSameElemKind(
80 {ElemKind::FloatTy, ElemKind::Int8QTy});
81
82 case Kinded::Kind::AdaptiveAvgPoolNodeKind:
83 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
84
85 case Kinded::Kind::MaxPoolNodeKind:
86 return NI.allInputsAndOutputsHaveSameElemKind(
87 {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
88 {MaxPoolNode::ArgmaxIdx}) &&
89 (NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy ||
90 NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int32ITy);
91
92 case Kinded::Kind::ArgMaxNodeKind:
93 case Kinded::Kind::ArgMinNodeKind:
94 return NI.allInputsAndOutputsHaveSameElemKind(
95 {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
96 {ArgMaxNode::ResultIdx}) &&
97 (NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
98 NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int32ITy);
99
100 case Kinded::Kind::ResizeNearestNodeKind:
101 case Kinded::Kind::ResizeBilinearNodeKind:
102 return NI.allInputsAndOutputsHaveSameElemKind(
103 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy,
104 ElemKind::Int32ITy, ElemKind::Int64ITy});
105
106 case Kinded::Kind::SaveNodeKind:
107 case Kinded::Kind::ReshapeNodeKind:
108 // These are implemented via a Copy Instruction.
109 return NI.allInputsAndOutputsHaveSameElemKind(
110 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
111 ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
112 ElemKind::BoolTy});
113
114 // InsertTensor ==> Copy + InsertTensor. Copy supports everything
115 // ReshapeNode above supports, so InsertTensor is the limiting factor.
116 case Kinded::Kind::InsertTensorNodeKind:
117 // Concat ==> Splat + Insert. Both only support the following.
118 case Kinded::Kind::ConcatNodeKind:
119 case Kinded::Kind::SplatNodeKind:
120 case Kinded::Kind::TouchNodeKind:
121 return NI.allInputsAndOutputsHaveSameElemKind(
122 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
123 ElemKind::Int32ITy, ElemKind::BoolTy});
124 case Kinded::Kind::SliceNodeKind:
125 return NI.allInputsAndOutputsHaveSameElemKind(
126 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy,
127 ElemKind::Int32ITy, ElemKind::Int64ITy});
128 case Kinded::Kind::FmodNodeKind:
129 return NI.allInputsAndOutputsHaveSameElemKind(
130 {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy});
131 case Kinded::Kind::SpaceToDepthNodeKind:
132 case Kinded::Kind::DivNodeKind:
133 return NI.allInputsAndOutputsHaveSameElemKind(
134 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
135 ElemKind::Int32ITy});
136
137 case Kinded::Kind::TransposeNodeKind:
138 return NI.allInputsAndOutputsHaveSameElemKind(
139 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
140 ElemKind::BoolTy});
141
142 case Kinded::Kind::FlipNodeKind:
143 return NI.allInputsAndOutputsHaveSameElemKind(
144 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
145 ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
146 ElemKind::BoolTy});
147
148 case Kinded::Kind::SparseLengthsSumNodeKind:
149 return NI.allInputsAndOutputsHaveSameElemKind(
150 {ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx,
151 SparseLengthsSumNode::LengthsIdx}) &&
152 (NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
153 ElemKind::Int64ITy ||
154 NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
155 ElemKind::Int32ITy) &&
156 (NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
157 ElemKind::Int32ITy);
158
159 case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
160 return NI.allInputsAndOutputsHaveSameElemKind(
161 {ElemKind::FloatTy},
162 {SparseLengthsWeightedSumNode::IndicesIdx,
163 SparseLengthsWeightedSumNode::LengthsIdx}) &&
164 (NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
165 ElemKind::Int64ITy ||
166 NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
167 ElemKind::Int32ITy) &&
168 (NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
169 ElemKind::Int32ITy);
170
171 case Kinded::Kind::EmbeddingNodeKind:
172 return NI.allInputsAndOutputsHaveSameElemKind(
173 {ElemKind::FloatTy}, {EmbeddingNode::IndicesIdx}) &&
174 (NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);
175
176 case Kinded::Kind::EmbeddingBagNodeKind:
177 return NI.allInputsAndOutputsHaveSameElemKind(
178 {ElemKind::FloatTy},
179 {EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
180 (NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
181 ElemKind::Int32ITy) &&
182 (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int32ITy);
183
184 case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind:
185 // GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to
186 // be checked because they are not used.
187 return NI.allInputsAndOutputsHaveSameElemKind(
188 {ElemKind::FloatTy},
189 {SparseLengthsWeightedSumGradNode::IndicesIdx,
190 SparseLengthsWeightedSumGradNode::LengthsIdx},
191 {SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx,
192 SparseLengthsWeightedSumGradNode::
193 GradOfInputNamedLengthsIdx}) &&
194 (NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
195 ElemKind::Int32ITy ||
196 NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
197 ElemKind::Int32ITy) &&
198 (NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) ==
199 ElemKind::Int32ITy);
200
201 case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
202 return (NI.getInElemTy(
203 RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
204 ElemKind::UInt8QTy) &&
205 (NI.getInElemTy(
206 RowwiseQuantizedSparseLengthsWeightedSumNode::ScalesIdx) ==
207 ElemKind::FloatTy) &&
208 (NI.getInElemTy(
209 RowwiseQuantizedSparseLengthsWeightedSumNode::OffsetsIdx) ==
210 ElemKind::FloatTy) &&
211 (NI.getInElemTy(
212 RowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
213 ElemKind::FloatTy) &&
214 (NI.getInElemTy(
215 RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
216 ElemKind::Int64ITy ||
217 NI.getInElemTy(
218 RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
219 ElemKind::Int32ITy) &&
220 (NI.getInElemTy(
221 RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
222 ElemKind::Int32ITy) &&
223 (NI.getOutElemTy(
224 RowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
225 ElemKind::FloatTy);
226
227 case Kinded::Kind::LengthsRangeFillNodeKind:
228 case Kinded::Kind::LengthsToRangesNodeKind:
229 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
230
231 case Kinded::Kind::IntLookupTableNodeKind:
232 return NI.allInputsAndOutputsHaveSameElemKind(
233 {ElemKind::Int8QTy, ElemKind::Int16QTy});
234
235 case Kinded::Kind::RescaleQuantizedNodeKind:
236 return NI.allInputsAndOutputsHaveSameElemKind(
237 {ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy});
238
239 case Kinded::Kind::PowNodeKind:
240 case Kinded::Kind::AvgPoolGradNodeKind:
241 case Kinded::Kind::QuantizationProfileNodeKind:
242 case Kinded::Kind::LocalResponseNormalizationNodeKind:
243 case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
244 case Kinded::Kind::LogNodeKind:
245 case Kinded::Kind::TanhNodeKind:
246 case Kinded::Kind::SigmoidNodeKind:
247 case Kinded::Kind::ExpNodeKind:
248 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
249
250 case Kinded::Kind::ModuloNodeKind:
251 return NI.allInputsAndOutputsHaveSameElemKind(
252 {ElemKind::Int32ITy, ElemKind::Int64ITy});
253
254 case Kinded::Kind::MaxPoolGradNodeKind:
255 return NI.allInputsAndOutputsHaveSameElemKind(
256 {ElemKind::FloatTy},
257 {MaxPoolGradNode::OriginalOutputForArgmaxIdx,
258 MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) &&
259 (NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
260 ElemKind::Int64ITy ||
261 NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
262 ElemKind::Int32ITy) &&
263 (NI.getInElemTy(
264 MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
265 ElemKind::Int64ITy ||
266 NI.getInElemTy(
267 MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
268 ElemKind::Int32ITy);
269
270 case Kinded::Kind::ConvolutionNodeKind:
271 if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
272 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
273 }
274
275 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
276 {ConvolutionNode::BiasIdx}) &&
277 (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int8QTy ||
278 NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy);
279
280 case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
281 return (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
282 ElemKind::Int8QTy) &&
283 (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
284 ElemKind::Int8QTy) &&
285 ((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
286 ElemKind::Int8QTy) ||
287 (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
288 ElemKind::Int32QTy)) &&
289 (NI.getInElemTy(
290 ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
291 ElemKind::FloatTy) &&
292 (NI.getInElemTy(
293 ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==
294 ElemKind::Int32ITy) &&
295 (NI.getInElemTy(
296 ChannelwiseQuantizedConvolutionNode::BiasScalesIdx) ==
297 ElemKind::FloatTy) &&
298 (NI.getInElemTy(
299 ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx) ==
300 ElemKind::Int32ITy) &&
301 (NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
302 ElemKind::Int8QTy);
303
304 case Kinded::Kind::ConvTransposeNodeKind:
305 // TODO - not quantized support yet in libjit.
306 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
307
308 case Kinded::Kind::BatchedAddNodeKind:
309 if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) {
310 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
311 }
312 // Allow for Int8QTy or Int32QTy for the Slice input.
313 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
314 {BatchedAddNode::SliceIdx}) &&
315 ((NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy) ||
316 (NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy));
317
318 case Kinded::Kind::GatherNodeKind:
319 return NI.allInputsAndOutputsHaveSameElemKind(
320 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
321 ElemKind::Int32ITy},
322 {GatherNode::IndicesIdx}) &&
323 ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
324 (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));
325
326 case Kinded::Kind::GatherNDNodeKind:
327 return NI.allInputsAndOutputsHaveSameElemKind(
328 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
329 ElemKind::Int32ITy},
330 {GatherNDNode::IndicesIdx}) &&
331 ((NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int32ITy) ||
332 (NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int64ITy));
333
334 case Kinded::Kind::GatherRangesNodeKind:
335 return NI.allInputsAndOutputsHaveSameElemKind(
336 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
337 ElemKind::Int32ITy},
338 {GatherRangesNode::RangesIdx}, {GatherRangesNode::LengthsIdx}) &&
339 ((NI.getInElemTy(GatherRangesNode::RangesIdx) ==
340 NI.getOutElemTy(GatherRangesNode::LengthsIdx)) &&
341 ((NI.getOutElemTy(GatherRangesNode::LengthsIdx) ==
342 ElemKind::Int32ITy) ||
343 (NI.getOutElemTy(GatherRangesNode::LengthsIdx) ==
344 ElemKind::Int64ITy)));
345
346 case Kinded::Kind::ScatterDataNodeKind:
347 // ScatterData ==> Copy + ScatterData. Copy supports everything
348 // ReshapeNode above supports, however ScatterData only supports the
349 // following.
350 return NI.allInputsAndOutputsHaveSameElemKind(
351 {ElemKind::FloatTy, ElemKind::Int8QTy},
352 {ScatterDataNode::IndicesIdx}) &&
353 (NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy ||
354 NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int32ITy);
355
356 case Kinded::Kind::SelectNodeKind:
357 return NI.allInputsAndOutputsHaveSameElemKind(
358 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy},
359 {SelectNode::CondIdx}) &&
360 ((NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy));
361
362 case Kinded::Kind::NotNodeKind:
363 case Kinded::Kind::AndNodeKind:
364 case Kinded::Kind::OrNodeKind:
365 case Kinded::Kind::XorNodeKind:
366 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::BoolTy});
367
368 case Kinded::Kind::AbsNodeKind:
369 case Kinded::Kind::NegNodeKind:
370 case Kinded::Kind::FloorNodeKind:
371 case Kinded::Kind::CeilNodeKind:
372 case Kinded::Kind::RoundNodeKind:
373 case Kinded::Kind::SqrtNodeKind:
374 case Kinded::Kind::ErfNodeKind:
375 case Kinded::Kind::RsqrtNodeKind:
376 case Kinded::Kind::ReciprocalNodeKind:
377 case Kinded::Kind::SinNodeKind:
378 case Kinded::Kind::CosNodeKind:
379 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
380
381 case Kinded::Kind::CmpEQNodeKind:
382 case Kinded::Kind::CmpNEQNodeKind:
383 case Kinded::Kind::CmpLTNodeKind:
384 case Kinded::Kind::CmpLTENodeKind:
385 return NI.allInputsAndOutputsHaveSameElemKind(
386 {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy,
387 ElemKind::Int64ITy},
388 {}, {CmpEQNode::ResultIdx}) &&
389 (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);
390
391 case Kinded::Kind::IsNaNNodeKind:
392 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}, {},
393 {IsNaNNode::ResultIdx}) &&
394 (NI.getOutElemTy(IsNaNNode::ResultIdx) == ElemKind::BoolTy);
395
396 case Kinded::Kind::TopKNodeKind:
397 return NI.allInputsAndOutputsHaveSameElemKind(
398 {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
399 {TopKNode::IndicesIdx}) &&
400 (NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy ||
401 NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy);
402
403 case Kinded::Kind::QuantizeNodeKind:
404 return (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) &&
405 ((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) ||
406 (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int16QTy) ||
407 (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy));
408
409 case Kinded::Kind::DequantizeNodeKind:
410 return ((NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) ||
411 (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int16QTy) ||
412 (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int32QTy)) &&
413 (NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy);
414
415 case Kinded::Kind::SoftMaxNodeKind:
416 return NI.allInputsAndOutputsHaveSameElemKind(
417 {ElemKind::FloatTy, ElemKind::Int8QTy},
418 {SoftMaxNode::SelectedIdx}) &&
419 (NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy ||
420 NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy);
421
422 case Kinded::Kind::CrossEntropyLossNodeKind:
423 return NI.allInputsAndOutputsHaveSameElemKind(
424 {ElemKind::FloatTy}, {CrossEntropyLossNode::LabelsIdx}) &&
425 (NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
426 ElemKind::Int64ITy ||
427 NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
428 ElemKind::Int32ITy);
429
430 case Kinded::Kind::LengthsSumNodeKind:
431 return NI.allInputsAndOutputsHaveSameElemKind(
432 {ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) &&
433 (NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);
434
435 case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind:
436 return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx) ==
437 ElemKind::UInt8FusedQTy) &&
438 (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
439 ElemKind::FloatTy) &&
440 (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
441 ElemKind::Int32ITy) &&
442 (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
443 ElemKind::Int32ITy) &&
444 (NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
445 ElemKind::FloatTy);
446
447 case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
448 return (NI.getInElemTy(
449 FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
450 ElemKind::UInt8FusedQTy) &&
451 (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
452 WeightsIdx) == ElemKind::FloatTy) &&
453 ((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
454 IndicesIdx) == ElemKind::Int64ITy ||
455 NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
456 IndicesIdx) == ElemKind::Int32ITy)) &&
457 (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
458 LengthsIdx) == ElemKind::Int32ITy) &&
459 (NI.getOutElemTy(
460 FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
461 ElemKind::FloatTy);
462
463 case Kinded::Kind::FullyConnectedNodeKind:
464 if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
465 return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
466 }
467 return NI.allInputsAndOutputsHaveSameElemKind(
468 {ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
469 (NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int8QTy ||
470 NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int32QTy);
471
472 case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
473 return (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
474 ElemKind::Int8QTy) &&
475 (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
476 ElemKind::Int8QTy) &&
477 (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
478 ElemKind::FloatTy) &&
479 (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
480 ElemKind::Int32ITy) &&
481 (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
482 ElemKind::Int8QTy ||
483 NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
484 ElemKind::Int32QTy) &&
485 (NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
486 ElemKind::Int8QTy);
487
488 case Kinded::Kind::SoftMaxGradNodeKind:
489 return NI.allInputsAndOutputsHaveSameElemKind(
490 {ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx},
491 {SoftMaxGradNode::GradOfInputNamedSelectedIdx}) &&
492 (NI.getInElemTy(SoftMaxGradNode::SelectedIdx) ==
493 ElemKind::Int64ITy ||
494 NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int32ITy);
495
496 case Kinded::Kind::ConvolutionGradNodeKind:
497 return NI.allInputsAndOutputsHaveSameElemKind(
498 {ElemKind::FloatTy}, {},
499 {ConvolutionGradNode::GradOfInputNamedInputIdx});
500
501 case Kinded::Kind::CrossEntropyLossGradNodeKind:
502 return NI.allInputsAndOutputsHaveSameElemKind(
503 {ElemKind::FloatTy}, {CrossEntropyLossGradNode::LabelsIdx},
504 {CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx}) &&
505 (NI.getInElemTy(CrossEntropyLossGradNode::LabelsIdx) ==
506 ElemKind::Int64ITy) &&
507 (NI.getOutElemTy(
508 CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx) ==
509 ElemKind::Int64ITy);
510
511 case Kinded::Kind::TraceEventNodeKind:
512 return NI.getInElemTy(TraceEventNode::DataIdx) == ElemKind::Int64ITy;
513
514 case Kinded::Kind::NonMaxSuppressionNodeKind:
515 return NI.getInElemTy(NonMaxSuppressionNode::BoxesIdx) ==
516 ElemKind::FloatTy &&
517 NI.getInElemTy(NonMaxSuppressionNode::ScoresIdx) ==
518 ElemKind::FloatTy &&
519 (NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
520 ElemKind::Int32ITy ||
521 NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
522 ElemKind::Int64ITy) &&
523 (NI.getOutElemTy(
524 NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
525 ElemKind::Int32ITy ||
526 NI.getOutElemTy(
527 NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
528 ElemKind::Int64ITy);
529
530 case Kinded::Kind::TFLiteDetectionPostProcessNodeKind:
531 return NI.getInElemTy(TFLiteDetectionPostProcessNode::BoxesIdx) ==
532 ElemKind::FloatTy &&
533 NI.getInElemTy(TFLiteDetectionPostProcessNode::ScoresIdx) ==
534 ElemKind::FloatTy &&
535 NI.getInElemTy(TFLiteDetectionPostProcessNode::AnchorsIdx) ==
536 ElemKind::FloatTy &&
537 NI.getOutElemTy(TFLiteDetectionPostProcessNode::DetectionBoxesIdx) ==
538 ElemKind::FloatTy &&
539 NI.getOutElemTy(
540 TFLiteDetectionPostProcessNode::DetectionClassesIdx) ==
541 ElemKind::Int32ITy &&
542 NI.getOutElemTy(
543 TFLiteDetectionPostProcessNode::DetectionScoresIdx) ==
544 ElemKind::FloatTy &&
545 NI.getOutElemTy(TFLiteDetectionPostProcessNode::NumDetectionsIdx) ==
546 ElemKind::Int32ITy;
547
548 case Kinded::Kind::AudioSpectrogramNodeKind:
549 return NI.getInElemTy(AudioSpectrogramNode::InputIdx) ==
550 ElemKind::FloatTy &&
551 NI.getOutElemTy(AudioSpectrogramNode::SpectrogramIdx) ==
552 ElemKind::FloatTy;
553
554 case Kinded::Kind::MFCCNodeKind:
555 return NI.getInElemTy(MFCCNode::SpectrogramIdx) == ElemKind::FloatTy &&
556 NI.getOutElemTy(MFCCNode::CoefficientsIdx) == ElemKind::FloatTy;
557
558 case Kinded::Kind::ConvertToNodeKind:
559 return ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) &&
560 (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) ||
561 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) &&
562 (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) ||
563 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int64ITy) &&
564 (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
565 ElemKind::Int32ITy)) ||
566 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) &&
567 (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
568 ElemKind::Int64ITy)) ||
569 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) &&
570 (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::BoolTy)) ||
571 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) &&
572 (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
573 ElemKind::Int32ITy)) ||
574 ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) &&
575 (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::Int32ITy));
576
577 default:
578 return false;
579 }
580}
581
582LLVMBackendOptions::LLVMBackendOptions() {
583 // Initialize using command-line options by default.
584 arch_ = llvmArch;
585 target_ = llvmTarget;
586 cpu_ = llvmCPU;
587 abi_ = llvmABI;
588 floatABI_ = floatABI;
589 codeModel_ = llvmCodeModel;
590 bundleCodeModel_ = llvmBundleCodeModel;
591 relocModel_ = llvmRelocModel;
592 bundleAPI_ = bundleAPI;
593 targetFeatures_.append(llvmTargetFeatures.begin(), llvmTargetFeatures.end());
594}
595
596LLVMBackend::LLVMBackend() {}
597
598std::string LLVMBackend::getHostTarget() {
599 return llvm::sys::getProcessTriple();
600}
601
602std::string LLVMBackend::getHostCPU() {
603 auto cpu_name = llvm::sys::getHostCPUName();
604 // Skip avx512 because LLVM does not support it well.
605 cpu_name.consume_back("-avx512");
606 return cpu_name.str();
607}
608
609llvm::SmallVector<std::string, 0> LLVMBackend::getHostFeatures() {
610 llvm::SmallVector<std::string, 0> result;
611 llvm::StringMap<bool> hostFeatures;
612 if (llvm::sys::getHostCPUFeatures(hostFeatures)) {
613 for (auto &feature : hostFeatures) {
614 if (feature.second) {
615 llvm::StringRef fn = feature.first();
616 // Skip avx512 because LLVM does not support it well.
617 if (fn.startswith("avx512")) {
618 continue;
619 }
620 result.push_back(fn.str());
621 }
622 }
623 }
624 return result;
625}
626
627/// Emit the entry point for JIT called "jitmain".
628/// Function has the following API:
629/// int jitmain(uint8_t *baseConstantWeightVars,
630/// uint8_t *baseInOutWeightVars,
631/// uint8_t *baseActivations);
632void LLVMBackend::emitJitMain(LLVMIRGen &irgen) const {
633 AllocationsInfo &allocationsInfo = irgen.getAllocationsInfo();
634 auto int8PtrTy = llvm::Type::getInt8PtrTy(irgen.getLLVMContext());
635 llvm::Type *retTy =
636 llvm::Type::getIntNTy(irgen.getLLVMContext(), irgen.getLibjitIntWidth());
637 llvm::FunctionType *jitFuncTy =
638 llvm::FunctionType::get(retTy, {int8PtrTy, int8PtrTy, int8PtrTy}, false);
639 auto *func =
640 llvm::Function::Create(jitFuncTy, llvm::Function::ExternalLinkage,
641 "jitmain", &irgen.getModule());
642 llvm::BasicBlock *entry_bb =
643 llvm::BasicBlock::Create(irgen.getLLVMContext(), "entry", func);
644 llvm::IRBuilder<> builder(entry_bb);
645 // Add a provisional terminator to make the function well-formed.
646 auto *zero = builder.getIntN(irgen.getLibjitIntWidth(), 0);
647 auto *ret = builder.CreateRet(zero);
648 builder.SetInsertPoint(ret);
649
650 // Prepare arguments for the "main" function.
651 llvm::SmallVector<llvm::Value *, 4> initFunctionCallArgs;
652 initFunctionCallArgs.push_back(func->args().begin());
653 initFunctionCallArgs.push_back(func->args().begin() + 1);
654 initFunctionCallArgs.push_back(func->args().begin() + 2);
655 // Now form the offsets array and pass it as the last argument.
656 auto offsetsArray =
657 irgen.emitConstOffsetsArray(irgen.getBuilder(), allocationsInfo);
658 initFunctionCallArgs.push_back(offsetsArray);
659 // Invoke the main entry with constant arguments and let LLVM optimizer make
660 // use of it.
661 auto *entryF = irgen.getModule().getFunction(irgen.getMainEntryName());
662 entryF->setLinkage(llvm::Function::InternalLinkage);
663 auto *result = irgen.createCall(builder, entryF, initFunctionCallArgs);
664 // Terminate the function.
665 builder.CreateRet(result);
666 // Remove the provisional terminator.
667 ret->eraseFromParent();
668 // Emit JIT file printer.
669 irgen.generateJITFileWriter();
670 // Create the debug info for the entry point function.
671 irgen.generateFunctionDebugInfo(func);
672}
673
674std::unique_ptr<CompiledFunction>
675LLVMBackend::compileIR(std::unique_ptr<IRFunction> IR) const {
676 auto function = compileIRWithoutConstants(IR.get());
677 static_cast<LLVMCompiledFunction *>(function.get())
678 ->getRuntimeBundle()
679 .collectConstants(IR.get());
680 return function;
681}
682
683std::unique_ptr<CompiledFunction>
684LLVMBackend::compileIRWithoutConstants(IRFunction *IR) const {
685 AllocationsInfo allocationsInfo;
686 std::unique_ptr<LLVMIRGen> irgen = createIRGen(IR, allocationsInfo);
687 llvm::SmallVector<std::string, 8> targetFeatures(llvmTargetFeatures.begin(),
688 llvmTargetFeatures.end());
689 irgen->initTargetMachine(getOptions());
690 irgen->initCodeGen();
691 irgen->setIRFunction(IR);
692 // Perform the address assignment for activations and WeightVars.
693 allocateJITMemory(IR, irgen->getAllocationsInfo());
694 // Emit the code for the body of the entry function.
695 irgen->performCodeGen();
696 // Create the jitmain function to be invoked by JIT.
697 emitJitMain(*irgen);
698 irgen->finishCodeGen();
699 // Hand over the module to JIT for the machine code generation.
700 auto JIT = glow::make_unique<GlowJIT>(irgen->takeTargetMachine());
701 JIT->setContext(irgen->takeLLVMContext());
702 JIT->addModule(irgen->borrowModule());
703 // Build runtimeBundle object containing offsets and allocation sizes.
704 MemoryAllocator constantAllocator("ConstantWeights", 0);
705 MemoryAllocator placeholderAllocator("Placeholders", 0);
706 MemoryAllocator activationsAllocator("Activations", 0);
707 auto runtimeInfo = runtime::RuntimeBundle::create(
708 *IR, constantAllocator, placeholderAllocator, activationsAllocator);
709 return createCompiledFunction(std::move(JIT), std::move(runtimeInfo));
710}
711
712Expected<std::unique_ptr<CompiledFunction>>
713LLVMBackend::compile(Function *F, const BackendOptions &opts) const {
714 TraceInfo traceInfo = buildManualTraceInfo(F);
715 auto IR = generateAndOptimizeIR(F, *this, shouldShareBuffers());
716
717 if (opts.autoInstrument) {
718 autoInstrument(traceInfo, IR.get());
719 }
720
721 std::unique_ptr<CompiledFunction> compiledFunc;
722 if (opts.collectConstants) {
723 compiledFunc = compileIR(std::move(IR));
724 } else {
725 compiledFunc = compileIRWithoutConstants(IR.get());
726 }
727
728 compiledFunc->setTraceInfo(std::move(traceInfo));
729 return Expected<std::unique_ptr<CompiledFunction>>(std::move(compiledFunc));
730}
731
732void LLVMBackend::save(Function *F, llvm::StringRef outputDir,
733 llvm::StringRef bundleName,
734 llvm::StringRef mainEntryName) const {
735 llvm::SmallVector<std::string, 8> targetFeatures(llvmTargetFeatures.begin(),
736 llvmTargetFeatures.end());
737 auto IR = generateAndOptimizeIR(F, *this, shouldShareBuffers());
738 auto bundleSaver = createBundleSaver(*this, outputDir, bundleName);
739 bundleSaver->save(mainEntryName, IR.get());
740 bundleSaver->produceBundle();
741}
742
743void LLVMBackend::saveFunctions(llvm::ArrayRef<BundleEntry> entries,
744 llvm::StringRef outputDir,
745 llvm::StringRef bundleName) const {
746 auto bundleSaver = createBundleSaver(*this, outputDir, bundleName);
747 std::vector<std::unique_ptr<glow::IRFunction>> irFunctions;
748 for (auto &entry : entries) {
749 auto IR = generateAndOptimizeIR(entry.func, *this, shouldShareBuffers());
750 bundleSaver->save(entry.name, IR.get());
751 irFunctions.emplace_back(std::move(IR));
752 }
753 bundleSaver->produceBundle();
754}
755
756std::unique_ptr<BundleSaver>
757LLVMBackend::createBundleSaver(const LLVMBackend &llvmBackend,
758 llvm::StringRef outputDir,
759 llvm::StringRef bundleName) const {
760 return glow::make_unique<BundleSaver>(llvmBackend, outputDir, bundleName);
761}
762