1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "BackendTestUtils.h" |
17 | |
18 | #include "glow/Converter/Float16Converter.h" |
19 | #include "glow/Converter/TypeAToTypeBFunctionConverter.h" |
20 | |
21 | #include "glow/Backend/Backend.h" |
22 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
23 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
24 | |
25 | #include "llvm/Support/Casting.h" |
26 | |
27 | #include "gtest/gtest.h" |
28 | |
29 | using namespace glow; |
30 | |
31 | /// Check that a simple graph is converted properly. |
32 | /// Namely, check that: |
33 | /// \verbatim |
34 | /// Input: Placeholder(float) |
35 | /// | |
36 | /// V |
37 | /// FC(float) Output: Placeholder(float) |
38 | /// | | |
39 | /// | +-------+ |
40 | /// | / |
41 | /// V V |
42 | /// Save |
43 | /// \endverbatim |
44 | /// |
45 | /// Gets converted into: |
46 | /// \verbatim |
47 | /// Input: Placeholder(float) |
48 | /// | |
49 | /// V |
50 | /// ConvertTo(float16) |
51 | /// | |
52 | /// V |
53 | /// FC(float16) Output: Placeholder(float) |
54 | /// | | |
55 | /// V | |
56 | /// ConvertTo(float) | |
57 | /// | +---------+ |
58 | /// | / |
59 | /// V V |
60 | /// Save |
61 | /// \endverbatim |
62 | /// |
63 | /// In particular, the input and output of the network shouldn't be modified. |
64 | TEST(TypeAToTypeBFunctionConverter, SimpleOneUseConversionFloatToFloat16) { |
65 | Module mod; |
66 | Function *F = mod.createFunction("test" ); |
67 | PlaceholderBindings bindings; |
68 | |
69 | auto *input = |
70 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
71 | auto *output = |
72 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
73 | |
74 | auto *FC = F->createFullyConnected(bindings, "FC" , input, 10); |
75 | auto *result = F->createSave("save" , FC, output); |
76 | |
77 | size_t origGraphSize = F->getNodes().size(); |
78 | |
79 | PrecisionConfiguration precConfig; |
80 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
81 | ElemKind::Float16Ty, precConfig); |
82 | converter.convert(); |
83 | |
84 | // We should have 4 more nodes: |
85 | // 1 conversion float to float16 for each input of FC (3) |
86 | // and 1 conversion float16 to float for the result of FC. |
87 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 4); |
88 | // Make sure the save node is still in the function and is unchanged. |
89 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
90 | F->getNodes().end()); |
91 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
92 | // Check that the save is fed from a conversion from float16 to float. |
93 | auto *convertedBackFCRes = llvm::dyn_cast<ConvertToNode>(result->getInput()); |
94 | ASSERT_NE(convertedBackFCRes, nullptr); |
95 | EXPECT_EQ(convertedBackFCRes->getElementType(ConvertToNode::ResultIdx), |
96 | ElemKind::FloatTy); |
97 | auto *convertedFC = |
98 | llvm::dyn_cast<FullyConnectedNode>(convertedBackFCRes->getInput()); |
99 | ASSERT_NE(convertedFC, nullptr); |
100 | EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx), |
101 | ElemKind::Float16Ty); |
102 | // Check that all the input of FC are convertTo node with from float to |
103 | // Float16Ty. |
104 | for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) { |
105 | auto *convertedFCInput = |
106 | llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx)); |
107 | ASSERT_NE(convertedFCInput, nullptr); |
108 | EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx), |
109 | ElemKind::Float16Ty); |
110 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput())); |
111 | EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy); |
112 | } |
113 | // At this point we know the input of FC is convertTo(placeholder). |
114 | // Check that this placeholder is the expected input. |
115 | EXPECT_EQ(convertedFC->getInput() |
116 | .getNode() |
117 | ->getNthInput(ConvertToNode::InputIdx) |
118 | .getNode(), |
119 | input); |
120 | } |
121 | |
122 | /// Check that a graph with a simple chain of computation is converted |
123 | /// properly. In particular, check that the intermediate conversion |
124 | /// steps are not eliminated by default. |
125 | /// Namely, check that: |
126 | /// \verbatim |
127 | /// Input: Placeholder(float) |
128 | /// | |
129 | /// V |
130 | /// FC(float) |
131 | /// | |
132 | /// V |
133 | /// ReLU(float) Output: Placeholder(float) |
134 | /// | | |
135 | /// | +-------+ |
136 | /// | / |
137 | /// V V |
138 | /// Save |
139 | /// \endverbatim |
140 | /// |
141 | /// Gets converted into: |
142 | /// \verbatim |
143 | /// Input: Placeholder(float) |
144 | /// | |
145 | /// V |
146 | /// ConvertTo(float16) |
147 | /// | |
148 | /// V |
149 | /// FC(float16) |
150 | /// | |
151 | /// V |
152 | /// ConvertTo(float) |
153 | /// | |
154 | /// V |
155 | /// ConvertTo(float16) |
156 | /// | |
157 | /// V |
158 | /// ReLU(float16) Output: Placeholder(float) |
159 | /// | | |
160 | /// V | |
161 | /// ConvertTo(float) | |
162 | /// | +---------+ |
163 | /// | / |
164 | /// V V |
165 | /// Save |
166 | /// \endverbatim |
167 | /// |
168 | /// In particular, the input and output of the network shouldn't be modified. |
169 | TEST(TypeAToTypeBFunctionConverter, |
170 | SimpleChainOfComputationConversionFloatToFloat16) { |
171 | Module mod; |
172 | Function *F = mod.createFunction("test" ); |
173 | PlaceholderBindings bindings; |
174 | |
175 | auto *input = |
176 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
177 | auto *output = |
178 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
179 | |
180 | auto *FC = F->createFullyConnected(bindings, "FC" , input, 10); |
181 | auto *ReLU = |
182 | F->createRELU("ReLU" , FC, FC->getType(FullyConnectedNode::ResultIdx)); |
183 | auto *result = F->createSave("save" , ReLU, output); |
184 | |
185 | size_t origGraphSize = F->getNodes().size(); |
186 | |
187 | PrecisionConfiguration precConfig; |
188 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
189 | ElemKind::Float16Ty, precConfig); |
190 | converter.convert(); |
191 | |
192 | // We should have 6 more nodes: |
193 | // 1 conversion float to float16 for each input of FC (3) |
194 | // 1 conversion float to float16 for the input of ReLU |
195 | // 1 conversion float16 to float for the result of FC. |
196 | // 1 conversion float16 to float for the result of ReLU. |
197 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 6); |
198 | // Make sure the save node is still in the function and is unchanged. |
199 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
200 | F->getNodes().end()); |
201 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
202 | // Check that the save is fed from a conversion from float16 to float. |
203 | auto *convertedBackReLURes = |
204 | llvm::dyn_cast<ConvertToNode>(result->getInput()); |
205 | ASSERT_NE(convertedBackReLURes, nullptr); |
206 | EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx), |
207 | ElemKind::FloatTy); |
208 | auto *convertedReLU = |
209 | llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput()); |
210 | ASSERT_NE(convertedReLU, nullptr); |
211 | EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx), |
212 | ElemKind::Float16Ty); |
213 | |
214 | // Check that the ReLU is fed from a conversion from float to float16. |
215 | auto *convertedToReLUInput = |
216 | llvm::dyn_cast<ConvertToNode>(convertedReLU->getInput()); |
217 | ASSERT_NE(convertedToReLUInput, nullptr); |
218 | EXPECT_EQ(convertedToReLUInput->getElementType(ConvertToNode::ResultIdx), |
219 | ElemKind::Float16Ty); |
220 | |
221 | // Check that this conversion is fed from a conversion from float16 to float. |
222 | auto *convertedBackFCRes = |
223 | llvm::dyn_cast<ConvertToNode>(convertedToReLUInput->getInput()); |
224 | ASSERT_NE(convertedBackFCRes, nullptr); |
225 | EXPECT_EQ(convertedBackFCRes->getElementType(ConvertToNode::ResultIdx), |
226 | ElemKind::FloatTy); |
227 | // Check that this conversion comes from the float16 FC node. |
228 | auto *convertedFC = |
229 | llvm::dyn_cast<FullyConnectedNode>(convertedBackFCRes->getInput()); |
230 | ASSERT_NE(convertedFC, nullptr); |
231 | EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx), |
232 | ElemKind::Float16Ty); |
233 | // Check that all the input of FC are convertTo node with from float to |
234 | // Float16Ty. |
235 | for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) { |
236 | auto *convertedFCInput = |
237 | llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx)); |
238 | ASSERT_NE(convertedFCInput, nullptr); |
239 | EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx), |
240 | ElemKind::Float16Ty); |
241 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput())); |
242 | EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy); |
243 | } |
244 | // At this point we know the input of FC is convertTo(placeholder). |
245 | // Check that this placeholder is the expected input. |
246 | EXPECT_EQ(convertedFC->getInput() |
247 | .getNode() |
248 | ->getNthInput(ConvertToNode::InputIdx) |
249 | .getNode(), |
250 | input); |
251 | } |
252 | |
253 | /// Check that the conversion honor the precision configuration for blacklisting |
254 | /// a node kind (Relu here) for a graph with a simple chain of computation. |
255 | /// Namely, check that: |
256 | /// \verbatim |
257 | /// Input: Placeholder(float) |
258 | /// | |
259 | /// V |
260 | /// FC(float) |
261 | /// | |
262 | /// V |
263 | /// ReLU(float) Output: Placeholder(float) |
264 | /// | | |
265 | /// | +-------+ |
266 | /// | / |
267 | /// V V |
268 | /// Save |
269 | /// \endverbatim |
270 | /// |
271 | /// Gets converted into: |
272 | /// \verbatim |
273 | /// Input: Placeholder(float) |
274 | /// | |
275 | /// V |
276 | /// ConvertTo(float16) |
277 | /// | |
278 | /// V |
279 | /// FC(float16) |
280 | /// | |
281 | /// V |
282 | /// ConvertTo(float) |
283 | /// | |
284 | /// V |
285 | /// ReLU(float) Output: Placeholder(float) |
286 | /// | | |
287 | /// | +---------+ |
288 | /// | / |
289 | /// V V |
290 | /// Save |
291 | /// \endverbatim |
292 | /// |
293 | /// In particular, the input and output of the network shouldn't be modified. |
294 | TEST(TypeAToTypeBFunctionConverter, DoNotConvertReLUConversionFloatToFloat16) { |
295 | Module mod; |
296 | Function *F = mod.createFunction("test" ); |
297 | PlaceholderBindings bindings; |
298 | |
299 | auto *input = |
300 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
301 | auto *output = |
302 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
303 | |
304 | auto *FC = F->createFullyConnected(bindings, "FC" , input, 10); |
305 | auto *ReLU = |
306 | F->createRELU("ReLU" , FC, FC->getType(FullyConnectedNode::ResultIdx)); |
307 | auto *result = F->createSave("save" , ReLU, output); |
308 | |
309 | size_t origGraphSize = F->getNodes().size(); |
310 | |
311 | PrecisionConfiguration precConfig; |
312 | precConfig.precisionModeKindSet.insert(Kinded::Kind::ReluNodeKind); |
313 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
314 | ElemKind::Float16Ty, precConfig); |
315 | converter.convert(); |
316 | |
317 | // We should have 4 more nodes: |
318 | // 1 conversion float to float16 for each input of FC (3) |
319 | // 1 conversion float16 to float for the result of FC. |
320 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 4); |
321 | // Make sure the save node is still in the function and is unchanged. |
322 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
323 | F->getNodes().end()); |
324 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
325 | // Check that the save is fed from a conversion from float16 to float. |
326 | auto *resultInput = llvm::dyn_cast<ReluNode>(result->getInput()); |
327 | ASSERT_NE(resultInput, nullptr); |
328 | EXPECT_EQ(resultInput->getElementType(ReluNode::ResultIdx), |
329 | ElemKind::FloatTy); |
330 | EXPECT_EQ(resultInput, ReLU); |
331 | |
332 | // Check that the ReLU is fed from a conversion from float16 to float. |
333 | auto *convertedToReLUInput = llvm::dyn_cast<ConvertToNode>(ReLU->getInput()); |
334 | ASSERT_NE(convertedToReLUInput, nullptr); |
335 | EXPECT_EQ(convertedToReLUInput->getElementType(ConvertToNode::ResultIdx), |
336 | ElemKind::FloatTy); |
337 | |
338 | // Check that this conversion comes from the float16 FC node. |
339 | auto *convertedFC = |
340 | llvm::dyn_cast<FullyConnectedNode>(convertedToReLUInput->getInput()); |
341 | ASSERT_NE(convertedFC, nullptr); |
342 | EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx), |
343 | ElemKind::Float16Ty); |
344 | // Check that all the input of FC are convertTo node with from float to |
345 | // Float16Ty. |
346 | for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) { |
347 | auto *convertedFCInput = |
348 | llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx)); |
349 | ASSERT_NE(convertedFCInput, nullptr); |
350 | EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx), |
351 | ElemKind::Float16Ty); |
352 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput())); |
353 | EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy); |
354 | } |
355 | // At this point we know the input of FC is convertTo(placeholder). |
356 | // Check that this placeholder is the expected input. |
357 | EXPECT_EQ(convertedFC->getInput() |
358 | .getNode() |
359 | ->getNthInput(ConvertToNode::InputIdx) |
360 | .getNode(), |
361 | input); |
362 | } |
363 | |
364 | /// Check that the conversion honor the precision configuration for whitelisting |
365 | /// a node kind (Relu here) for a graph with a simple chain of computation. |
366 | /// Namely, check that: |
367 | /// \verbatim |
368 | /// Input: Placeholder(float) |
369 | /// | |
370 | /// V |
371 | /// FC(float) |
372 | /// | |
373 | /// V |
374 | /// ReLU(float) Output: Placeholder(float) |
375 | /// | | |
376 | /// | +-------+ |
377 | /// | / |
378 | /// V V |
379 | /// Save |
380 | /// \endverbatim |
381 | /// |
382 | /// Gets converted into: |
383 | /// \verbatim |
384 | /// Input: Placeholder(float) |
385 | /// | |
386 | /// V |
387 | /// FC(float) |
388 | /// | |
389 | /// V |
390 | /// ConvertTo(float16) |
391 | /// | |
392 | /// V |
393 | /// ReLU(float16) |
394 | /// | |
395 | /// V |
396 | /// ConvertTo(float) Output: Placeholder(float) |
397 | /// | | |
398 | /// | +---------------+ |
399 | /// | / |
400 | /// V V |
401 | /// Save |
402 | /// \endverbatim |
403 | /// |
404 | /// In particular, the input and output of the network shouldn't be modified. |
405 | TEST(TypeAToTypeBFunctionConverter, OnlyReluConversionFloatToFloat16) { |
406 | Module mod; |
407 | Function *F = mod.createFunction("test" ); |
408 | PlaceholderBindings bindings; |
409 | |
410 | auto *input = |
411 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
412 | auto *output = |
413 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
414 | |
415 | auto *FC = F->createFullyConnected(bindings, "FC" , input, 10); |
416 | auto *RN = |
417 | F->createRELU("Relu" , FC, FC->getType(FullyConnectedNode::ResultIdx)); |
418 | auto *result = F->createSave("save" , RN, output); |
419 | |
420 | size_t origGraphSize = F->getNodes().size(); |
421 | |
422 | PrecisionConfiguration precConfig; |
423 | precConfig.precisionModeKindSet.insert(Kinded::Kind::ReluNodeKind); |
424 | precConfig.useSetAsWhitelist = true; |
425 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
426 | ElemKind::Float16Ty, precConfig); |
427 | converter.convert(); |
428 | |
429 | // We should have 4 more nodes: |
430 | // 1 conversion float to float16 for the input of Relu. |
431 | // 1 conversion float16 to float for the result of Relu. |
432 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 2); |
433 | // Make sure the save node is still in the function and is unchanged. |
434 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
435 | F->getNodes().end()); |
436 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
437 | // Check that the save is fed from a conversion from float16 to float. |
438 | auto *resultInput = llvm::dyn_cast<ConvertToNode>(result->getInput()); |
439 | ASSERT_NE(resultInput, nullptr); |
440 | EXPECT_EQ(resultInput->getInput().getElementType(), ElemKind::Float16Ty); |
441 | EXPECT_EQ(resultInput->getResult().getElementType(), ElemKind::FloatTy); |
442 | |
443 | // Check the Relu has FP16 inputs and outputs. |
444 | auto *convertedRelu = llvm::dyn_cast<ReluNode>(resultInput->getInput()); |
445 | ASSERT_NE(convertedRelu, nullptr); |
446 | EXPECT_EQ(convertedRelu->getInput().getElementType(), ElemKind::Float16Ty); |
447 | EXPECT_EQ(convertedRelu->getResult().getElementType(), ElemKind::Float16Ty); |
448 | |
449 | // Check that the Relu is fed from a conversion from float to float16. |
450 | auto *convertedToReluInput = llvm::dyn_cast<ConvertToNode>(RN->getInput()); |
451 | ASSERT_NE(convertedToReluInput, nullptr); |
452 | EXPECT_EQ(convertedToReluInput->getInput().getElementType(), |
453 | ElemKind::FloatTy); |
454 | EXPECT_EQ(convertedToReluInput->getResult().getElementType(), |
455 | ElemKind::Float16Ty); |
456 | |
457 | // Check that this conversion comes from the original float FC node. |
458 | EXPECT_EQ(convertedToReluInput->getInput().getNode(), FC); |
459 | EXPECT_EQ(FC->getResult().getElementType(), ElemKind::FloatTy); |
460 | // Check that all the input of FC are float. |
461 | for (unsigned idx = 0, end = FC->getNumInputs(); idx != end; ++idx) { |
462 | EXPECT_EQ(FC->getNthInput(idx).getElementType(), ElemKind::FloatTy); |
463 | } |
464 | // Check that the original placeholder is still the input to the FC and float. |
465 | EXPECT_EQ(FC->getInput().getNode(), input); |
466 | EXPECT_EQ(input->getOutput().getElementType(), ElemKind::FloatTy); |
467 | } |
468 | |
469 | /// Check that don't convert types we didn't asked for. |
470 | /// Namely, check that: |
471 | /// \verbatim |
472 | /// Input: Placeholder(float) |
473 | /// | |
474 | /// V |
475 | /// TopK(float, Int64I) |
476 | /// | | |
477 | /// | | Output: Placeholder(Int64I) |
478 | /// | | / |
479 | /// | V V |
480 | /// | Save Output: Placeholder(float) |
481 | /// | | |
482 | /// | +-------+ |
483 | /// | / |
484 | /// V V |
485 | /// Save |
486 | /// \endverbatim |
487 | /// |
488 | /// Gets converted into: |
489 | /// \verbatim |
490 | /// Input: Placeholder(float) |
491 | /// | |
492 | /// V |
493 | /// ConvertTo(float16) |
494 | /// | |
495 | /// V |
496 | /// TopK(float, Int64I) |
497 | /// | | |
498 | /// | | Output: Placeholder(Int64I) |
499 | /// | | / |
500 | /// | V V |
501 | /// | Save Output: Placeholder(float) |
502 | /// V | |
503 | /// ConvertTo(float) | |
504 | /// | | |
505 | /// | +----------+ |
506 | /// | / |
507 | /// V V |
508 | /// \endverbatim |
509 | /// |
510 | /// In particular, the input and outputs of the network shouldn't be modified |
511 | /// as well as the Int64I result. |
512 | TEST(TypeAToTypeBFunctionConverter, int64IConversionFloatToFloat16) { |
513 | Module mod; |
514 | Function *F = mod.createFunction("test" ); |
515 | PlaceholderBindings bindings; |
516 | |
517 | auto *input = |
518 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
519 | auto *output = |
520 | mod.createPlaceholder(ElemKind::FloatTy, {20, 3}, "Output" , false); |
521 | auto *outputIdx = |
522 | mod.createPlaceholder(ElemKind::Int64ITy, {20, 3}, "Output" , false); |
523 | |
524 | auto *topK = F->createTopK("topK" , input, 3); |
525 | auto *result = F->createSave("save" , topK->getValues(), output); |
526 | auto *resultIndices = F->createSave("saveIdx" , topK->getIndices(), outputIdx); |
527 | |
528 | size_t origGraphSize = F->getNodes().size(); |
529 | |
530 | PrecisionConfiguration precConfig; |
531 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
532 | ElemKind::Float16Ty, precConfig); |
533 | converter.convert(); |
534 | |
535 | // We should have 2 more nodes: |
536 | // 1 conversion float to float16 the input of TopK |
537 | // and 1 conversion float16 to float for the result of TopK. |
538 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 2); |
539 | // Make sure the save node is still in the function and is unchanged. |
540 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
541 | F->getNodes().end()); |
542 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
543 | // Check that the save is fed from a conversion from float16 to float. |
544 | auto *convertedBackTopKRes = |
545 | llvm::dyn_cast<ConvertToNode>(result->getInput()); |
546 | ASSERT_NE(convertedBackTopKRes, nullptr); |
547 | EXPECT_EQ(convertedBackTopKRes->getElementType(ConvertToNode::ResultIdx), |
548 | ElemKind::FloatTy); |
549 | auto *convertedTopK = |
550 | llvm::dyn_cast<TopKNode>(convertedBackTopKRes->getInput()); |
551 | ASSERT_NE(convertedTopK, nullptr); |
552 | EXPECT_EQ(convertedTopK->getElementType(TopKNode::ValuesIdx), |
553 | ElemKind::Float16Ty); |
554 | EXPECT_EQ(convertedTopK->getElementType(TopKNode::IndicesIdx), |
555 | ElemKind::Int64ITy); |
556 | // Check that the input of TopK is a convertTo node from float to |
557 | // Float16Ty. |
558 | auto *convertedTopKInput = |
559 | llvm::dyn_cast<ConvertToNode>(convertedTopK->getInput()); |
560 | ASSERT_NE(convertedTopKInput, nullptr); |
561 | EXPECT_EQ(convertedTopKInput->getElementType(ConvertToNode::ResultIdx), |
562 | ElemKind::Float16Ty); |
563 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedTopKInput->getInput())); |
564 | EXPECT_EQ(convertedTopKInput->getInput().getElementType(), ElemKind::FloatTy); |
565 | // At this point we know the input of TopK is convertTo(placeholder). |
566 | // Check that this placeholder is the expected input. |
567 | EXPECT_EQ(convertedTopK->getInput() |
568 | .getNode() |
569 | ->getNthInput(ConvertToNode::InputIdx) |
570 | .getNode(), |
571 | input); |
572 | |
573 | // Now check the Int64ITy part of the graph. |
574 | // Make sure the save node for the indices is still in the function and is |
575 | // unchanged. |
576 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), |
577 | *resultIndices) != F->getNodes().end()); |
578 | EXPECT_EQ(resultIndices->getOutput(), outputIdx->getOutput()); |
579 | EXPECT_EQ(resultIndices->getInput(), |
580 | convertedTopK->getNthResult(TopKNode::IndicesIdx)); |
581 | EXPECT_EQ(resultIndices->getInput().getElementType(), ElemKind::Int64ITy); |
582 | } |
583 | |
584 | /// Check that the conversion optimization can get rid of conversion of |
585 | /// constants and intermediate conversions. |
586 | /// Namely, check that: |
587 | /// \verbatim |
588 | /// Input: Placeholder(float) |
589 | /// | |
590 | /// | Weight: Constant(float) |
591 | /// | | Bias: Constant(float) |
592 | /// | | / |
593 | /// V V V |
594 | /// FC(float) |
595 | /// | |
596 | /// V |
597 | /// ReLU(float) Output: Placeholder(float) |
598 | /// | | |
599 | /// | +-------+ |
600 | /// | / |
601 | /// V V |
602 | /// Save |
603 | /// \endverbatim |
604 | /// |
605 | /// Gets converted into: |
606 | /// \verbatim |
607 | /// Input: Placeholder(float) |
608 | /// | |
609 | /// V |
610 | /// ConvertTo(float16) |
611 | /// | |
612 | /// | Weight: Constant(float16) |
613 | /// | | Bias: Constant(float16) |
614 | /// | | / |
615 | /// V V V |
616 | /// FC(float16) |
617 | /// | |
618 | /// V |
619 | /// ReLU(float16) Output: Placeholder(float) |
620 | /// | | |
621 | /// V | |
622 | /// ConvertTo(float) | |
623 | /// | +---------+ |
624 | /// | / |
625 | /// V V |
626 | /// Save |
627 | /// \endverbatim |
628 | /// |
629 | /// In particular, the input and output of the network shouldn't be modified. |
630 | TEST(TypeAToTypeBFunctionConverter, OptimizeMiddleConversionsFloatToFloat16) { |
631 | Module mod; |
632 | Function *F = mod.createFunction("test" ); |
633 | PlaceholderBindings bindings; |
634 | |
635 | auto *input = |
636 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
637 | auto *output = |
638 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
639 | |
640 | auto *weights = mod.createConstant( |
641 | mod.uniqueType(ElemKind::FloatTy, {13, 10}), "weights" ); |
642 | weights->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG()); |
643 | Tensor origWeights; |
644 | origWeights.assign(&weights->getPayload()); |
645 | auto *bias = |
646 | mod.createConstant(mod.uniqueType(ElemKind::FloatTy, {10}), "bias" ); |
647 | bias->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG()); |
648 | Tensor origBias; |
649 | origBias.assign(&bias->getPayload()); |
650 | |
651 | // This save is just to test that we do the right thing for constants with |
652 | // more than one use. |
653 | auto *saveBias = F->createSave("saveBias" , bias); |
654 | TypeRef FCTy = mod.uniqueType(ElemKind::FloatTy, {20, 10}); |
655 | auto *FC = F->createFullyConnected("FC" , input, weights, bias, FCTy); |
656 | auto *ReLU = |
657 | F->createRELU("ReLU" , FC, FC->getType(FullyConnectedNode::ResultIdx)); |
658 | auto *result = F->createSave("save" , ReLU, output); |
659 | |
660 | size_t origGraphSize = F->getNodes().size(); |
661 | |
662 | PrecisionConfiguration precConfig; |
663 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
664 | ElemKind::Float16Ty, precConfig); |
665 | converter.convert(); |
666 | |
667 | optimize(F, CompilationMode::Infer); |
668 | |
669 | // We should have 2 more nodes: |
670 | // 1 conversion float to float16 for the input of FC |
671 | // 1 conversion float16 to float for the result of ReLU. |
672 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 2); |
673 | // Make sure the save node is still in the function and is unchanged. |
674 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
675 | F->getNodes().end()); |
676 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
677 | // Check that the save is fed from a conversion from float16 to float. |
678 | auto *convertedBackReLURes = |
679 | llvm::dyn_cast<ConvertToNode>(result->getInput()); |
680 | ASSERT_NE(convertedBackReLURes, nullptr); |
681 | EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx), |
682 | ElemKind::FloatTy); |
683 | auto *convertedReLU = |
684 | llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput()); |
685 | ASSERT_NE(convertedReLU, nullptr); |
686 | EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx), |
687 | ElemKind::Float16Ty); |
688 | |
689 | // Check that the ReLU is fed directly by FC float16. |
690 | auto *convertedFC = |
691 | llvm::dyn_cast<FullyConnectedNode>(convertedReLU->getInput()); |
692 | ASSERT_NE(convertedFC, nullptr); |
693 | EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx), |
694 | ElemKind::Float16Ty); |
695 | // Check that the input of FC is a convertTo node from "input" from float to |
696 | // Float16Ty. |
697 | auto *convertedFCInput = |
698 | llvm::dyn_cast<ConvertToNode>(convertedFC->getInput()); |
699 | ASSERT_NE(convertedFCInput, nullptr); |
700 | EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx), |
701 | ElemKind::Float16Ty); |
702 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput())); |
703 | EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy); |
704 | EXPECT_EQ(convertedFCInput->getInput().getNode(), input); |
705 | |
706 | // Check that the weights have been updated to float16. |
707 | auto *convertedFCWeights = |
708 | llvm::dyn_cast<Constant>(convertedFC->getWeights()); |
709 | ASSERT_NE(convertedFCWeights, nullptr); |
710 | EXPECT_EQ(convertedFCWeights->getElementType(), ElemKind::Float16Ty); |
711 | EXPECT_EQ(convertedFCWeights, weights); |
712 | origWeights.convertToType(ElemKind::Float16Ty); |
713 | EXPECT_TRUE(origWeights.isEqual(weights->getPayload())); |
714 | |
715 | // Check that the bias has been duplicated and converted. |
716 | auto *convertedFCBias = llvm::dyn_cast<Constant>(convertedFC->getBias()); |
717 | ASSERT_NE(convertedFCBias, nullptr); |
718 | EXPECT_EQ(convertedFCBias->getElementType(), ElemKind::Float16Ty); |
719 | EXPECT_NE(convertedFCBias, bias); |
720 | origBias.convertToType(ElemKind::Float16Ty); |
721 | EXPECT_TRUE(origBias.isEqual(convertedFCBias->getPayload())); |
722 | |
723 | // Check that the original bias hasn't been altered. |
724 | EXPECT_EQ(bias->getElementType(), ElemKind::FloatTy); |
725 | EXPECT_EQ(saveBias->getInput().getNode(), bias); |
726 | } |
727 | |
728 | /// Check that the conversion of placeholder inserts conversion |
729 | /// at the right places, and in all the functions. |
730 | /// Namely, check that: |
731 | /// \verbatim |
732 | /// #### F #### |
733 | /// Input: Placeholder(float) |
734 | /// | | |
735 | /// | | Weight: Constant(float) |
736 | /// | | | Bias: Constant(float) |
737 | /// | | | / |
738 | /// | V V V |
739 | /// | FC(float) |
740 | /// | | |
741 | /// | V |
742 | /// | ReLU(float) Output: Placeholder(float) |
743 | /// | | | |
744 | /// | | +-------+ |
745 | /// | | / |
746 | /// | V V |
747 | /// | Save |
748 | /// | |
749 | /// | #### F2 #### |
750 | /// | Output2: Placeholder(float) |
751 | /// +-+ / |
752 | /// | | | |
753 | /// | V V |
754 | /// | Save |
755 | /// | |
756 | /// | #### F3 #### |
757 | /// | Output3: Placeholder(float) |
758 | /// | / |
759 | /// V V |
760 | /// Save |
761 | /// \endverbatim |
762 | /// |
763 | /// Gets converted into: |
764 | /// \verbatim |
765 | /// #### F #### |
766 | /// Input: Placeholder(float16) |
767 | /// | | |
768 | /// | V |
769 | /// |ConvertTo(float) |
770 | /// | | |
771 | /// | | Weight: Constant(float) |
772 | /// | | | Bias: Constant(float) |
773 | /// | | | / |
774 | /// | V V V |
775 | /// | FC(float) |
776 | /// | | |
777 | /// | V |
778 | /// | ReLU(float) Output: Placeholder(float16) |
779 | /// | | | |
780 | /// | V | |
781 | /// |ConvertTo(float16)| |
782 | /// | | +---------+ |
783 | /// | | / |
784 | /// | V V |
785 | /// | Save |
786 | /// | #### F2 #### |
787 | /// +-+ |
788 | /// | | |
789 | /// | V |
790 | /// | ConvertTo(float) |
791 | /// | | |
792 | /// | | Output2: Placeholder(float) |
793 | /// | | | |
794 | /// | V V |
795 | /// | Save |
796 | /// | |
797 | /// | #### F3 #### |
798 | /// V |
799 | /// ConvertTo(float) |
800 | /// | |
801 | /// V |
802 | /// ConvertTo(float16) |
803 | /// | |
804 | /// | Output3: Placeholder(float16) |
805 | /// | / |
806 | /// V V |
807 | /// Save |
808 | /// \endverbatim |
809 | /// |
810 | /// In particular, the input and output of the network should be modified |
811 | /// and the input of the last save node should be converted to the expected |
812 | /// output type. |
813 | TEST(TypeAToTypeBFunctionConverter, convertPlaceholderFloatToFloat16) { |
814 | Module mod; |
815 | Function *F = mod.createFunction("test" ); |
816 | Function *F2 = mod.createFunction("test2" ); |
817 | Function *F3 = mod.createFunction("test3" ); |
818 | PlaceholderBindings bindings; |
819 | |
820 | auto *input = |
821 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
822 | Tensor *inputTensor = bindings.allocate(input); |
823 | inputTensor->getHandle().randomize(-6.0, 6.0, mod.getPRNG()); |
824 | Tensor origInput; |
825 | origInput.assign(inputTensor); |
826 | |
827 | auto *output = |
828 | mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output" , false); |
829 | auto *output2 = |
830 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Output2" , false); |
831 | |
832 | auto *saveOutput2 = F2->createSave("saveOutput2" , input, output2); |
833 | |
834 | auto *output3 = |
835 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Output2" , false); |
836 | auto *saveOutput3 = F3->createSave("saveOutput3" , input, output3); |
837 | |
838 | auto *weights = mod.createConstant( |
839 | mod.uniqueType(ElemKind::FloatTy, {13, 10}), "weights" ); |
840 | weights->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG()); |
841 | Tensor origWeights; |
842 | origWeights.assign(&weights->getPayload()); |
843 | auto *bias = |
844 | mod.createConstant(mod.uniqueType(ElemKind::FloatTy, {10, 20}), "bias" ); |
845 | bias->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG()); |
846 | |
847 | TypeRef FCTy = mod.uniqueType(ElemKind::FloatTy, {20, 10}); |
848 | auto *FC = F->createFullyConnected("FC" , input, weights, bias, FCTy); |
849 | auto *ReLU = |
850 | F->createRELU("ReLU" , FC, FC->getType(FullyConnectedNode::ResultIdx)); |
851 | auto *result = F->createSave("save" , ReLU, output); |
852 | |
853 | size_t origGraphSize = F->getNodes().size(); |
854 | size_t f2OrigGraphSize = F2->getNodes().size(); |
855 | size_t f3OrigGraphSize = F3->getNodes().size(); |
856 | |
857 | PrecisionConfiguration precConfig; |
858 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
859 | ElemKind::Float16Ty, precConfig); |
860 | for (auto *placeholder : mod.getPlaceholders()) { |
861 | if (output2 == placeholder) { |
862 | continue; |
863 | } |
864 | converter.convertPlaceholder(*placeholder, &bindings); |
865 | } |
866 | |
867 | // We should have 2 more nodes in F: |
868 | // 1 conversion for each conversion for the input of the save node of output |
869 | // 1 conversion from the input to the FC |
870 | EXPECT_EQ(F->getNodes().size(), origGraphSize + 2); |
871 | // Make the save node of F is still in the function and is unchanged. |
872 | EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) != |
873 | F->getNodes().end()); |
874 | EXPECT_EQ(result->getOutput(), output->getOutput()); |
875 | // Check that the save is fed from a conversion from float16 to float. |
876 | auto *convertedBackReLURes = |
877 | llvm::dyn_cast<ConvertToNode>(result->getInput()); |
878 | ASSERT_NE(convertedBackReLURes, nullptr); |
879 | EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx), |
880 | ElemKind::Float16Ty); |
881 | auto *convertedReLU = |
882 | llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput()); |
883 | ASSERT_NE(convertedReLU, nullptr); |
884 | EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx), |
885 | ElemKind::FloatTy); |
886 | |
887 | // Check that the ReLU is fed directly by FC float. |
888 | auto *convertedFC = |
889 | llvm::dyn_cast<FullyConnectedNode>(convertedReLU->getInput()); |
890 | ASSERT_NE(convertedFC, nullptr); |
891 | EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx), |
892 | ElemKind::FloatTy); |
893 | // Check that the input of FC is a convertTo node from "input" from float to |
894 | // Float16Ty. |
895 | auto *convertedFCInput = |
896 | llvm::dyn_cast<ConvertToNode>(convertedFC->getInput()); |
897 | ASSERT_NE(convertedFCInput, nullptr); |
898 | EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx), |
899 | ElemKind::FloatTy); |
900 | EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput())); |
901 | EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::Float16Ty); |
902 | EXPECT_EQ(convertedFCInput->getInput().getNode(), input); |
903 | |
904 | // Checks for F2. |
905 | |
906 | // We should have 1 more node in F2: |
907 | // 1 conversion from the input to the input of the save node |
908 | |
909 | // Make the save node of F2 is still in the function and is unchanged. |
910 | EXPECT_EQ(F2->getNodes().size(), f2OrigGraphSize + 1); |
911 | EXPECT_TRUE(std::find(F2->getNodes().begin(), F2->getNodes().end(), |
912 | *saveOutput2) != F2->getNodes().end()); |
913 | EXPECT_EQ(saveOutput2->getOutput(), output2->getOutput()); |
914 | |
915 | // Check that the save is fed from a conversion from float16 to float. |
916 | auto *inputToFloat = llvm::dyn_cast<ConvertToNode>(saveOutput2->getInput()); |
917 | ASSERT_NE(inputToFloat, nullptr); |
918 | EXPECT_EQ(inputToFloat->getElementType(ConvertToNode::ResultIdx), |
919 | ElemKind::FloatTy); |
920 | // Check that this input is "input". |
921 | auto *inputOfF2 = llvm::dyn_cast<Placeholder>(inputToFloat->getInput()); |
922 | ASSERT_NE(inputOfF2, nullptr); |
923 | EXPECT_EQ(inputOfF2, input); |
924 | |
925 | // Checks for F3. |
926 | |
927 | // We should have 2 more nodes in F3: |
928 | // 1 conversion from the input to the input of the save node (coming |
929 | // from the input) |
930 | // 1 conversion from the input to the input of the save node (coming |
931 | // from the requirement for the output) |
932 | |
933 | // Make the save node of F3 is still in the function and is unchanged. |
934 | EXPECT_EQ(F3->getNodes().size(), f3OrigGraphSize + 2); |
935 | EXPECT_TRUE(std::find(F3->getNodes().begin(), F3->getNodes().end(), |
936 | *saveOutput3) != F3->getNodes().end()); |
937 | EXPECT_EQ(saveOutput3->getOutput(), output3->getOutput()); |
938 | EXPECT_EQ(output3->getElementType(), ElemKind::Float16Ty); |
939 | |
940 | // Check that the save is fed from a conversion from float16 to float. |
941 | auto *convertOutput3 = llvm::dyn_cast<ConvertToNode>(saveOutput3->getInput()); |
942 | ASSERT_NE(convertOutput3, nullptr); |
943 | EXPECT_EQ(convertOutput3->getElementType(ConvertToNode::ResultIdx), |
944 | ElemKind::Float16Ty); |
945 | |
946 | auto *convertInputFor3 = |
947 | llvm::dyn_cast<ConvertToNode>(convertOutput3->getInput()); |
948 | ASSERT_NE(convertInputFor3, nullptr); |
949 | EXPECT_EQ(convertInputFor3->getElementType(ConvertToNode::ResultIdx), |
950 | ElemKind::FloatTy); |
951 | // Check that this input is "input". |
952 | auto *inputOfF3 = llvm::dyn_cast<Placeholder>(convertInputFor3->getInput()); |
953 | ASSERT_NE(inputOfF3, nullptr); |
954 | EXPECT_EQ(inputOfF3, input); |
955 | |
956 | origInput.convertToType(ElemKind::Float16Ty); |
957 | EXPECT_TRUE(origInput.isEqual(*inputTensor)); |
958 | } |
959 | |
960 | /// Check that the verify doesn't complain when there are |
961 | /// noop conversion. This may happen on unoptimized network. |
962 | /// E.g., |
963 | /// Input: Placeholder(float) |
964 | /// | |
965 | /// V |
966 | /// OrigConvert: ConvertTo(float16) |
967 | /// | |
968 | /// V |
969 | /// Save |
970 | /// |
971 | /// Now converting the network to float16 will yield: |
972 | /// Input: Placeholder(float) |
973 | /// | |
974 | /// V |
975 | /// ConvertTo(float16); convert the input to fp16 |
976 | /// | |
977 | /// V |
978 | /// OrigConvert: ConvertTo(float16); <-- now this is a noop conversion. |
979 | /// | |
980 | /// V |
981 | /// Save |
982 | TEST(TypeAToTypeBFunctionConverter, convertExistingConversionToNoop) { |
983 | Module mod; |
984 | Function *F = mod.createFunction("test" ); |
985 | auto *placeholder = |
986 | mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input" , false); |
987 | |
988 | auto *convert = |
989 | F->createConvertTo("convert" , placeholder, ElemKind::Float16Ty); |
990 | auto *save = F->createSave("save" , convert); |
991 | |
992 | size_t origSize = F->getNodes().size(); |
993 | |
994 | PrecisionConfiguration precConfig; |
995 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
996 | ElemKind::Float16Ty, precConfig); |
997 | converter.convert(); |
998 | |
999 | EXPECT_EQ(F->getNodes().size(), origSize + 1); |
1000 | |
1001 | auto *convertToSave = llvm::dyn_cast<ConvertToNode>(save->getInput()); |
1002 | EXPECT_EQ(convertToSave, convert); |
1003 | EXPECT_EQ(convert->getElementType(ConvertToNode::ResultIdx), |
1004 | ElemKind::Float16Ty); |
1005 | |
1006 | auto *addedConversion = llvm::dyn_cast<ConvertToNode>(convert->getInput()); |
1007 | ASSERT_NE(addedConversion, nullptr); |
1008 | // At this point both the input and output of convert are FP16. |
1009 | EXPECT_EQ(addedConversion->getElementType(ConvertToNode::ResultIdx), |
1010 | ElemKind::Float16Ty); |
1011 | |
1012 | EXPECT_EQ(addedConversion->getInput().getNode(), placeholder); |
1013 | EXPECT_EQ(placeholder->getElementType(), ElemKind::FloatTy); |
1014 | |
1015 | EXPECT_TRUE(F->verify()); |
1016 | } |
1017 | |
1018 | /// Helper for testing FRWQSLWS FP16 conversion, with and without FP16 |
1019 | /// accumulation based on \p forceFP16AccumSLS. |
1020 | static void testConvertFRWQSLWS(bool forceFP16AccumSLS) { |
1021 | Module mod; |
1022 | Function *F = mod.createFunction("test" ); |
1023 | Tensor data(ElemKind::FloatTy, {3, 1}); |
1024 | data.getHandle() = { |
1025 | 2.0, |
1026 | -0.5, |
1027 | 13, |
1028 | }; |
1029 | |
1030 | Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights" ); |
1031 | |
1032 | Placeholder *indices = |
1033 | mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices" , |
1034 | /* isTrainable */ false); |
1035 | Placeholder *lengths = |
1036 | mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
1037 | /* isTrainable */ false); |
1038 | auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1039 | "RQSLWS" , data, weights, indices, lengths, ElemKind::UInt8FusedQTy); |
1040 | SaveNode *S = F->createSave("save" , R); |
1041 | |
1042 | size_t origSize = F->getNodes().size(); |
1043 | |
1044 | CompilationContext cctx; |
1045 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
1046 | precConfig.convertToFP16 = true; |
1047 | precConfig.convertFusedToFP16 = true; |
1048 | precConfig.forceFP16AccumSLS = forceFP16AccumSLS; |
1049 | transformForPrecisionMode(MockBackend(), F, cctx); |
1050 | |
1051 | // Should have added convert nodes for the Data, Weights, and Result. Data |
1052 | // and Weights ConvertTo nodes should have been merged in, while Result stil |
1053 | // has a ConvertTo. |
1054 | EXPECT_EQ(F->getNodes().size(), origSize + 1); |
1055 | |
1056 | auto *convertResult = llvm::dyn_cast<ConvertToNode>(S->getInput()); |
1057 | ASSERT_NE(convertResult, nullptr); |
1058 | EXPECT_EQ(convertResult->getResult().getElementType(), ElemKind::FloatTy); |
1059 | |
1060 | auto *SLWS = |
1061 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
1062 | convertResult->getInput()); |
1063 | ASSERT_NE(SLWS, nullptr); |
1064 | EXPECT_EQ(SLWS->getResult().getElementType(), ElemKind::Float16Ty); |
1065 | EXPECT_EQ(SLWS->getUseFP16Accumulation(), forceFP16AccumSLS); |
1066 | |
1067 | EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt8FusedFP16QTy); |
1068 | EXPECT_EQ(SLWS->getWeights().getElementType(), ElemKind::Float16Ty); |
1069 | |
1070 | EXPECT_TRUE(F->verify()); |
1071 | } |
1072 | |
1073 | /// Test conversion of a FusedRowwiseQuantizedSparseLengthsWeightedSumNode to |
1074 | /// FP16, instead of creating it directly. Use FP16 accumulation. |
1075 | TEST(TypeAToTypeBFunctionConverter, convertFRWQSLWS_FP16Accum) { |
1076 | testConvertFRWQSLWS(/* forceFP16AccumSLS */ true); |
1077 | } |
1078 | |
1079 | /// Test conversion of a FusedRowwiseQuantizedSparseLengthsWeightedSumNode to |
1080 | /// FP16, instead of creating it directly. Do not use FP16 accumulation; note |
1081 | /// that conversion by default uses FP32 accumulation. |
1082 | TEST(TypeAToTypeBFunctionConverter, convertFRWQSLWS_FP32Accum) { |
1083 | testConvertFRWQSLWS(/* forceFP16AccumSLS */ false); |
1084 | } |
1085 | |
1086 | /// Test skipping conversion of a |
1087 | /// FusedRowwiseQuantizedSparseLengthsWeightedSumNode to FP16. |
1088 | TEST(TypeAToTypeBFunctionConverter, skipConvertingFRWQSLWS) { |
1089 | Module mod; |
1090 | Function *F = mod.createFunction("test" ); |
1091 | Tensor data(ElemKind::FloatTy, {3, 1}); |
1092 | data.getHandle() = { |
1093 | 2.0, |
1094 | -0.5, |
1095 | 13, |
1096 | }; |
1097 | |
1098 | Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights" ); |
1099 | |
1100 | Placeholder *indices = |
1101 | mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices" , |
1102 | /* isTrainable */ false); |
1103 | Placeholder *lengths = |
1104 | mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
1105 | /* isTrainable */ false); |
1106 | auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1107 | "RQSLWS" , data, weights, indices, lengths, ElemKind::UInt8FusedQTy); |
1108 | SaveNode *S = F->createSave("save" , R); |
1109 | |
1110 | size_t origSize = F->getNodes().size(); |
1111 | |
1112 | PrecisionConfiguration precConfig; |
1113 | precConfig.convertToFP16 = true; |
1114 | precConfig.convertFusedToFP16 = true; |
1115 | precConfig.precisionModeKindSet.insert( |
1116 | Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind); |
1117 | convertFunctionToFloat16(F, precConfig); |
1118 | |
1119 | // Should have done nothing since we skipped its conversion. Check the |
1120 | // Function is the same as before. |
1121 | EXPECT_EQ(F->getNodes().size(), origSize); |
1122 | |
1123 | auto *SLWS = |
1124 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
1125 | S->getInput()); |
1126 | ASSERT_EQ(SLWS, R); |
1127 | |
1128 | auto *origData = llvm::dyn_cast<Constant>(SLWS->getData()); |
1129 | ASSERT_EQ(origData, R->getData().getNode()); |
1130 | EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy); |
1131 | |
1132 | auto *origWeights = llvm::dyn_cast<Constant>(SLWS->getWeights()); |
1133 | ASSERT_EQ(origWeights, weights); |
1134 | EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy); |
1135 | |
1136 | EXPECT_TRUE(F->verify()); |
1137 | } |
1138 | |
1139 | static void |
1140 | convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format float16Format) { |
1141 | Module mod; |
1142 | Function *F = mod.createFunction("test" ); |
1143 | Tensor data(ElemKind::FloatTy, {3, 1}); |
1144 | data.getHandle() = { |
1145 | 2.0, |
1146 | -0.5, |
1147 | 13, |
1148 | }; |
1149 | |
1150 | Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights" ); |
1151 | |
1152 | Placeholder *indices = |
1153 | mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices" , |
1154 | /* isTrainable */ false); |
1155 | Placeholder *lengths = |
1156 | mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
1157 | /* isTrainable */ false); |
1158 | auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1159 | "RQSLWS" , data, weights, indices, lengths, ElemKind::UInt8FusedQTy); |
1160 | SaveNode *S = F->createSave("save" , R); |
1161 | |
1162 | size_t origSize = F->getNodes().size(); |
1163 | |
1164 | PrecisionConfiguration precConfig; |
1165 | precConfig.convertToFP16 = true; |
1166 | precConfig.convertFusedToFP16 = false; |
1167 | precConfig.float16Format = float16Format; |
1168 | convertFunctionToFloat16(F, precConfig); |
1169 | |
1170 | ElemKind convertedElementType = |
1171 | PrecisionConfiguration::getElementType(float16Format); |
1172 | |
1173 | // Should have added convert nodes for the weights and results. |
1174 | EXPECT_EQ(F->getNodes().size(), origSize + 2); |
1175 | |
1176 | auto *convertResult = llvm::dyn_cast<ConvertToNode>(S->getInput()); |
1177 | ASSERT_NE(convertResult, nullptr); |
1178 | EXPECT_EQ(convertResult->getResult().getElementType(), ElemKind::FloatTy); |
1179 | |
1180 | auto *SLWS = |
1181 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
1182 | convertResult->getInput()); |
1183 | ASSERT_NE(SLWS, nullptr); |
1184 | EXPECT_EQ(SLWS->getResult().getElementType(), convertedElementType); |
1185 | |
1186 | auto *origData = llvm::dyn_cast<Constant>(SLWS->getData()); |
1187 | ASSERT_NE(origData, nullptr); |
1188 | EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy); |
1189 | |
1190 | auto *convertWeights = llvm::dyn_cast<ConvertToNode>(SLWS->getWeights()); |
1191 | ASSERT_NE(convertWeights, nullptr); |
1192 | EXPECT_EQ(convertWeights->getResult().getElementType(), convertedElementType); |
1193 | |
1194 | auto *origWeights = llvm::dyn_cast<Constant>(convertWeights->getInput()); |
1195 | ASSERT_NE(origWeights, nullptr); |
1196 | EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy); |
1197 | EXPECT_EQ(weights, origWeights); |
1198 | |
1199 | EXPECT_TRUE(F->verify()); |
1200 | } |
1201 | |
1202 | /// Test conversion of only FP16 inputs of Node and not UInt8FusedQTy. |
1203 | TEST(TypeAToTypeBFunctionConverter, convertOnlyFP16Ty) { |
1204 | convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format::FP16); |
1205 | } |
1206 | |
1207 | /// Test conversion of only BFloat16 inputs of Node and not UInt8FusedQTy. |
1208 | TEST(TypeAToTypeBFunctionConverter, convertOnlyBFloat16Ty) { |
1209 | convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format::BFloat16); |
1210 | } |
1211 | |
1212 | /// Test conversion of only UInt8FusedQTy inputs of Node and not Float16Ty. |
1213 | TEST(TypeAToTypeBFunctionConverter, convertOnlyUInt8FusedQTy) { |
1214 | Module mod; |
1215 | Function *F = mod.createFunction("test" ); |
1216 | Tensor data(ElemKind::FloatTy, {3, 1}); |
1217 | data.getHandle() = { |
1218 | 2.0, |
1219 | -0.5, |
1220 | 13, |
1221 | }; |
1222 | |
1223 | Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights" ); |
1224 | |
1225 | Placeholder *indices = |
1226 | mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices" , |
1227 | /* isTrainable */ false); |
1228 | Placeholder *lengths = |
1229 | mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
1230 | /* isTrainable */ false); |
1231 | auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1232 | "RQSLWS" , data, weights, indices, lengths, ElemKind::UInt8FusedQTy); |
1233 | SaveNode *S = F->createSave("save" , R); |
1234 | |
1235 | size_t origSize = F->getNodes().size(); |
1236 | |
1237 | PrecisionConfiguration precConfig; |
1238 | precConfig.convertToFP16 = false; |
1239 | precConfig.convertFusedToFP16 = true; |
1240 | convertFunctionToFloat16(F, precConfig); |
1241 | |
1242 | // Should have added a convert nodes for the data. |
1243 | EXPECT_EQ(F->getNodes().size(), origSize + 1); |
1244 | |
1245 | auto *SLWS = |
1246 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
1247 | S->getInput()); |
1248 | ASSERT_EQ(SLWS, R); |
1249 | |
1250 | auto *convertData = llvm::dyn_cast<ConvertToNode>(SLWS->getData()); |
1251 | ASSERT_NE(convertData, nullptr); |
1252 | EXPECT_EQ(convertData->getResult().getElementType(), |
1253 | ElemKind::UInt8FusedFP16QTy); |
1254 | |
1255 | auto *origData = llvm::dyn_cast<Constant>(convertData->getInput()); |
1256 | ASSERT_NE(origData, nullptr); |
1257 | EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy); |
1258 | |
1259 | auto *origWeights = llvm::dyn_cast<Constant>(SLWS->getWeights()); |
1260 | ASSERT_EQ(origWeights, weights); |
1261 | EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy); |
1262 | |
1263 | EXPECT_TRUE(F->verify()); |
1264 | } |
1265 | |
1266 | static void convertWithoutClipAroundNonNumericNodes( |
1267 | PrecisionConfiguration::Float16Format float16Format) { |
1268 | Module mod; |
1269 | Function *F = mod.createFunction("test" ); |
1270 | const dim_t dims[] = {1, 5, 10, 15}; |
1271 | const dim_t dimsReshape[] = {10, 10, 15}; |
1272 | Node *I0 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i0" , false); |
1273 | Node *I1 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i1" , false); |
1274 | Node *I2 = mod.createPlaceholder(ElemKind::Int32ITy, {2, 2, 2}, "i2" , false); |
1275 | Node *CN = F->createConcat("concat" , {I0, I1}, 1); |
1276 | Node *R = F->createReshape("reshape" , CN, dimsReshape); |
1277 | Node *S = F->createSlice("slice" , R, {0, 0, 0}, {5, 5, 5}); |
1278 | Node *G = F->createGather("gather" , S, I2); |
1279 | F->createSave("ret" , G); |
1280 | |
1281 | PrecisionConfiguration precConfig; |
1282 | precConfig.convertToFP16 = true; |
1283 | precConfig.clipFP16 = true; |
1284 | precConfig.float16Format = float16Format; |
1285 | convertFunctionToFloat16(F, precConfig); |
1286 | |
1287 | int numClips = 0; |
1288 | int numConvertTos = 0; |
1289 | for (auto &n : F->getNodes()) { |
1290 | if (n.getKind() == Kinded::Kind::ClipNodeKind) { |
1291 | ++numClips; |
1292 | } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) { |
1293 | ++numConvertTos; |
1294 | } |
1295 | } |
1296 | |
1297 | EXPECT_EQ(9, numConvertTos); |
1298 | EXPECT_EQ(0, numClips); |
1299 | |
1300 | EXPECT_TRUE(F->verify()); |
1301 | } |
1302 | |
1303 | // Test that we don't insert Clips around non-numeric nodes. |
1304 | TEST(TypeAToTypeBFunctionConverter, |
1305 | convertWithFP16WithoutClipAroundNonNumericNodes) { |
1306 | convertWithoutClipAroundNonNumericNodes( |
1307 | PrecisionConfiguration::Float16Format::FP16); |
1308 | } |
1309 | |
1310 | // Test that we don't insert Clips around non-numeric nodes. |
1311 | TEST(TypeAToTypeBFunctionConverter, |
1312 | convertWithBFloat16WithoutClipAroundNonNumericNodes) { |
1313 | convertWithoutClipAroundNonNumericNodes( |
1314 | PrecisionConfiguration::Float16Format::BFloat16); |
1315 | } |
1316 | |
1317 | static void convertWithoutClipAfterTanhOrSigmoid( |
1318 | PrecisionConfiguration::Float16Format float16Format) { |
1319 | Module mod; |
1320 | Function *F = mod.createFunction("test" ); |
1321 | const dim_t dims[] = {10, 20}; |
1322 | const dim_t dims2[] = {10, 30}; |
1323 | Node *I0 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i0" , false); |
1324 | Node *I1 = mod.createPlaceholder(ElemKind::FloatTy, dims2, "i1" , false); |
1325 | Node *T = F->createTanh("tanh" , {I0}); |
1326 | Node *S = F->createSigmoid("sigmoid" , {I1}); |
1327 | Node *CN = F->createConcat("concat" , {T, S}, 1); |
1328 | F->createSave("ret" , CN); |
1329 | |
1330 | PrecisionConfiguration precConfig; |
1331 | precConfig.convertToFP16 = true; |
1332 | precConfig.clipFP16 = true; |
1333 | precConfig.float16Format = float16Format; |
1334 | convertFunctionToFloat16(F, precConfig); |
1335 | |
1336 | int numClips = 0; |
1337 | int numConvertTos = 0; |
1338 | for (auto &n : F->getNodes()) { |
1339 | if (n.getKind() == Kinded::Kind::ClipNodeKind) { |
1340 | ++numClips; |
1341 | } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) { |
1342 | ++numConvertTos; |
1343 | } |
1344 | } |
1345 | |
1346 | EXPECT_EQ(7, numConvertTos); |
1347 | EXPECT_EQ(2, numClips); |
1348 | |
1349 | EXPECT_TRUE(F->verify()); |
1350 | } |
1351 | |
1352 | // Test that we don't insert Clips at the output of Tanh or Sigmoid |
1353 | TEST(TypeAToTypeBFunctionConverter, |
1354 | convertWithFP16WithoutClipAfterTanhOrSigmoid) { |
1355 | convertWithoutClipAfterTanhOrSigmoid( |
1356 | PrecisionConfiguration::Float16Format::FP16); |
1357 | } |
1358 | |
1359 | // Test that we don't insert Clips at the output of Tanh or Sigmoid |
1360 | TEST(TypeAToTypeBFunctionConverter, |
1361 | convertWithBFloat16WithoutClipAfterTanhOrSigmoid) { |
1362 | convertWithoutClipAfterTanhOrSigmoid( |
1363 | PrecisionConfiguration::Float16Format::BFloat16); |
1364 | } |
1365 | |
1366 | static void convertWithoutClipAfterFp16ConvertTo( |
1367 | PrecisionConfiguration::Float16Format float16Format) { |
1368 | Module mod; |
1369 | Function *F = mod.createFunction("test" ); |
1370 | const dim_t dims[] = {10, 20}; |
1371 | const dim_t dims2[] = {10, 30}; |
1372 | Node *I0 = mod.createPlaceholder(ElemKind::Float16Ty, dims, "i0" , false); |
1373 | Node *I1 = mod.createPlaceholder(ElemKind::Float16Ty, dims2, "i1" , false); |
1374 | Node *T = F->createConvertTo("c1" , {I0}, ElemKind::FloatTy); |
1375 | Node *S = F->createConvertTo("c2" , {I1}, ElemKind::FloatTy); |
1376 | Node *CN = F->createConcat("concat" , {T, S}, 1); |
1377 | F->createSave("ret" , CN); |
1378 | |
1379 | PrecisionConfiguration precConfig; |
1380 | precConfig.convertToFP16 = true; |
1381 | precConfig.clipFP16 = true; |
1382 | precConfig.float16Format = float16Format; |
1383 | convertFunctionToFloat16(F, precConfig); |
1384 | |
1385 | int numClips = 0; |
1386 | int numConvertTos = 0; |
1387 | for (auto &n : F->getNodes()) { |
1388 | if (n.getKind() == Kinded::Kind::ClipNodeKind) { |
1389 | ++numClips; |
1390 | } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) { |
1391 | ++numConvertTos; |
1392 | } |
1393 | } |
1394 | |
1395 | EXPECT_EQ(0, numClips); |
1396 | |
1397 | EXPECT_TRUE(F->verify()); |
1398 | } |
1399 | |
1400 | // Test that we don't insert Clips at the output of ConvertTo if its input is |
1401 | // fp16. |
1402 | TEST(TypeAToTypeBFunctionConverter, |
1403 | convertWithFP16WithoutClipAfterFp16ConvertTo) { |
1404 | convertWithoutClipAfterFp16ConvertTo( |
1405 | PrecisionConfiguration::Float16Format::FP16); |
1406 | } |
1407 | |
1408 | // Test that we don't insert Clips at the output of ConvertTo if its input is |
1409 | // bfloat16. |
1410 | TEST(TypeAToTypeBFunctionConverter, |
1411 | convertWithBFloat16WithoutClipAfterFp16ConvertTo) { |
1412 | convertWithoutClipAfterFp16ConvertTo( |
1413 | PrecisionConfiguration::Float16Format::BFloat16); |
1414 | } |
1415 | |
1416 | // Test that we only insert clips for outputs. |
1417 | static void |
1418 | checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format float16Format) { |
1419 | Module mod; |
1420 | Function *F = mod.createFunction("test" ); |
1421 | Node *I = mod.createPlaceholder(ElemKind::FloatTy, {10}, "i" , false); |
1422 | ReluNode *RN = F->createRELU("relu" , I); |
1423 | SaveNode *SN = F->createSave("ret" , RN); |
1424 | |
1425 | PrecisionConfiguration precConfig; |
1426 | precConfig.convertToFP16 = true; |
1427 | precConfig.clipFP16 = true; |
1428 | precConfig.clipFP16SkipInputs = true; |
1429 | precConfig.convertPlaceholdersToFP16 = true; |
1430 | precConfig.convertConstantsToFP16 = true; |
1431 | precConfig.float16Format = float16Format; |
1432 | convertFunctionToFloat16(F, precConfig); |
1433 | |
1434 | ElemKind convertedElementType = |
1435 | PrecisionConfiguration::getElementType(float16Format); |
1436 | |
1437 | // PH -> ConvertToFP16 -> ConvertToFP32 -> ConvertToFP16 -> Relu -> |
1438 | // Clip -> ConvertToFP32 -> Save |
1439 | |
1440 | ConvertToNode *convertRN = llvm::dyn_cast<ConvertToNode>(SN->getInput()); |
1441 | ASSERT_TRUE(convertRN); |
1442 | EXPECT_EQ(convertRN->getResult().getType()->getElementType(), |
1443 | ElemKind::FloatTy); |
1444 | ClipNode *clipRN = llvm::dyn_cast<ClipNode>(convertRN->getInput()); |
1445 | convertRN->getInput().getNode()->dump(); |
1446 | ASSERT_TRUE(clipRN); |
1447 | ASSERT_TRUE(clipRN->getInput() == RN->getResult()); |
1448 | ConvertToNode *convert32To16 = llvm::dyn_cast<ConvertToNode>(RN->getInput()); |
1449 | ASSERT_TRUE(convert32To16); |
1450 | EXPECT_EQ(convert32To16->getResult().getType()->getElementType(), |
1451 | convertedElementType); |
1452 | ConvertToNode *convert16To32 = |
1453 | llvm::dyn_cast<ConvertToNode>(convert32To16->getInput()); |
1454 | ASSERT_TRUE(convert16To32); |
1455 | EXPECT_EQ(convert16To32->getResult().getType()->getElementType(), |
1456 | ElemKind::FloatTy); |
1457 | ConvertToNode *convertPH = |
1458 | llvm::dyn_cast<ConvertToNode>(convert16To32->getInput()); |
1459 | ASSERT_TRUE(convertPH); |
1460 | EXPECT_EQ(convertPH->getResult().getType()->getElementType(), |
1461 | convertedElementType); |
1462 | |
1463 | EXPECT_TRUE(F->verify()); |
1464 | } |
1465 | |
1466 | // Test that we only insert clips for outputs. |
1467 | TEST(TypeAToTypeBFunctionConverter, checkWithFP16ConvertOnlyOutputs) { |
1468 | checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format::FP16); |
1469 | } |
1470 | |
1471 | // Test that we only insert clips for outputs. |
1472 | TEST(TypeAToTypeBFunctionConverter, checkWithBFloat16ConvertOnlyOutputs) { |
1473 | checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format::BFloat16); |
1474 | } |
1475 | |
1476 | // Test that we only insert clips for outputs. |
1477 | static void |
1478 | checkConvertClipStorage(PrecisionConfiguration::Float16Format float16Format) { |
1479 | Module mod; |
1480 | Function *F = mod.createFunction("test" ); |
1481 | Node *PH = mod.createPlaceholder(ElemKind::FloatTy, {10}, "ph" , false); |
1482 | Node *C = mod.createConstant(ElemKind::FloatTy, {10, 1}, "c" ); |
1483 | SaveNode *SPH = F->createSave("ret" , PH); |
1484 | SaveNode *SC = F->createSave("ret" , C); |
1485 | |
1486 | PrecisionConfiguration precConfig; |
1487 | precConfig.convertToFP16 = true; |
1488 | precConfig.clipFP16 = true; |
1489 | precConfig.clipFP16SkipInputs = true; |
1490 | precConfig.convertPlaceholdersToFP16 = true; |
1491 | precConfig.convertConstantsToFP16 = true; |
1492 | precConfig.float16Format = float16Format; |
1493 | convertFunctionToFloat16(F, precConfig); |
1494 | |
1495 | ConvertToNode *convertFP32PH = llvm::dyn_cast<ConvertToNode>(SPH->getInput()); |
1496 | ASSERT_TRUE(convertFP32PH); |
1497 | ConvertToNode *convertFP16PH = |
1498 | llvm::dyn_cast<ConvertToNode>(convertFP32PH->getInput()); |
1499 | ASSERT_TRUE(convertFP16PH); |
1500 | |
1501 | ConvertToNode *convertFP32C = llvm::dyn_cast<ConvertToNode>(SC->getInput()); |
1502 | ASSERT_TRUE(convertFP32C); |
1503 | ClipNode *clipC = llvm::dyn_cast<ClipNode>(convertFP32C->getInput()); |
1504 | ASSERT_TRUE(clipC); |
1505 | ConvertToNode *convertFP16C = |
1506 | llvm::dyn_cast<ConvertToNode>(clipC->getInput()); |
1507 | ASSERT_TRUE(convertFP16C); |
1508 | |
1509 | EXPECT_TRUE(F->verify()); |
1510 | } |
1511 | |
1512 | // Test that we only insert clips for outputs. |
1513 | TEST(TypeAToTypeBFunctionConverter, checkWithFP16ConvertClipStorage) { |
1514 | checkConvertClipStorage(PrecisionConfiguration::Float16Format::FP16); |
1515 | } |
1516 | |
1517 | // Test that we only insert clips for outputs. |
1518 | TEST(TypeAToTypeBFunctionConverter, checkWithBFloat16ConvertClipStorage) { |
1519 | checkConvertClipStorage(PrecisionConfiguration::Float16Format::BFloat16); |
1520 | } |
1521 | |
1522 | /// Check that quantized FC with FP32 bias doesn't have bias converted to FP16. |
1523 | TEST(TypeAToTypeBFunctionConverter, DoNotConvertFloatBiasWithIntInput) { |
1524 | Module mod; |
1525 | Function *F = mod.createFunction("test" ); |
1526 | PlaceholderBindings bindings; |
1527 | |
1528 | auto *input = mod.createPlaceholder(ElemKind::Int8QTy, {3, 8}, 0.05, -2, |
1529 | "input" , false); |
1530 | auto *weight = mod.createConstant(ElemKind::Int8QTy, {8, 10}, 0.02, 3, "w" ); |
1531 | auto *bias = mod.createConstant(ElemKind::FloatTy, {10}, "w" ); |
1532 | |
1533 | auto *FC = F->createFullyConnected("FC" , input, weight, bias); |
1534 | F->createSave("save" , FC); |
1535 | |
1536 | std::string origGraph = F->toString(); |
1537 | |
1538 | PrecisionConfiguration precConfig; |
1539 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, |
1540 | ElemKind::Float16Ty, precConfig); |
1541 | converter.convert(); |
1542 | |
1543 | EXPECT_EQ(origGraph, F->toString()); |
1544 | } |
1545 | |
1546 | /// Create a FRWQSLWS node with data type \p fusedKind and shape \p row and \p |
1547 | /// col, check if its data can be properly converted from UInt8FP16QTy to |
1548 | /// UInt8FusedQTy, or from UInt4FP16QTy to UInt4FusedQTy, and its indices can be |
1549 | /// properly coverted to Int64. |
1550 | static void testFRWQSLWSDataIndicesConvert(ElemKind fusedKind, dim_t row, |
1551 | dim_t col) { |
1552 | EXPECT_LT(row, 100); |
1553 | EXPECT_LT(col, 100); |
1554 | |
1555 | Module mod; |
1556 | Function *F = mod.createFunction("test" ); |
1557 | Tensor data(ElemKind::FloatTy, {row, col}); |
1558 | auto dataH = data.getHandle(); |
1559 | for (dim_t i = 0; i < row; i++) { |
1560 | for (dim_t j = 0; j < col; j++) { |
1561 | dataH.at({i, j}) = 2.0 * i + 1.0 * j; |
1562 | } |
1563 | } |
1564 | |
1565 | Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights" ); |
1566 | Placeholder *indices = |
1567 | mod.createPlaceholder(ElemKind::Int32ITy, {8}, "indices" , |
1568 | /* isTrainable */ false); |
1569 | Placeholder *lengths = |
1570 | mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
1571 | /* isTrainable */ false); |
1572 | auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
1573 | "RQSLWS" , data, weights, indices, lengths, fusedKind); |
1574 | SaveNode *S = F->createSave("save" , R); |
1575 | |
1576 | size_t origSize = F->getNodes().size(); |
1577 | CompilationContext cctx; |
1578 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
1579 | precConfig.convert4BitFusedToFP32 = true; |
1580 | precConfig.convert8BitFusedToFP32 = true; |
1581 | precConfig.convertIndicesToInt64 = true; |
1582 | precConfig.forceFP16AccumSLS = false; |
1583 | |
1584 | transformForPrecisionMode(MockBackend(), F, cctx); |
1585 | // Should have added ConvertTo nodes for the Data and indices. |
1586 | EXPECT_EQ(F->getNodes().size(), origSize + 2); |
1587 | |
1588 | optimize(F, CompilationMode::Infer); |
1589 | // Since data is a constant, after optimization, const folding should be |
1590 | // applied and a new data is created. Therefore, only 1 ConverTo node is left |
1591 | // for indices. |
1592 | EXPECT_EQ(F->getNodes().size(), origSize + 1); |
1593 | |
1594 | auto *SLWS = |
1595 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>( |
1596 | S->getInput()); |
1597 | ASSERT_NE(SLWS, nullptr); |
1598 | if (fusedKind == ElemKind::UInt8FusedFP16QTy) { |
1599 | EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt8FusedQTy); |
1600 | } else { |
1601 | EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt4FusedQTy); |
1602 | } |
1603 | EXPECT_EQ(SLWS->getIndices().getElementType(), ElemKind::Int64ITy); |
1604 | |
1605 | EXPECT_TRUE(F->verify()); |
1606 | } |
1607 | |
1608 | /// Testing converting UInt8FusedFP16QTy to UInt8FusedQTy. |
1609 | TEST(TypeAToTypeBFunctionConverter, FRWLWSConvert8Bit) { |
1610 | testFRWQSLWSDataIndicesConvert(ElemKind::UInt8FusedFP16QTy, 10, 10); |
1611 | } |
1612 | |
1613 | /// Testing converting UInt4FusedFP16QTy to UInt4FusedQTy. |
1614 | TEST(TypeAToTypeBFunctionConverter, FRWLWSConvert4Bit) { |
1615 | testFRWQSLWSDataIndicesConvert(ElemKind::UInt4FusedFP16QTy, 10, 10); |
1616 | } |
1617 | |