1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/LLVMIRCodeGen/LLVMBackend.h" |
18 | #include "glow/LLVMIRCodeGen/BundleSaver.h" |
19 | #include "glow/LLVMIRCodeGen/CommandLine.h" |
20 | #include "glow/LLVMIRCodeGen/LLVMCompiledFunction.h" |
21 | |
22 | #include "glow/Backend/BackendUtils.h" |
23 | #include "glow/Graph/Graph.h" |
24 | #include "glow/Graph/PlaceholderBindings.h" |
25 | #include "glow/IR/Instrs.h" |
26 | #include "glow/Optimizer/IROptimizer/IROptimizer.h" |
27 | #include "glow/Support/Debug.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ADT/SmallVector.h" |
30 | #include "llvm/IR/IRBuilder.h" |
31 | #include "llvm/IR/LLVMContext.h" |
32 | #include "llvm/Support/Host.h" |
33 | |
34 | using namespace glow; |
35 | |
36 | namespace { |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Functions for executing code using JIT |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Perform memory allocation for a JIT execution. |
43 | void allocateJITMemory(const IRFunction *F, AllocationsInfo &allocationsInfo) { |
44 | allocationsInfo.numberValues(F); |
45 | allocationsInfo.allocateWeightVars(F); |
46 | allocationsInfo.allocateActivations(F); |
47 | allocationsInfo.allocateTensorViews(F); |
48 | } |
49 | |
50 | } // end namespace |
51 | |
52 | bool LLVMBackend::isOpSupported(const NodeInfo &NI) const { |
53 | // Note: For brevity below, "X ==> Y, Z" signifes that Node X is IRGen'd into |
54 | // Instructions Y and Z. |
55 | switch (NI.getKind()) { |
56 | case Kinded::Kind::BatchedReduceMaxNodeKind: |
57 | case Kinded::Kind::BatchedReduceMinNodeKind: |
58 | return NI.allInputsAndOutputsHaveSameElemKind( |
59 | {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy}); |
60 | |
61 | case Kinded::Kind::BatchedReduceProdNodeKind: |
62 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
63 | |
64 | case Kinded::Kind::AddNodeKind: |
65 | case Kinded::Kind::MulNodeKind: |
66 | return NI.allInputsAndOutputsHaveSameElemKind( |
67 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy, |
68 | ElemKind::Int64ITy}); |
69 | |
70 | case Kinded::Kind::ReluNodeKind: |
71 | case Kinded::Kind::ClipNodeKind: |
72 | case Kinded::Kind::LeakyReluNodeKind: |
73 | case Kinded::Kind::SubNodeKind: |
74 | case Kinded::Kind::MaxNodeKind: |
75 | case Kinded::Kind::MinNodeKind: |
76 | case Kinded::Kind::BatchedReduceAddNodeKind: |
77 | case Kinded::Kind::MatMulNodeKind: |
78 | case Kinded::Kind::AvgPoolNodeKind: |
79 | return NI.allInputsAndOutputsHaveSameElemKind( |
80 | {ElemKind::FloatTy, ElemKind::Int8QTy}); |
81 | |
82 | case Kinded::Kind::AdaptiveAvgPoolNodeKind: |
83 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
84 | |
85 | case Kinded::Kind::MaxPoolNodeKind: |
86 | return NI.allInputsAndOutputsHaveSameElemKind( |
87 | {ElemKind::FloatTy, ElemKind::Int8QTy}, {}, |
88 | {MaxPoolNode::ArgmaxIdx}) && |
89 | (NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy || |
90 | NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int32ITy); |
91 | |
92 | case Kinded::Kind::ArgMaxNodeKind: |
93 | case Kinded::Kind::ArgMinNodeKind: |
94 | return NI.allInputsAndOutputsHaveSameElemKind( |
95 | {ElemKind::FloatTy, ElemKind::Int8QTy}, {}, |
96 | {ArgMaxNode::ResultIdx}) && |
97 | (NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy || |
98 | NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int32ITy); |
99 | |
100 | case Kinded::Kind::ResizeNearestNodeKind: |
101 | case Kinded::Kind::ResizeBilinearNodeKind: |
102 | return NI.allInputsAndOutputsHaveSameElemKind( |
103 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy, |
104 | ElemKind::Int32ITy, ElemKind::Int64ITy}); |
105 | |
106 | case Kinded::Kind::SaveNodeKind: |
107 | case Kinded::Kind::ReshapeNodeKind: |
108 | // These are implemented via a Copy Instruction. |
109 | return NI.allInputsAndOutputsHaveSameElemKind( |
110 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy, |
111 | ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy, |
112 | ElemKind::BoolTy}); |
113 | |
114 | // InsertTensor ==> Copy + InsertTensor. Copy supports everything |
115 | // ReshapeNode above supports, so InsertTensor is the limiting factor. |
116 | case Kinded::Kind::InsertTensorNodeKind: |
117 | // Concat ==> Splat + Insert. Both only support the following. |
118 | case Kinded::Kind::ConcatNodeKind: |
119 | case Kinded::Kind::SplatNodeKind: |
120 | case Kinded::Kind::TouchNodeKind: |
121 | return NI.allInputsAndOutputsHaveSameElemKind( |
122 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
123 | ElemKind::Int32ITy, ElemKind::BoolTy}); |
124 | case Kinded::Kind::SliceNodeKind: |
125 | return NI.allInputsAndOutputsHaveSameElemKind( |
126 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy, |
127 | ElemKind::Int32ITy, ElemKind::Int64ITy}); |
128 | case Kinded::Kind::FmodNodeKind: |
129 | return NI.allInputsAndOutputsHaveSameElemKind( |
130 | {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy}); |
131 | case Kinded::Kind::SpaceToDepthNodeKind: |
132 | case Kinded::Kind::DivNodeKind: |
133 | return NI.allInputsAndOutputsHaveSameElemKind( |
134 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
135 | ElemKind::Int32ITy}); |
136 | |
137 | case Kinded::Kind::TransposeNodeKind: |
138 | return NI.allInputsAndOutputsHaveSameElemKind( |
139 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
140 | ElemKind::BoolTy}); |
141 | |
142 | case Kinded::Kind::FlipNodeKind: |
143 | return NI.allInputsAndOutputsHaveSameElemKind( |
144 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy, |
145 | ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy, |
146 | ElemKind::BoolTy}); |
147 | |
148 | case Kinded::Kind::SparseLengthsSumNodeKind: |
149 | return NI.allInputsAndOutputsHaveSameElemKind( |
150 | {ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx, |
151 | SparseLengthsSumNode::LengthsIdx}) && |
152 | (NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) == |
153 | ElemKind::Int64ITy || |
154 | NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) == |
155 | ElemKind::Int32ITy) && |
156 | (NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) == |
157 | ElemKind::Int32ITy); |
158 | |
159 | case Kinded::Kind::SparseLengthsWeightedSumNodeKind: |
160 | return NI.allInputsAndOutputsHaveSameElemKind( |
161 | {ElemKind::FloatTy}, |
162 | {SparseLengthsWeightedSumNode::IndicesIdx, |
163 | SparseLengthsWeightedSumNode::LengthsIdx}) && |
164 | (NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) == |
165 | ElemKind::Int64ITy || |
166 | NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) == |
167 | ElemKind::Int32ITy) && |
168 | (NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) == |
169 | ElemKind::Int32ITy); |
170 | |
171 | case Kinded::Kind::EmbeddingNodeKind: |
172 | return NI.allInputsAndOutputsHaveSameElemKind( |
173 | {ElemKind::FloatTy}, {EmbeddingNode::IndicesIdx}) && |
174 | (NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy); |
175 | |
176 | case Kinded::Kind::EmbeddingBagNodeKind: |
177 | return NI.allInputsAndOutputsHaveSameElemKind( |
178 | {ElemKind::FloatTy}, |
179 | {EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) && |
180 | (NI.getInElemTy(EmbeddingBagNode::IndicesIdx) == |
181 | ElemKind::Int32ITy) && |
182 | (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int32ITy); |
183 | |
184 | case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind: |
185 | // GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to |
186 | // be checked because they are not used. |
187 | return NI.allInputsAndOutputsHaveSameElemKind( |
188 | {ElemKind::FloatTy}, |
189 | {SparseLengthsWeightedSumGradNode::IndicesIdx, |
190 | SparseLengthsWeightedSumGradNode::LengthsIdx}, |
191 | {SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx, |
192 | SparseLengthsWeightedSumGradNode:: |
193 | GradOfInputNamedLengthsIdx}) && |
194 | (NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) == |
195 | ElemKind::Int32ITy || |
196 | NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) == |
197 | ElemKind::Int32ITy) && |
198 | (NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) == |
199 | ElemKind::Int32ITy); |
200 | |
201 | case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind: |
202 | return (NI.getInElemTy( |
203 | RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) == |
204 | ElemKind::UInt8QTy) && |
205 | (NI.getInElemTy( |
206 | RowwiseQuantizedSparseLengthsWeightedSumNode::ScalesIdx) == |
207 | ElemKind::FloatTy) && |
208 | (NI.getInElemTy( |
209 | RowwiseQuantizedSparseLengthsWeightedSumNode::OffsetsIdx) == |
210 | ElemKind::FloatTy) && |
211 | (NI.getInElemTy( |
212 | RowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) == |
213 | ElemKind::FloatTy) && |
214 | (NI.getInElemTy( |
215 | RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) == |
216 | ElemKind::Int64ITy || |
217 | NI.getInElemTy( |
218 | RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) == |
219 | ElemKind::Int32ITy) && |
220 | (NI.getInElemTy( |
221 | RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) == |
222 | ElemKind::Int32ITy) && |
223 | (NI.getOutElemTy( |
224 | RowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) == |
225 | ElemKind::FloatTy); |
226 | |
227 | case Kinded::Kind::LengthsRangeFillNodeKind: |
228 | case Kinded::Kind::LengthsToRangesNodeKind: |
229 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy}); |
230 | |
231 | case Kinded::Kind::IntLookupTableNodeKind: |
232 | return NI.allInputsAndOutputsHaveSameElemKind( |
233 | {ElemKind::Int8QTy, ElemKind::Int16QTy}); |
234 | |
235 | case Kinded::Kind::RescaleQuantizedNodeKind: |
236 | return NI.allInputsAndOutputsHaveSameElemKind( |
237 | {ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy}); |
238 | |
239 | case Kinded::Kind::PowNodeKind: |
240 | case Kinded::Kind::AvgPoolGradNodeKind: |
241 | case Kinded::Kind::QuantizationProfileNodeKind: |
242 | case Kinded::Kind::LocalResponseNormalizationNodeKind: |
243 | case Kinded::Kind::LocalResponseNormalizationGradNodeKind: |
244 | case Kinded::Kind::LogNodeKind: |
245 | case Kinded::Kind::TanhNodeKind: |
246 | case Kinded::Kind::SigmoidNodeKind: |
247 | case Kinded::Kind::ExpNodeKind: |
248 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
249 | |
250 | case Kinded::Kind::ModuloNodeKind: |
251 | return NI.allInputsAndOutputsHaveSameElemKind( |
252 | {ElemKind::Int32ITy, ElemKind::Int64ITy}); |
253 | |
254 | case Kinded::Kind::MaxPoolGradNodeKind: |
255 | return NI.allInputsAndOutputsHaveSameElemKind( |
256 | {ElemKind::FloatTy}, |
257 | {MaxPoolGradNode::OriginalOutputForArgmaxIdx, |
258 | MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) && |
259 | (NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) == |
260 | ElemKind::Int64ITy || |
261 | NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) == |
262 | ElemKind::Int32ITy) && |
263 | (NI.getInElemTy( |
264 | MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) == |
265 | ElemKind::Int64ITy || |
266 | NI.getInElemTy( |
267 | MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) == |
268 | ElemKind::Int32ITy); |
269 | |
270 | case Kinded::Kind::ConvolutionNodeKind: |
271 | if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) { |
272 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
273 | } |
274 | |
275 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy}, |
276 | {ConvolutionNode::BiasIdx}) && |
277 | (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int8QTy || |
278 | NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy); |
279 | |
280 | case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: |
281 | return (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) == |
282 | ElemKind::Int8QTy) && |
283 | (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) == |
284 | ElemKind::Int8QTy) && |
285 | ((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) == |
286 | ElemKind::Int8QTy) || |
287 | (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) == |
288 | ElemKind::Int32QTy)) && |
289 | (NI.getInElemTy( |
290 | ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) == |
291 | ElemKind::FloatTy) && |
292 | (NI.getInElemTy( |
293 | ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) == |
294 | ElemKind::Int32ITy) && |
295 | (NI.getInElemTy( |
296 | ChannelwiseQuantizedConvolutionNode::BiasScalesIdx) == |
297 | ElemKind::FloatTy) && |
298 | (NI.getInElemTy( |
299 | ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx) == |
300 | ElemKind::Int32ITy) && |
301 | (NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) == |
302 | ElemKind::Int8QTy); |
303 | |
304 | case Kinded::Kind::ConvTransposeNodeKind: |
305 | // TODO - not quantized support yet in libjit. |
306 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
307 | |
308 | case Kinded::Kind::BatchedAddNodeKind: |
309 | if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) { |
310 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
311 | } |
312 | // Allow for Int8QTy or Int32QTy for the Slice input. |
313 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy}, |
314 | {BatchedAddNode::SliceIdx}) && |
315 | ((NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy) || |
316 | (NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy)); |
317 | |
318 | case Kinded::Kind::GatherNodeKind: |
319 | return NI.allInputsAndOutputsHaveSameElemKind( |
320 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
321 | ElemKind::Int32ITy}, |
322 | {GatherNode::IndicesIdx}) && |
323 | ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) || |
324 | (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy)); |
325 | |
326 | case Kinded::Kind::GatherNDNodeKind: |
327 | return NI.allInputsAndOutputsHaveSameElemKind( |
328 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
329 | ElemKind::Int32ITy}, |
330 | {GatherNDNode::IndicesIdx}) && |
331 | ((NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int32ITy) || |
332 | (NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int64ITy)); |
333 | |
334 | case Kinded::Kind::GatherRangesNodeKind: |
335 | return NI.allInputsAndOutputsHaveSameElemKind( |
336 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy, |
337 | ElemKind::Int32ITy}, |
338 | {GatherRangesNode::RangesIdx}, {GatherRangesNode::LengthsIdx}) && |
339 | ((NI.getInElemTy(GatherRangesNode::RangesIdx) == |
340 | NI.getOutElemTy(GatherRangesNode::LengthsIdx)) && |
341 | ((NI.getOutElemTy(GatherRangesNode::LengthsIdx) == |
342 | ElemKind::Int32ITy) || |
343 | (NI.getOutElemTy(GatherRangesNode::LengthsIdx) == |
344 | ElemKind::Int64ITy))); |
345 | |
346 | case Kinded::Kind::ScatterDataNodeKind: |
347 | // ScatterData ==> Copy + ScatterData. Copy supports everything |
348 | // ReshapeNode above supports, however ScatterData only supports the |
349 | // following. |
350 | return NI.allInputsAndOutputsHaveSameElemKind( |
351 | {ElemKind::FloatTy, ElemKind::Int8QTy}, |
352 | {ScatterDataNode::IndicesIdx}) && |
353 | (NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy || |
354 | NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int32ITy); |
355 | |
356 | case Kinded::Kind::SelectNodeKind: |
357 | return NI.allInputsAndOutputsHaveSameElemKind( |
358 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy}, |
359 | {SelectNode::CondIdx}) && |
360 | ((NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy)); |
361 | |
362 | case Kinded::Kind::NotNodeKind: |
363 | case Kinded::Kind::AndNodeKind: |
364 | case Kinded::Kind::OrNodeKind: |
365 | case Kinded::Kind::XorNodeKind: |
366 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::BoolTy}); |
367 | |
368 | case Kinded::Kind::AbsNodeKind: |
369 | case Kinded::Kind::NegNodeKind: |
370 | case Kinded::Kind::FloorNodeKind: |
371 | case Kinded::Kind::CeilNodeKind: |
372 | case Kinded::Kind::RoundNodeKind: |
373 | case Kinded::Kind::SqrtNodeKind: |
374 | case Kinded::Kind::ErfNodeKind: |
375 | case Kinded::Kind::RsqrtNodeKind: |
376 | case Kinded::Kind::ReciprocalNodeKind: |
377 | case Kinded::Kind::SinNodeKind: |
378 | case Kinded::Kind::CosNodeKind: |
379 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
380 | |
381 | case Kinded::Kind::CmpEQNodeKind: |
382 | case Kinded::Kind::CmpNEQNodeKind: |
383 | case Kinded::Kind::CmpLTNodeKind: |
384 | case Kinded::Kind::CmpLTENodeKind: |
385 | return NI.allInputsAndOutputsHaveSameElemKind( |
386 | {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy, |
387 | ElemKind::Int64ITy}, |
388 | {}, {CmpEQNode::ResultIdx}) && |
389 | (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy); |
390 | |
391 | case Kinded::Kind::IsNaNNodeKind: |
392 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}, {}, |
393 | {IsNaNNode::ResultIdx}) && |
394 | (NI.getOutElemTy(IsNaNNode::ResultIdx) == ElemKind::BoolTy); |
395 | |
396 | case Kinded::Kind::TopKNodeKind: |
397 | return NI.allInputsAndOutputsHaveSameElemKind( |
398 | {ElemKind::FloatTy, ElemKind::Int8QTy}, {}, |
399 | {TopKNode::IndicesIdx}) && |
400 | (NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy || |
401 | NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy); |
402 | |
403 | case Kinded::Kind::QuantizeNodeKind: |
404 | return (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) && |
405 | ((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) || |
406 | (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int16QTy) || |
407 | (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy)); |
408 | |
409 | case Kinded::Kind::DequantizeNodeKind: |
410 | return ((NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) || |
411 | (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int16QTy) || |
412 | (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int32QTy)) && |
413 | (NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy); |
414 | |
415 | case Kinded::Kind::SoftMaxNodeKind: |
416 | return NI.allInputsAndOutputsHaveSameElemKind( |
417 | {ElemKind::FloatTy, ElemKind::Int8QTy}, |
418 | {SoftMaxNode::SelectedIdx}) && |
419 | (NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy || |
420 | NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy); |
421 | |
422 | case Kinded::Kind::CrossEntropyLossNodeKind: |
423 | return NI.allInputsAndOutputsHaveSameElemKind( |
424 | {ElemKind::FloatTy}, {CrossEntropyLossNode::LabelsIdx}) && |
425 | (NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) == |
426 | ElemKind::Int64ITy || |
427 | NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) == |
428 | ElemKind::Int32ITy); |
429 | |
430 | case Kinded::Kind::LengthsSumNodeKind: |
431 | return NI.allInputsAndOutputsHaveSameElemKind( |
432 | {ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) && |
433 | (NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy); |
434 | |
435 | case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: |
436 | return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx) == |
437 | ElemKind::UInt8FusedQTy) && |
438 | (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) == |
439 | ElemKind::FloatTy) && |
440 | (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) == |
441 | ElemKind::Int32ITy) && |
442 | (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) == |
443 | ElemKind::Int32ITy) && |
444 | (NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) == |
445 | ElemKind::FloatTy); |
446 | |
447 | case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: |
448 | return (NI.getInElemTy( |
449 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) == |
450 | ElemKind::UInt8FusedQTy) && |
451 | (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode:: |
452 | WeightsIdx) == ElemKind::FloatTy) && |
453 | ((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode:: |
454 | IndicesIdx) == ElemKind::Int64ITy || |
455 | NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode:: |
456 | IndicesIdx) == ElemKind::Int32ITy)) && |
457 | (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode:: |
458 | LengthsIdx) == ElemKind::Int32ITy) && |
459 | (NI.getOutElemTy( |
460 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) == |
461 | ElemKind::FloatTy); |
462 | |
463 | case Kinded::Kind::FullyConnectedNodeKind: |
464 | if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) { |
465 | return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}); |
466 | } |
467 | return NI.allInputsAndOutputsHaveSameElemKind( |
468 | {ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) && |
469 | (NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int8QTy || |
470 | NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int32QTy); |
471 | |
472 | case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind: |
473 | return (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) == |
474 | ElemKind::Int8QTy) && |
475 | (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) == |
476 | ElemKind::Int8QTy) && |
477 | (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) == |
478 | ElemKind::FloatTy) && |
479 | (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) == |
480 | ElemKind::Int32ITy) && |
481 | (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) == |
482 | ElemKind::Int8QTy || |
483 | NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) == |
484 | ElemKind::Int32QTy) && |
485 | (NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) == |
486 | ElemKind::Int8QTy); |
487 | |
488 | case Kinded::Kind::SoftMaxGradNodeKind: |
489 | return NI.allInputsAndOutputsHaveSameElemKind( |
490 | {ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx}, |
491 | {SoftMaxGradNode::GradOfInputNamedSelectedIdx}) && |
492 | (NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == |
493 | ElemKind::Int64ITy || |
494 | NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int32ITy); |
495 | |
496 | case Kinded::Kind::ConvolutionGradNodeKind: |
497 | return NI.allInputsAndOutputsHaveSameElemKind( |
498 | {ElemKind::FloatTy}, {}, |
499 | {ConvolutionGradNode::GradOfInputNamedInputIdx}); |
500 | |
501 | case Kinded::Kind::CrossEntropyLossGradNodeKind: |
502 | return NI.allInputsAndOutputsHaveSameElemKind( |
503 | {ElemKind::FloatTy}, {CrossEntropyLossGradNode::LabelsIdx}, |
504 | {CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx}) && |
505 | (NI.getInElemTy(CrossEntropyLossGradNode::LabelsIdx) == |
506 | ElemKind::Int64ITy) && |
507 | (NI.getOutElemTy( |
508 | CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx) == |
509 | ElemKind::Int64ITy); |
510 | |
511 | case Kinded::Kind::TraceEventNodeKind: |
512 | return NI.getInElemTy(TraceEventNode::DataIdx) == ElemKind::Int64ITy; |
513 | |
514 | case Kinded::Kind::NonMaxSuppressionNodeKind: |
515 | return NI.getInElemTy(NonMaxSuppressionNode::BoxesIdx) == |
516 | ElemKind::FloatTy && |
517 | NI.getInElemTy(NonMaxSuppressionNode::ScoresIdx) == |
518 | ElemKind::FloatTy && |
519 | (NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) == |
520 | ElemKind::Int32ITy || |
521 | NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) == |
522 | ElemKind::Int64ITy) && |
523 | (NI.getOutElemTy( |
524 | NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) == |
525 | ElemKind::Int32ITy || |
526 | NI.getOutElemTy( |
527 | NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) == |
528 | ElemKind::Int64ITy); |
529 | |
530 | case Kinded::Kind::TFLiteDetectionPostProcessNodeKind: |
531 | return NI.getInElemTy(TFLiteDetectionPostProcessNode::BoxesIdx) == |
532 | ElemKind::FloatTy && |
533 | NI.getInElemTy(TFLiteDetectionPostProcessNode::ScoresIdx) == |
534 | ElemKind::FloatTy && |
535 | NI.getInElemTy(TFLiteDetectionPostProcessNode::AnchorsIdx) == |
536 | ElemKind::FloatTy && |
537 | NI.getOutElemTy(TFLiteDetectionPostProcessNode::DetectionBoxesIdx) == |
538 | ElemKind::FloatTy && |
539 | NI.getOutElemTy( |
540 | TFLiteDetectionPostProcessNode::DetectionClassesIdx) == |
541 | ElemKind::Int32ITy && |
542 | NI.getOutElemTy( |
543 | TFLiteDetectionPostProcessNode::DetectionScoresIdx) == |
544 | ElemKind::FloatTy && |
545 | NI.getOutElemTy(TFLiteDetectionPostProcessNode::NumDetectionsIdx) == |
546 | ElemKind::Int32ITy; |
547 | |
548 | case Kinded::Kind::AudioSpectrogramNodeKind: |
549 | return NI.getInElemTy(AudioSpectrogramNode::InputIdx) == |
550 | ElemKind::FloatTy && |
551 | NI.getOutElemTy(AudioSpectrogramNode::SpectrogramIdx) == |
552 | ElemKind::FloatTy; |
553 | |
554 | case Kinded::Kind::MFCCNodeKind: |
555 | return NI.getInElemTy(MFCCNode::SpectrogramIdx) == ElemKind::FloatTy && |
556 | NI.getOutElemTy(MFCCNode::CoefficientsIdx) == ElemKind::FloatTy; |
557 | |
558 | case Kinded::Kind::ConvertToNodeKind: |
559 | return ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) && |
560 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) || |
561 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) && |
562 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) || |
563 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int64ITy) && |
564 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == |
565 | ElemKind::Int32ITy)) || |
566 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) && |
567 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == |
568 | ElemKind::Int64ITy)) || |
569 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) && |
570 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::BoolTy)) || |
571 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) && |
572 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == |
573 | ElemKind::Int32ITy)) || |
574 | ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) && |
575 | (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::Int32ITy)); |
576 | |
577 | default: |
578 | return false; |
579 | } |
580 | } |
581 | |
582 | LLVMBackendOptions::LLVMBackendOptions() { |
583 | // Initialize using command-line options by default. |
584 | arch_ = llvmArch; |
585 | target_ = llvmTarget; |
586 | cpu_ = llvmCPU; |
587 | abi_ = llvmABI; |
588 | floatABI_ = floatABI; |
589 | codeModel_ = llvmCodeModel; |
590 | bundleCodeModel_ = llvmBundleCodeModel; |
591 | relocModel_ = llvmRelocModel; |
592 | bundleAPI_ = bundleAPI; |
593 | targetFeatures_.append(llvmTargetFeatures.begin(), llvmTargetFeatures.end()); |
594 | } |
595 | |
596 | LLVMBackend::LLVMBackend() {} |
597 | |
598 | std::string LLVMBackend::getHostTarget() { |
599 | return llvm::sys::getProcessTriple(); |
600 | } |
601 | |
602 | std::string LLVMBackend::getHostCPU() { |
603 | auto cpu_name = llvm::sys::getHostCPUName(); |
604 | // Skip avx512 because LLVM does not support it well. |
605 | cpu_name.consume_back("-avx512" ); |
606 | return cpu_name.str(); |
607 | } |
608 | |
609 | llvm::SmallVector<std::string, 0> LLVMBackend::getHostFeatures() { |
610 | llvm::SmallVector<std::string, 0> result; |
611 | llvm::StringMap<bool> hostFeatures; |
612 | if (llvm::sys::getHostCPUFeatures(hostFeatures)) { |
613 | for (auto &feature : hostFeatures) { |
614 | if (feature.second) { |
615 | llvm::StringRef fn = feature.first(); |
616 | // Skip avx512 because LLVM does not support it well. |
617 | if (fn.startswith("avx512" )) { |
618 | continue; |
619 | } |
620 | result.push_back(fn.str()); |
621 | } |
622 | } |
623 | } |
624 | return result; |
625 | } |
626 | |
627 | /// Emit the entry point for JIT called "jitmain". |
628 | /// Function has the following API: |
629 | /// int jitmain(uint8_t *baseConstantWeightVars, |
630 | /// uint8_t *baseInOutWeightVars, |
631 | /// uint8_t *baseActivations); |
632 | void LLVMBackend::emitJitMain(LLVMIRGen &irgen) const { |
633 | AllocationsInfo &allocationsInfo = irgen.getAllocationsInfo(); |
634 | auto int8PtrTy = llvm::Type::getInt8PtrTy(irgen.getLLVMContext()); |
635 | llvm::Type *retTy = |
636 | llvm::Type::getIntNTy(irgen.getLLVMContext(), irgen.getLibjitIntWidth()); |
637 | llvm::FunctionType *jitFuncTy = |
638 | llvm::FunctionType::get(retTy, {int8PtrTy, int8PtrTy, int8PtrTy}, false); |
639 | auto *func = |
640 | llvm::Function::Create(jitFuncTy, llvm::Function::ExternalLinkage, |
641 | "jitmain" , &irgen.getModule()); |
642 | llvm::BasicBlock *entry_bb = |
643 | llvm::BasicBlock::Create(irgen.getLLVMContext(), "entry" , func); |
644 | llvm::IRBuilder<> builder(entry_bb); |
645 | // Add a provisional terminator to make the function well-formed. |
646 | auto *zero = builder.getIntN(irgen.getLibjitIntWidth(), 0); |
647 | auto *ret = builder.CreateRet(zero); |
648 | builder.SetInsertPoint(ret); |
649 | |
650 | // Prepare arguments for the "main" function. |
651 | llvm::SmallVector<llvm::Value *, 4> initFunctionCallArgs; |
652 | initFunctionCallArgs.push_back(func->args().begin()); |
653 | initFunctionCallArgs.push_back(func->args().begin() + 1); |
654 | initFunctionCallArgs.push_back(func->args().begin() + 2); |
655 | // Now form the offsets array and pass it as the last argument. |
656 | auto offsetsArray = |
657 | irgen.emitConstOffsetsArray(irgen.getBuilder(), allocationsInfo); |
658 | initFunctionCallArgs.push_back(offsetsArray); |
659 | // Invoke the main entry with constant arguments and let LLVM optimizer make |
660 | // use of it. |
661 | auto *entryF = irgen.getModule().getFunction(irgen.getMainEntryName()); |
662 | entryF->setLinkage(llvm::Function::InternalLinkage); |
663 | auto *result = irgen.createCall(builder, entryF, initFunctionCallArgs); |
664 | // Terminate the function. |
665 | builder.CreateRet(result); |
666 | // Remove the provisional terminator. |
667 | ret->eraseFromParent(); |
668 | // Emit JIT file printer. |
669 | irgen.generateJITFileWriter(); |
670 | // Create the debug info for the entry point function. |
671 | irgen.generateFunctionDebugInfo(func); |
672 | } |
673 | |
674 | std::unique_ptr<CompiledFunction> |
675 | LLVMBackend::compileIR(std::unique_ptr<IRFunction> IR) const { |
676 | auto function = compileIRWithoutConstants(IR.get()); |
677 | static_cast<LLVMCompiledFunction *>(function.get()) |
678 | ->getRuntimeBundle() |
679 | .collectConstants(IR.get()); |
680 | return function; |
681 | } |
682 | |
683 | std::unique_ptr<CompiledFunction> |
684 | LLVMBackend::compileIRWithoutConstants(IRFunction *IR) const { |
685 | AllocationsInfo allocationsInfo; |
686 | std::unique_ptr<LLVMIRGen> irgen = createIRGen(IR, allocationsInfo); |
687 | llvm::SmallVector<std::string, 8> targetFeatures(llvmTargetFeatures.begin(), |
688 | llvmTargetFeatures.end()); |
689 | irgen->initTargetMachine(getOptions()); |
690 | irgen->initCodeGen(); |
691 | irgen->setIRFunction(IR); |
692 | // Perform the address assignment for activations and WeightVars. |
693 | allocateJITMemory(IR, irgen->getAllocationsInfo()); |
694 | // Emit the code for the body of the entry function. |
695 | irgen->performCodeGen(); |
696 | // Create the jitmain function to be invoked by JIT. |
697 | emitJitMain(*irgen); |
698 | irgen->finishCodeGen(); |
699 | // Hand over the module to JIT for the machine code generation. |
700 | auto JIT = glow::make_unique<GlowJIT>(irgen->takeTargetMachine()); |
701 | JIT->setContext(irgen->takeLLVMContext()); |
702 | JIT->addModule(irgen->borrowModule()); |
703 | // Build runtimeBundle object containing offsets and allocation sizes. |
704 | MemoryAllocator constantAllocator("ConstantWeights" , 0); |
705 | MemoryAllocator placeholderAllocator("Placeholders" , 0); |
706 | MemoryAllocator activationsAllocator("Activations" , 0); |
707 | auto runtimeInfo = runtime::RuntimeBundle::create( |
708 | *IR, constantAllocator, placeholderAllocator, activationsAllocator); |
709 | return createCompiledFunction(std::move(JIT), std::move(runtimeInfo)); |
710 | } |
711 | |
712 | Expected<std::unique_ptr<CompiledFunction>> |
713 | LLVMBackend::compile(Function *F, const BackendOptions &opts) const { |
714 | TraceInfo traceInfo = buildManualTraceInfo(F); |
715 | auto IR = generateAndOptimizeIR(F, *this, shouldShareBuffers()); |
716 | |
717 | if (opts.autoInstrument) { |
718 | autoInstrument(traceInfo, IR.get()); |
719 | } |
720 | |
721 | std::unique_ptr<CompiledFunction> compiledFunc; |
722 | if (opts.collectConstants) { |
723 | compiledFunc = compileIR(std::move(IR)); |
724 | } else { |
725 | compiledFunc = compileIRWithoutConstants(IR.get()); |
726 | } |
727 | |
728 | compiledFunc->setTraceInfo(std::move(traceInfo)); |
729 | return Expected<std::unique_ptr<CompiledFunction>>(std::move(compiledFunc)); |
730 | } |
731 | |
732 | void LLVMBackend::save(Function *F, llvm::StringRef outputDir, |
733 | llvm::StringRef bundleName, |
734 | llvm::StringRef mainEntryName) const { |
735 | llvm::SmallVector<std::string, 8> targetFeatures(llvmTargetFeatures.begin(), |
736 | llvmTargetFeatures.end()); |
737 | auto IR = generateAndOptimizeIR(F, *this, shouldShareBuffers()); |
738 | auto bundleSaver = createBundleSaver(*this, outputDir, bundleName); |
739 | bundleSaver->save(mainEntryName, IR.get()); |
740 | bundleSaver->produceBundle(); |
741 | } |
742 | |
743 | void LLVMBackend::saveFunctions(llvm::ArrayRef<BundleEntry> entries, |
744 | llvm::StringRef outputDir, |
745 | llvm::StringRef bundleName) const { |
746 | auto bundleSaver = createBundleSaver(*this, outputDir, bundleName); |
747 | std::vector<std::unique_ptr<glow::IRFunction>> irFunctions; |
748 | for (auto &entry : entries) { |
749 | auto IR = generateAndOptimizeIR(entry.func, *this, shouldShareBuffers()); |
750 | bundleSaver->save(entry.name, IR.get()); |
751 | irFunctions.emplace_back(std::move(IR)); |
752 | } |
753 | bundleSaver->produceBundle(); |
754 | } |
755 | |
756 | std::unique_ptr<BundleSaver> |
757 | LLVMBackend::createBundleSaver(const LLVMBackend &llvmBackend, |
758 | llvm::StringRef outputDir, |
759 | llvm::StringRef bundleName) const { |
760 | return glow::make_unique<BundleSaver>(llvmBackend, outputDir, bundleName); |
761 | } |
762 | |