1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "BackendTestUtils.h"
17
18#include "glow/Converter/Float16Converter.h"
19#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
20
21#include "glow/Backend/Backend.h"
22#include "glow/ExecutionEngine/ExecutionEngine.h"
23#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
24
25#include "llvm/Support/Casting.h"
26
27#include "gtest/gtest.h"
28
29using namespace glow;
30
31/// Check that a simple graph is converted properly.
32/// Namely, check that:
33/// \verbatim
34/// Input: Placeholder(float)
35/// |
36/// V
37/// FC(float) Output: Placeholder(float)
38/// | |
39/// | +-------+
40/// | /
41/// V V
42/// Save
43/// \endverbatim
44///
45/// Gets converted into:
46/// \verbatim
47/// Input: Placeholder(float)
48/// |
49/// V
50/// ConvertTo(float16)
51/// |
52/// V
53/// FC(float16) Output: Placeholder(float)
54/// | |
55/// V |
56/// ConvertTo(float) |
57/// | +---------+
58/// | /
59/// V V
60/// Save
61/// \endverbatim
62///
63/// In particular, the input and output of the network shouldn't be modified.
64TEST(TypeAToTypeBFunctionConverter, SimpleOneUseConversionFloatToFloat16) {
65 Module mod;
66 Function *F = mod.createFunction("test");
67 PlaceholderBindings bindings;
68
69 auto *input =
70 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
71 auto *output =
72 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
73
74 auto *FC = F->createFullyConnected(bindings, "FC", input, 10);
75 auto *result = F->createSave("save", FC, output);
76
77 size_t origGraphSize = F->getNodes().size();
78
79 PrecisionConfiguration precConfig;
80 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
81 ElemKind::Float16Ty, precConfig);
82 converter.convert();
83
84 // We should have 4 more nodes:
85 // 1 conversion float to float16 for each input of FC (3)
86 // and 1 conversion float16 to float for the result of FC.
87 EXPECT_EQ(F->getNodes().size(), origGraphSize + 4);
88 // Make sure the save node is still in the function and is unchanged.
89 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
90 F->getNodes().end());
91 EXPECT_EQ(result->getOutput(), output->getOutput());
92 // Check that the save is fed from a conversion from float16 to float.
93 auto *convertedBackFCRes = llvm::dyn_cast<ConvertToNode>(result->getInput());
94 ASSERT_NE(convertedBackFCRes, nullptr);
95 EXPECT_EQ(convertedBackFCRes->getElementType(ConvertToNode::ResultIdx),
96 ElemKind::FloatTy);
97 auto *convertedFC =
98 llvm::dyn_cast<FullyConnectedNode>(convertedBackFCRes->getInput());
99 ASSERT_NE(convertedFC, nullptr);
100 EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx),
101 ElemKind::Float16Ty);
102 // Check that all the input of FC are convertTo node with from float to
103 // Float16Ty.
104 for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) {
105 auto *convertedFCInput =
106 llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx));
107 ASSERT_NE(convertedFCInput, nullptr);
108 EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx),
109 ElemKind::Float16Ty);
110 EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput()));
111 EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy);
112 }
113 // At this point we know the input of FC is convertTo(placeholder).
114 // Check that this placeholder is the expected input.
115 EXPECT_EQ(convertedFC->getInput()
116 .getNode()
117 ->getNthInput(ConvertToNode::InputIdx)
118 .getNode(),
119 input);
120}
121
122/// Check that a graph with a simple chain of computation is converted
123/// properly. In particular, check that the intermediate conversion
124/// steps are not eliminated by default.
125/// Namely, check that:
126/// \verbatim
127/// Input: Placeholder(float)
128/// |
129/// V
130/// FC(float)
131/// |
132/// V
133/// ReLU(float) Output: Placeholder(float)
134/// | |
135/// | +-------+
136/// | /
137/// V V
138/// Save
139/// \endverbatim
140///
141/// Gets converted into:
142/// \verbatim
143/// Input: Placeholder(float)
144/// |
145/// V
146/// ConvertTo(float16)
147/// |
148/// V
149/// FC(float16)
150/// |
151/// V
152/// ConvertTo(float)
153/// |
154/// V
155/// ConvertTo(float16)
156/// |
157/// V
158/// ReLU(float16) Output: Placeholder(float)
159/// | |
160/// V |
161/// ConvertTo(float) |
162/// | +---------+
163/// | /
164/// V V
165/// Save
166/// \endverbatim
167///
168/// In particular, the input and output of the network shouldn't be modified.
169TEST(TypeAToTypeBFunctionConverter,
170 SimpleChainOfComputationConversionFloatToFloat16) {
171 Module mod;
172 Function *F = mod.createFunction("test");
173 PlaceholderBindings bindings;
174
175 auto *input =
176 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
177 auto *output =
178 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
179
180 auto *FC = F->createFullyConnected(bindings, "FC", input, 10);
181 auto *ReLU =
182 F->createRELU("ReLU", FC, FC->getType(FullyConnectedNode::ResultIdx));
183 auto *result = F->createSave("save", ReLU, output);
184
185 size_t origGraphSize = F->getNodes().size();
186
187 PrecisionConfiguration precConfig;
188 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
189 ElemKind::Float16Ty, precConfig);
190 converter.convert();
191
192 // We should have 6 more nodes:
193 // 1 conversion float to float16 for each input of FC (3)
194 // 1 conversion float to float16 for the input of ReLU
195 // 1 conversion float16 to float for the result of FC.
196 // 1 conversion float16 to float for the result of ReLU.
197 EXPECT_EQ(F->getNodes().size(), origGraphSize + 6);
198 // Make sure the save node is still in the function and is unchanged.
199 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
200 F->getNodes().end());
201 EXPECT_EQ(result->getOutput(), output->getOutput());
202 // Check that the save is fed from a conversion from float16 to float.
203 auto *convertedBackReLURes =
204 llvm::dyn_cast<ConvertToNode>(result->getInput());
205 ASSERT_NE(convertedBackReLURes, nullptr);
206 EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx),
207 ElemKind::FloatTy);
208 auto *convertedReLU =
209 llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput());
210 ASSERT_NE(convertedReLU, nullptr);
211 EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx),
212 ElemKind::Float16Ty);
213
214 // Check that the ReLU is fed from a conversion from float to float16.
215 auto *convertedToReLUInput =
216 llvm::dyn_cast<ConvertToNode>(convertedReLU->getInput());
217 ASSERT_NE(convertedToReLUInput, nullptr);
218 EXPECT_EQ(convertedToReLUInput->getElementType(ConvertToNode::ResultIdx),
219 ElemKind::Float16Ty);
220
221 // Check that this conversion is fed from a conversion from float16 to float.
222 auto *convertedBackFCRes =
223 llvm::dyn_cast<ConvertToNode>(convertedToReLUInput->getInput());
224 ASSERT_NE(convertedBackFCRes, nullptr);
225 EXPECT_EQ(convertedBackFCRes->getElementType(ConvertToNode::ResultIdx),
226 ElemKind::FloatTy);
227 // Check that this conversion comes from the float16 FC node.
228 auto *convertedFC =
229 llvm::dyn_cast<FullyConnectedNode>(convertedBackFCRes->getInput());
230 ASSERT_NE(convertedFC, nullptr);
231 EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx),
232 ElemKind::Float16Ty);
233 // Check that all the input of FC are convertTo node with from float to
234 // Float16Ty.
235 for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) {
236 auto *convertedFCInput =
237 llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx));
238 ASSERT_NE(convertedFCInput, nullptr);
239 EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx),
240 ElemKind::Float16Ty);
241 EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput()));
242 EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy);
243 }
244 // At this point we know the input of FC is convertTo(placeholder).
245 // Check that this placeholder is the expected input.
246 EXPECT_EQ(convertedFC->getInput()
247 .getNode()
248 ->getNthInput(ConvertToNode::InputIdx)
249 .getNode(),
250 input);
251}
252
253/// Check that the conversion honor the precision configuration for blacklisting
254/// a node kind (Relu here) for a graph with a simple chain of computation.
255/// Namely, check that:
256/// \verbatim
257/// Input: Placeholder(float)
258/// |
259/// V
260/// FC(float)
261/// |
262/// V
263/// ReLU(float) Output: Placeholder(float)
264/// | |
265/// | +-------+
266/// | /
267/// V V
268/// Save
269/// \endverbatim
270///
271/// Gets converted into:
272/// \verbatim
273/// Input: Placeholder(float)
274/// |
275/// V
276/// ConvertTo(float16)
277/// |
278/// V
279/// FC(float16)
280/// |
281/// V
282/// ConvertTo(float)
283/// |
284/// V
285/// ReLU(float) Output: Placeholder(float)
286/// | |
287/// | +---------+
288/// | /
289/// V V
290/// Save
291/// \endverbatim
292///
293/// In particular, the input and output of the network shouldn't be modified.
294TEST(TypeAToTypeBFunctionConverter, DoNotConvertReLUConversionFloatToFloat16) {
295 Module mod;
296 Function *F = mod.createFunction("test");
297 PlaceholderBindings bindings;
298
299 auto *input =
300 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
301 auto *output =
302 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
303
304 auto *FC = F->createFullyConnected(bindings, "FC", input, 10);
305 auto *ReLU =
306 F->createRELU("ReLU", FC, FC->getType(FullyConnectedNode::ResultIdx));
307 auto *result = F->createSave("save", ReLU, output);
308
309 size_t origGraphSize = F->getNodes().size();
310
311 PrecisionConfiguration precConfig;
312 precConfig.precisionModeKindSet.insert(Kinded::Kind::ReluNodeKind);
313 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
314 ElemKind::Float16Ty, precConfig);
315 converter.convert();
316
317 // We should have 4 more nodes:
318 // 1 conversion float to float16 for each input of FC (3)
319 // 1 conversion float16 to float for the result of FC.
320 EXPECT_EQ(F->getNodes().size(), origGraphSize + 4);
321 // Make sure the save node is still in the function and is unchanged.
322 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
323 F->getNodes().end());
324 EXPECT_EQ(result->getOutput(), output->getOutput());
325 // Check that the save is fed from a conversion from float16 to float.
326 auto *resultInput = llvm::dyn_cast<ReluNode>(result->getInput());
327 ASSERT_NE(resultInput, nullptr);
328 EXPECT_EQ(resultInput->getElementType(ReluNode::ResultIdx),
329 ElemKind::FloatTy);
330 EXPECT_EQ(resultInput, ReLU);
331
332 // Check that the ReLU is fed from a conversion from float16 to float.
333 auto *convertedToReLUInput = llvm::dyn_cast<ConvertToNode>(ReLU->getInput());
334 ASSERT_NE(convertedToReLUInput, nullptr);
335 EXPECT_EQ(convertedToReLUInput->getElementType(ConvertToNode::ResultIdx),
336 ElemKind::FloatTy);
337
338 // Check that this conversion comes from the float16 FC node.
339 auto *convertedFC =
340 llvm::dyn_cast<FullyConnectedNode>(convertedToReLUInput->getInput());
341 ASSERT_NE(convertedFC, nullptr);
342 EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx),
343 ElemKind::Float16Ty);
344 // Check that all the input of FC are convertTo node with from float to
345 // Float16Ty.
346 for (unsigned idx = 0, end = convertedFC->getNumInputs(); idx != end; ++idx) {
347 auto *convertedFCInput =
348 llvm::dyn_cast<ConvertToNode>(convertedFC->getNthInput(idx));
349 ASSERT_NE(convertedFCInput, nullptr);
350 EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx),
351 ElemKind::Float16Ty);
352 EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput()));
353 EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy);
354 }
355 // At this point we know the input of FC is convertTo(placeholder).
356 // Check that this placeholder is the expected input.
357 EXPECT_EQ(convertedFC->getInput()
358 .getNode()
359 ->getNthInput(ConvertToNode::InputIdx)
360 .getNode(),
361 input);
362}
363
364/// Check that the conversion honor the precision configuration for whitelisting
365/// a node kind (Relu here) for a graph with a simple chain of computation.
366/// Namely, check that:
367/// \verbatim
368/// Input: Placeholder(float)
369/// |
370/// V
371/// FC(float)
372/// |
373/// V
374/// ReLU(float) Output: Placeholder(float)
375/// | |
376/// | +-------+
377/// | /
378/// V V
379/// Save
380/// \endverbatim
381///
382/// Gets converted into:
383/// \verbatim
384/// Input: Placeholder(float)
385/// |
386/// V
387/// FC(float)
388/// |
389/// V
390/// ConvertTo(float16)
391/// |
392/// V
393/// ReLU(float16)
394/// |
395/// V
396/// ConvertTo(float) Output: Placeholder(float)
397/// | |
398/// | +---------------+
399/// | /
400/// V V
401/// Save
402/// \endverbatim
403///
404/// In particular, the input and output of the network shouldn't be modified.
405TEST(TypeAToTypeBFunctionConverter, OnlyReluConversionFloatToFloat16) {
406 Module mod;
407 Function *F = mod.createFunction("test");
408 PlaceholderBindings bindings;
409
410 auto *input =
411 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
412 auto *output =
413 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
414
415 auto *FC = F->createFullyConnected(bindings, "FC", input, 10);
416 auto *RN =
417 F->createRELU("Relu", FC, FC->getType(FullyConnectedNode::ResultIdx));
418 auto *result = F->createSave("save", RN, output);
419
420 size_t origGraphSize = F->getNodes().size();
421
422 PrecisionConfiguration precConfig;
423 precConfig.precisionModeKindSet.insert(Kinded::Kind::ReluNodeKind);
424 precConfig.useSetAsWhitelist = true;
425 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
426 ElemKind::Float16Ty, precConfig);
427 converter.convert();
428
429 // We should have 4 more nodes:
430 // 1 conversion float to float16 for the input of Relu.
431 // 1 conversion float16 to float for the result of Relu.
432 EXPECT_EQ(F->getNodes().size(), origGraphSize + 2);
433 // Make sure the save node is still in the function and is unchanged.
434 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
435 F->getNodes().end());
436 EXPECT_EQ(result->getOutput(), output->getOutput());
437 // Check that the save is fed from a conversion from float16 to float.
438 auto *resultInput = llvm::dyn_cast<ConvertToNode>(result->getInput());
439 ASSERT_NE(resultInput, nullptr);
440 EXPECT_EQ(resultInput->getInput().getElementType(), ElemKind::Float16Ty);
441 EXPECT_EQ(resultInput->getResult().getElementType(), ElemKind::FloatTy);
442
443 // Check the Relu has FP16 inputs and outputs.
444 auto *convertedRelu = llvm::dyn_cast<ReluNode>(resultInput->getInput());
445 ASSERT_NE(convertedRelu, nullptr);
446 EXPECT_EQ(convertedRelu->getInput().getElementType(), ElemKind::Float16Ty);
447 EXPECT_EQ(convertedRelu->getResult().getElementType(), ElemKind::Float16Ty);
448
449 // Check that the Relu is fed from a conversion from float to float16.
450 auto *convertedToReluInput = llvm::dyn_cast<ConvertToNode>(RN->getInput());
451 ASSERT_NE(convertedToReluInput, nullptr);
452 EXPECT_EQ(convertedToReluInput->getInput().getElementType(),
453 ElemKind::FloatTy);
454 EXPECT_EQ(convertedToReluInput->getResult().getElementType(),
455 ElemKind::Float16Ty);
456
457 // Check that this conversion comes from the original float FC node.
458 EXPECT_EQ(convertedToReluInput->getInput().getNode(), FC);
459 EXPECT_EQ(FC->getResult().getElementType(), ElemKind::FloatTy);
460 // Check that all the input of FC are float.
461 for (unsigned idx = 0, end = FC->getNumInputs(); idx != end; ++idx) {
462 EXPECT_EQ(FC->getNthInput(idx).getElementType(), ElemKind::FloatTy);
463 }
464 // Check that the original placeholder is still the input to the FC and float.
465 EXPECT_EQ(FC->getInput().getNode(), input);
466 EXPECT_EQ(input->getOutput().getElementType(), ElemKind::FloatTy);
467}
468
469/// Check that don't convert types we didn't asked for.
470/// Namely, check that:
471/// \verbatim
472/// Input: Placeholder(float)
473/// |
474/// V
475/// TopK(float, Int64I)
476/// | |
477/// | | Output: Placeholder(Int64I)
478/// | | /
479/// | V V
480/// | Save Output: Placeholder(float)
481/// | |
482/// | +-------+
483/// | /
484/// V V
485/// Save
486/// \endverbatim
487///
488/// Gets converted into:
489/// \verbatim
490/// Input: Placeholder(float)
491/// |
492/// V
493/// ConvertTo(float16)
494/// |
495/// V
496/// TopK(float, Int64I)
497/// | |
498/// | | Output: Placeholder(Int64I)
499/// | | /
500/// | V V
501/// | Save Output: Placeholder(float)
502/// V |
503/// ConvertTo(float) |
504/// | |
505/// | +----------+
506/// | /
507/// V V
508/// \endverbatim
509///
510/// In particular, the input and outputs of the network shouldn't be modified
511/// as well as the Int64I result.
512TEST(TypeAToTypeBFunctionConverter, int64IConversionFloatToFloat16) {
513 Module mod;
514 Function *F = mod.createFunction("test");
515 PlaceholderBindings bindings;
516
517 auto *input =
518 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
519 auto *output =
520 mod.createPlaceholder(ElemKind::FloatTy, {20, 3}, "Output", false);
521 auto *outputIdx =
522 mod.createPlaceholder(ElemKind::Int64ITy, {20, 3}, "Output", false);
523
524 auto *topK = F->createTopK("topK", input, 3);
525 auto *result = F->createSave("save", topK->getValues(), output);
526 auto *resultIndices = F->createSave("saveIdx", topK->getIndices(), outputIdx);
527
528 size_t origGraphSize = F->getNodes().size();
529
530 PrecisionConfiguration precConfig;
531 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
532 ElemKind::Float16Ty, precConfig);
533 converter.convert();
534
535 // We should have 2 more nodes:
536 // 1 conversion float to float16 the input of TopK
537 // and 1 conversion float16 to float for the result of TopK.
538 EXPECT_EQ(F->getNodes().size(), origGraphSize + 2);
539 // Make sure the save node is still in the function and is unchanged.
540 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
541 F->getNodes().end());
542 EXPECT_EQ(result->getOutput(), output->getOutput());
543 // Check that the save is fed from a conversion from float16 to float.
544 auto *convertedBackTopKRes =
545 llvm::dyn_cast<ConvertToNode>(result->getInput());
546 ASSERT_NE(convertedBackTopKRes, nullptr);
547 EXPECT_EQ(convertedBackTopKRes->getElementType(ConvertToNode::ResultIdx),
548 ElemKind::FloatTy);
549 auto *convertedTopK =
550 llvm::dyn_cast<TopKNode>(convertedBackTopKRes->getInput());
551 ASSERT_NE(convertedTopK, nullptr);
552 EXPECT_EQ(convertedTopK->getElementType(TopKNode::ValuesIdx),
553 ElemKind::Float16Ty);
554 EXPECT_EQ(convertedTopK->getElementType(TopKNode::IndicesIdx),
555 ElemKind::Int64ITy);
556 // Check that the input of TopK is a convertTo node from float to
557 // Float16Ty.
558 auto *convertedTopKInput =
559 llvm::dyn_cast<ConvertToNode>(convertedTopK->getInput());
560 ASSERT_NE(convertedTopKInput, nullptr);
561 EXPECT_EQ(convertedTopKInput->getElementType(ConvertToNode::ResultIdx),
562 ElemKind::Float16Ty);
563 EXPECT_TRUE(llvm::isa<Placeholder>(convertedTopKInput->getInput()));
564 EXPECT_EQ(convertedTopKInput->getInput().getElementType(), ElemKind::FloatTy);
565 // At this point we know the input of TopK is convertTo(placeholder).
566 // Check that this placeholder is the expected input.
567 EXPECT_EQ(convertedTopK->getInput()
568 .getNode()
569 ->getNthInput(ConvertToNode::InputIdx)
570 .getNode(),
571 input);
572
573 // Now check the Int64ITy part of the graph.
574 // Make sure the save node for the indices is still in the function and is
575 // unchanged.
576 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(),
577 *resultIndices) != F->getNodes().end());
578 EXPECT_EQ(resultIndices->getOutput(), outputIdx->getOutput());
579 EXPECT_EQ(resultIndices->getInput(),
580 convertedTopK->getNthResult(TopKNode::IndicesIdx));
581 EXPECT_EQ(resultIndices->getInput().getElementType(), ElemKind::Int64ITy);
582}
583
584/// Check that the conversion optimization can get rid of conversion of
585/// constants and intermediate conversions.
586/// Namely, check that:
587/// \verbatim
588/// Input: Placeholder(float)
589/// |
590/// | Weight: Constant(float)
591/// | | Bias: Constant(float)
592/// | | /
593/// V V V
594/// FC(float)
595/// |
596/// V
597/// ReLU(float) Output: Placeholder(float)
598/// | |
599/// | +-------+
600/// | /
601/// V V
602/// Save
603/// \endverbatim
604///
605/// Gets converted into:
606/// \verbatim
607/// Input: Placeholder(float)
608/// |
609/// V
610/// ConvertTo(float16)
611/// |
612/// | Weight: Constant(float16)
613/// | | Bias: Constant(float16)
614/// | | /
615/// V V V
616/// FC(float16)
617/// |
618/// V
619/// ReLU(float16) Output: Placeholder(float)
620/// | |
621/// V |
622/// ConvertTo(float) |
623/// | +---------+
624/// | /
625/// V V
626/// Save
627/// \endverbatim
628///
629/// In particular, the input and output of the network shouldn't be modified.
630TEST(TypeAToTypeBFunctionConverter, OptimizeMiddleConversionsFloatToFloat16) {
631 Module mod;
632 Function *F = mod.createFunction("test");
633 PlaceholderBindings bindings;
634
635 auto *input =
636 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
637 auto *output =
638 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
639
640 auto *weights = mod.createConstant(
641 mod.uniqueType(ElemKind::FloatTy, {13, 10}), "weights");
642 weights->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG());
643 Tensor origWeights;
644 origWeights.assign(&weights->getPayload());
645 auto *bias =
646 mod.createConstant(mod.uniqueType(ElemKind::FloatTy, {10}), "bias");
647 bias->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG());
648 Tensor origBias;
649 origBias.assign(&bias->getPayload());
650
651 // This save is just to test that we do the right thing for constants with
652 // more than one use.
653 auto *saveBias = F->createSave("saveBias", bias);
654 TypeRef FCTy = mod.uniqueType(ElemKind::FloatTy, {20, 10});
655 auto *FC = F->createFullyConnected("FC", input, weights, bias, FCTy);
656 auto *ReLU =
657 F->createRELU("ReLU", FC, FC->getType(FullyConnectedNode::ResultIdx));
658 auto *result = F->createSave("save", ReLU, output);
659
660 size_t origGraphSize = F->getNodes().size();
661
662 PrecisionConfiguration precConfig;
663 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
664 ElemKind::Float16Ty, precConfig);
665 converter.convert();
666
667 optimize(F, CompilationMode::Infer);
668
669 // We should have 2 more nodes:
670 // 1 conversion float to float16 for the input of FC
671 // 1 conversion float16 to float for the result of ReLU.
672 EXPECT_EQ(F->getNodes().size(), origGraphSize + 2);
673 // Make sure the save node is still in the function and is unchanged.
674 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
675 F->getNodes().end());
676 EXPECT_EQ(result->getOutput(), output->getOutput());
677 // Check that the save is fed from a conversion from float16 to float.
678 auto *convertedBackReLURes =
679 llvm::dyn_cast<ConvertToNode>(result->getInput());
680 ASSERT_NE(convertedBackReLURes, nullptr);
681 EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx),
682 ElemKind::FloatTy);
683 auto *convertedReLU =
684 llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput());
685 ASSERT_NE(convertedReLU, nullptr);
686 EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx),
687 ElemKind::Float16Ty);
688
689 // Check that the ReLU is fed directly by FC float16.
690 auto *convertedFC =
691 llvm::dyn_cast<FullyConnectedNode>(convertedReLU->getInput());
692 ASSERT_NE(convertedFC, nullptr);
693 EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx),
694 ElemKind::Float16Ty);
695 // Check that the input of FC is a convertTo node from "input" from float to
696 // Float16Ty.
697 auto *convertedFCInput =
698 llvm::dyn_cast<ConvertToNode>(convertedFC->getInput());
699 ASSERT_NE(convertedFCInput, nullptr);
700 EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx),
701 ElemKind::Float16Ty);
702 EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput()));
703 EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::FloatTy);
704 EXPECT_EQ(convertedFCInput->getInput().getNode(), input);
705
706 // Check that the weights have been updated to float16.
707 auto *convertedFCWeights =
708 llvm::dyn_cast<Constant>(convertedFC->getWeights());
709 ASSERT_NE(convertedFCWeights, nullptr);
710 EXPECT_EQ(convertedFCWeights->getElementType(), ElemKind::Float16Ty);
711 EXPECT_EQ(convertedFCWeights, weights);
712 origWeights.convertToType(ElemKind::Float16Ty);
713 EXPECT_TRUE(origWeights.isEqual(weights->getPayload()));
714
715 // Check that the bias has been duplicated and converted.
716 auto *convertedFCBias = llvm::dyn_cast<Constant>(convertedFC->getBias());
717 ASSERT_NE(convertedFCBias, nullptr);
718 EXPECT_EQ(convertedFCBias->getElementType(), ElemKind::Float16Ty);
719 EXPECT_NE(convertedFCBias, bias);
720 origBias.convertToType(ElemKind::Float16Ty);
721 EXPECT_TRUE(origBias.isEqual(convertedFCBias->getPayload()));
722
723 // Check that the original bias hasn't been altered.
724 EXPECT_EQ(bias->getElementType(), ElemKind::FloatTy);
725 EXPECT_EQ(saveBias->getInput().getNode(), bias);
726}
727
728/// Check that the conversion of placeholder inserts conversion
729/// at the right places, and in all the functions.
730/// Namely, check that:
731/// \verbatim
732/// #### F ####
733/// Input: Placeholder(float)
734/// | |
735/// | | Weight: Constant(float)
736/// | | | Bias: Constant(float)
737/// | | | /
738/// | V V V
739/// | FC(float)
740/// | |
741/// | V
742/// | ReLU(float) Output: Placeholder(float)
743/// | | |
744/// | | +-------+
745/// | | /
746/// | V V
747/// | Save
748/// |
749/// | #### F2 ####
750/// | Output2: Placeholder(float)
751/// +-+ /
752/// | | |
753/// | V V
754/// | Save
755/// |
756/// | #### F3 ####
757/// | Output3: Placeholder(float)
758/// | /
759/// V V
760/// Save
761/// \endverbatim
762///
763/// Gets converted into:
764/// \verbatim
765/// #### F ####
766/// Input: Placeholder(float16)
767/// | |
768/// | V
769/// |ConvertTo(float)
770/// | |
771/// | | Weight: Constant(float)
772/// | | | Bias: Constant(float)
773/// | | | /
774/// | V V V
775/// | FC(float)
776/// | |
777/// | V
778/// | ReLU(float) Output: Placeholder(float16)
779/// | | |
780/// | V |
781/// |ConvertTo(float16)|
782/// | | +---------+
783/// | | /
784/// | V V
785/// | Save
786/// | #### F2 ####
787/// +-+
788/// | |
789/// | V
790/// | ConvertTo(float)
791/// | |
792/// | | Output2: Placeholder(float)
793/// | | |
794/// | V V
795/// | Save
796/// |
797/// | #### F3 ####
798/// V
799/// ConvertTo(float)
800/// |
801/// V
802/// ConvertTo(float16)
803/// |
804/// | Output3: Placeholder(float16)
805/// | /
806/// V V
807/// Save
808/// \endverbatim
809///
810/// In particular, the input and output of the network should be modified
811/// and the input of the last save node should be converted to the expected
812/// output type.
813TEST(TypeAToTypeBFunctionConverter, convertPlaceholderFloatToFloat16) {
814 Module mod;
815 Function *F = mod.createFunction("test");
816 Function *F2 = mod.createFunction("test2");
817 Function *F3 = mod.createFunction("test3");
818 PlaceholderBindings bindings;
819
820 auto *input =
821 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
822 Tensor *inputTensor = bindings.allocate(input);
823 inputTensor->getHandle().randomize(-6.0, 6.0, mod.getPRNG());
824 Tensor origInput;
825 origInput.assign(inputTensor);
826
827 auto *output =
828 mod.createPlaceholder(ElemKind::FloatTy, {20, 10}, "Output", false);
829 auto *output2 =
830 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Output2", false);
831
832 auto *saveOutput2 = F2->createSave("saveOutput2", input, output2);
833
834 auto *output3 =
835 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Output2", false);
836 auto *saveOutput3 = F3->createSave("saveOutput3", input, output3);
837
838 auto *weights = mod.createConstant(
839 mod.uniqueType(ElemKind::FloatTy, {13, 10}), "weights");
840 weights->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG());
841 Tensor origWeights;
842 origWeights.assign(&weights->getPayload());
843 auto *bias =
844 mod.createConstant(mod.uniqueType(ElemKind::FloatTy, {10, 20}), "bias");
845 bias->getPayloadMutable().getHandle().randomize(-5.0, 5.0, mod.getPRNG());
846
847 TypeRef FCTy = mod.uniqueType(ElemKind::FloatTy, {20, 10});
848 auto *FC = F->createFullyConnected("FC", input, weights, bias, FCTy);
849 auto *ReLU =
850 F->createRELU("ReLU", FC, FC->getType(FullyConnectedNode::ResultIdx));
851 auto *result = F->createSave("save", ReLU, output);
852
853 size_t origGraphSize = F->getNodes().size();
854 size_t f2OrigGraphSize = F2->getNodes().size();
855 size_t f3OrigGraphSize = F3->getNodes().size();
856
857 PrecisionConfiguration precConfig;
858 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
859 ElemKind::Float16Ty, precConfig);
860 for (auto *placeholder : mod.getPlaceholders()) {
861 if (output2 == placeholder) {
862 continue;
863 }
864 converter.convertPlaceholder(*placeholder, &bindings);
865 }
866
867 // We should have 2 more nodes in F:
868 // 1 conversion for each conversion for the input of the save node of output
869 // 1 conversion from the input to the FC
870 EXPECT_EQ(F->getNodes().size(), origGraphSize + 2);
871 // Make the save node of F is still in the function and is unchanged.
872 EXPECT_TRUE(std::find(F->getNodes().begin(), F->getNodes().end(), *result) !=
873 F->getNodes().end());
874 EXPECT_EQ(result->getOutput(), output->getOutput());
875 // Check that the save is fed from a conversion from float16 to float.
876 auto *convertedBackReLURes =
877 llvm::dyn_cast<ConvertToNode>(result->getInput());
878 ASSERT_NE(convertedBackReLURes, nullptr);
879 EXPECT_EQ(convertedBackReLURes->getElementType(ConvertToNode::ResultIdx),
880 ElemKind::Float16Ty);
881 auto *convertedReLU =
882 llvm::dyn_cast<ReluNode>(convertedBackReLURes->getInput());
883 ASSERT_NE(convertedReLU, nullptr);
884 EXPECT_EQ(convertedReLU->getElementType(ReluNode::ResultIdx),
885 ElemKind::FloatTy);
886
887 // Check that the ReLU is fed directly by FC float.
888 auto *convertedFC =
889 llvm::dyn_cast<FullyConnectedNode>(convertedReLU->getInput());
890 ASSERT_NE(convertedFC, nullptr);
891 EXPECT_EQ(convertedFC->getElementType(FullyConnectedNode::ResultIdx),
892 ElemKind::FloatTy);
893 // Check that the input of FC is a convertTo node from "input" from float to
894 // Float16Ty.
895 auto *convertedFCInput =
896 llvm::dyn_cast<ConvertToNode>(convertedFC->getInput());
897 ASSERT_NE(convertedFCInput, nullptr);
898 EXPECT_EQ(convertedFCInput->getElementType(ConvertToNode::ResultIdx),
899 ElemKind::FloatTy);
900 EXPECT_TRUE(llvm::isa<Placeholder>(convertedFCInput->getInput()));
901 EXPECT_EQ(convertedFCInput->getInput().getElementType(), ElemKind::Float16Ty);
902 EXPECT_EQ(convertedFCInput->getInput().getNode(), input);
903
904 // Checks for F2.
905
906 // We should have 1 more node in F2:
907 // 1 conversion from the input to the input of the save node
908
909 // Make the save node of F2 is still in the function and is unchanged.
910 EXPECT_EQ(F2->getNodes().size(), f2OrigGraphSize + 1);
911 EXPECT_TRUE(std::find(F2->getNodes().begin(), F2->getNodes().end(),
912 *saveOutput2) != F2->getNodes().end());
913 EXPECT_EQ(saveOutput2->getOutput(), output2->getOutput());
914
915 // Check that the save is fed from a conversion from float16 to float.
916 auto *inputToFloat = llvm::dyn_cast<ConvertToNode>(saveOutput2->getInput());
917 ASSERT_NE(inputToFloat, nullptr);
918 EXPECT_EQ(inputToFloat->getElementType(ConvertToNode::ResultIdx),
919 ElemKind::FloatTy);
920 // Check that this input is "input".
921 auto *inputOfF2 = llvm::dyn_cast<Placeholder>(inputToFloat->getInput());
922 ASSERT_NE(inputOfF2, nullptr);
923 EXPECT_EQ(inputOfF2, input);
924
925 // Checks for F3.
926
927 // We should have 2 more nodes in F3:
928 // 1 conversion from the input to the input of the save node (coming
929 // from the input)
930 // 1 conversion from the input to the input of the save node (coming
931 // from the requirement for the output)
932
933 // Make the save node of F3 is still in the function and is unchanged.
934 EXPECT_EQ(F3->getNodes().size(), f3OrigGraphSize + 2);
935 EXPECT_TRUE(std::find(F3->getNodes().begin(), F3->getNodes().end(),
936 *saveOutput3) != F3->getNodes().end());
937 EXPECT_EQ(saveOutput3->getOutput(), output3->getOutput());
938 EXPECT_EQ(output3->getElementType(), ElemKind::Float16Ty);
939
940 // Check that the save is fed from a conversion from float16 to float.
941 auto *convertOutput3 = llvm::dyn_cast<ConvertToNode>(saveOutput3->getInput());
942 ASSERT_NE(convertOutput3, nullptr);
943 EXPECT_EQ(convertOutput3->getElementType(ConvertToNode::ResultIdx),
944 ElemKind::Float16Ty);
945
946 auto *convertInputFor3 =
947 llvm::dyn_cast<ConvertToNode>(convertOutput3->getInput());
948 ASSERT_NE(convertInputFor3, nullptr);
949 EXPECT_EQ(convertInputFor3->getElementType(ConvertToNode::ResultIdx),
950 ElemKind::FloatTy);
951 // Check that this input is "input".
952 auto *inputOfF3 = llvm::dyn_cast<Placeholder>(convertInputFor3->getInput());
953 ASSERT_NE(inputOfF3, nullptr);
954 EXPECT_EQ(inputOfF3, input);
955
956 origInput.convertToType(ElemKind::Float16Ty);
957 EXPECT_TRUE(origInput.isEqual(*inputTensor));
958}
959
960/// Check that the verify doesn't complain when there are
961/// noop conversion. This may happen on unoptimized network.
962/// E.g.,
963/// Input: Placeholder(float)
964/// |
965/// V
966/// OrigConvert: ConvertTo(float16)
967/// |
968/// V
969/// Save
970///
971/// Now converting the network to float16 will yield:
972/// Input: Placeholder(float)
973/// |
974/// V
975/// ConvertTo(float16); convert the input to fp16
976/// |
977/// V
978/// OrigConvert: ConvertTo(float16); <-- now this is a noop conversion.
979/// |
980/// V
981/// Save
982TEST(TypeAToTypeBFunctionConverter, convertExistingConversionToNoop) {
983 Module mod;
984 Function *F = mod.createFunction("test");
985 auto *placeholder =
986 mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "Input", false);
987
988 auto *convert =
989 F->createConvertTo("convert", placeholder, ElemKind::Float16Ty);
990 auto *save = F->createSave("save", convert);
991
992 size_t origSize = F->getNodes().size();
993
994 PrecisionConfiguration precConfig;
995 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
996 ElemKind::Float16Ty, precConfig);
997 converter.convert();
998
999 EXPECT_EQ(F->getNodes().size(), origSize + 1);
1000
1001 auto *convertToSave = llvm::dyn_cast<ConvertToNode>(save->getInput());
1002 EXPECT_EQ(convertToSave, convert);
1003 EXPECT_EQ(convert->getElementType(ConvertToNode::ResultIdx),
1004 ElemKind::Float16Ty);
1005
1006 auto *addedConversion = llvm::dyn_cast<ConvertToNode>(convert->getInput());
1007 ASSERT_NE(addedConversion, nullptr);
1008 // At this point both the input and output of convert are FP16.
1009 EXPECT_EQ(addedConversion->getElementType(ConvertToNode::ResultIdx),
1010 ElemKind::Float16Ty);
1011
1012 EXPECT_EQ(addedConversion->getInput().getNode(), placeholder);
1013 EXPECT_EQ(placeholder->getElementType(), ElemKind::FloatTy);
1014
1015 EXPECT_TRUE(F->verify());
1016}
1017
1018/// Helper for testing FRWQSLWS FP16 conversion, with and without FP16
1019/// accumulation based on \p forceFP16AccumSLS.
1020static void testConvertFRWQSLWS(bool forceFP16AccumSLS) {
1021 Module mod;
1022 Function *F = mod.createFunction("test");
1023 Tensor data(ElemKind::FloatTy, {3, 1});
1024 data.getHandle() = {
1025 2.0,
1026 -0.5,
1027 13,
1028 };
1029
1030 Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
1031
1032 Placeholder *indices =
1033 mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
1034 /* isTrainable */ false);
1035 Placeholder *lengths =
1036 mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
1037 /* isTrainable */ false);
1038 auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1039 "RQSLWS", data, weights, indices, lengths, ElemKind::UInt8FusedQTy);
1040 SaveNode *S = F->createSave("save", R);
1041
1042 size_t origSize = F->getNodes().size();
1043
1044 CompilationContext cctx;
1045 PrecisionConfiguration &precConfig = cctx.precisionConfig;
1046 precConfig.convertToFP16 = true;
1047 precConfig.convertFusedToFP16 = true;
1048 precConfig.forceFP16AccumSLS = forceFP16AccumSLS;
1049 transformForPrecisionMode(MockBackend(), F, cctx);
1050
1051 // Should have added convert nodes for the Data, Weights, and Result. Data
1052 // and Weights ConvertTo nodes should have been merged in, while Result stil
1053 // has a ConvertTo.
1054 EXPECT_EQ(F->getNodes().size(), origSize + 1);
1055
1056 auto *convertResult = llvm::dyn_cast<ConvertToNode>(S->getInput());
1057 ASSERT_NE(convertResult, nullptr);
1058 EXPECT_EQ(convertResult->getResult().getElementType(), ElemKind::FloatTy);
1059
1060 auto *SLWS =
1061 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1062 convertResult->getInput());
1063 ASSERT_NE(SLWS, nullptr);
1064 EXPECT_EQ(SLWS->getResult().getElementType(), ElemKind::Float16Ty);
1065 EXPECT_EQ(SLWS->getUseFP16Accumulation(), forceFP16AccumSLS);
1066
1067 EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt8FusedFP16QTy);
1068 EXPECT_EQ(SLWS->getWeights().getElementType(), ElemKind::Float16Ty);
1069
1070 EXPECT_TRUE(F->verify());
1071}
1072
1073/// Test conversion of a FusedRowwiseQuantizedSparseLengthsWeightedSumNode to
1074/// FP16, instead of creating it directly. Use FP16 accumulation.
1075TEST(TypeAToTypeBFunctionConverter, convertFRWQSLWS_FP16Accum) {
1076 testConvertFRWQSLWS(/* forceFP16AccumSLS */ true);
1077}
1078
1079/// Test conversion of a FusedRowwiseQuantizedSparseLengthsWeightedSumNode to
1080/// FP16, instead of creating it directly. Do not use FP16 accumulation; note
1081/// that conversion by default uses FP32 accumulation.
1082TEST(TypeAToTypeBFunctionConverter, convertFRWQSLWS_FP32Accum) {
1083 testConvertFRWQSLWS(/* forceFP16AccumSLS */ false);
1084}
1085
1086/// Test skipping conversion of a
1087/// FusedRowwiseQuantizedSparseLengthsWeightedSumNode to FP16.
1088TEST(TypeAToTypeBFunctionConverter, skipConvertingFRWQSLWS) {
1089 Module mod;
1090 Function *F = mod.createFunction("test");
1091 Tensor data(ElemKind::FloatTy, {3, 1});
1092 data.getHandle() = {
1093 2.0,
1094 -0.5,
1095 13,
1096 };
1097
1098 Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
1099
1100 Placeholder *indices =
1101 mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
1102 /* isTrainable */ false);
1103 Placeholder *lengths =
1104 mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
1105 /* isTrainable */ false);
1106 auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1107 "RQSLWS", data, weights, indices, lengths, ElemKind::UInt8FusedQTy);
1108 SaveNode *S = F->createSave("save", R);
1109
1110 size_t origSize = F->getNodes().size();
1111
1112 PrecisionConfiguration precConfig;
1113 precConfig.convertToFP16 = true;
1114 precConfig.convertFusedToFP16 = true;
1115 precConfig.precisionModeKindSet.insert(
1116 Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind);
1117 convertFunctionToFloat16(F, precConfig);
1118
1119 // Should have done nothing since we skipped its conversion. Check the
1120 // Function is the same as before.
1121 EXPECT_EQ(F->getNodes().size(), origSize);
1122
1123 auto *SLWS =
1124 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1125 S->getInput());
1126 ASSERT_EQ(SLWS, R);
1127
1128 auto *origData = llvm::dyn_cast<Constant>(SLWS->getData());
1129 ASSERT_EQ(origData, R->getData().getNode());
1130 EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy);
1131
1132 auto *origWeights = llvm::dyn_cast<Constant>(SLWS->getWeights());
1133 ASSERT_EQ(origWeights, weights);
1134 EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy);
1135
1136 EXPECT_TRUE(F->verify());
1137}
1138
1139static void
1140convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format float16Format) {
1141 Module mod;
1142 Function *F = mod.createFunction("test");
1143 Tensor data(ElemKind::FloatTy, {3, 1});
1144 data.getHandle() = {
1145 2.0,
1146 -0.5,
1147 13,
1148 };
1149
1150 Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
1151
1152 Placeholder *indices =
1153 mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
1154 /* isTrainable */ false);
1155 Placeholder *lengths =
1156 mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
1157 /* isTrainable */ false);
1158 auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1159 "RQSLWS", data, weights, indices, lengths, ElemKind::UInt8FusedQTy);
1160 SaveNode *S = F->createSave("save", R);
1161
1162 size_t origSize = F->getNodes().size();
1163
1164 PrecisionConfiguration precConfig;
1165 precConfig.convertToFP16 = true;
1166 precConfig.convertFusedToFP16 = false;
1167 precConfig.float16Format = float16Format;
1168 convertFunctionToFloat16(F, precConfig);
1169
1170 ElemKind convertedElementType =
1171 PrecisionConfiguration::getElementType(float16Format);
1172
1173 // Should have added convert nodes for the weights and results.
1174 EXPECT_EQ(F->getNodes().size(), origSize + 2);
1175
1176 auto *convertResult = llvm::dyn_cast<ConvertToNode>(S->getInput());
1177 ASSERT_NE(convertResult, nullptr);
1178 EXPECT_EQ(convertResult->getResult().getElementType(), ElemKind::FloatTy);
1179
1180 auto *SLWS =
1181 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1182 convertResult->getInput());
1183 ASSERT_NE(SLWS, nullptr);
1184 EXPECT_EQ(SLWS->getResult().getElementType(), convertedElementType);
1185
1186 auto *origData = llvm::dyn_cast<Constant>(SLWS->getData());
1187 ASSERT_NE(origData, nullptr);
1188 EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy);
1189
1190 auto *convertWeights = llvm::dyn_cast<ConvertToNode>(SLWS->getWeights());
1191 ASSERT_NE(convertWeights, nullptr);
1192 EXPECT_EQ(convertWeights->getResult().getElementType(), convertedElementType);
1193
1194 auto *origWeights = llvm::dyn_cast<Constant>(convertWeights->getInput());
1195 ASSERT_NE(origWeights, nullptr);
1196 EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy);
1197 EXPECT_EQ(weights, origWeights);
1198
1199 EXPECT_TRUE(F->verify());
1200}
1201
1202/// Test conversion of only FP16 inputs of Node and not UInt8FusedQTy.
1203TEST(TypeAToTypeBFunctionConverter, convertOnlyFP16Ty) {
1204 convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format::FP16);
1205}
1206
1207/// Test conversion of only BFloat16 inputs of Node and not UInt8FusedQTy.
1208TEST(TypeAToTypeBFunctionConverter, convertOnlyBFloat16Ty) {
1209 convertOnlyFloat16Ty(PrecisionConfiguration::Float16Format::BFloat16);
1210}
1211
1212/// Test conversion of only UInt8FusedQTy inputs of Node and not Float16Ty.
1213TEST(TypeAToTypeBFunctionConverter, convertOnlyUInt8FusedQTy) {
1214 Module mod;
1215 Function *F = mod.createFunction("test");
1216 Tensor data(ElemKind::FloatTy, {3, 1});
1217 data.getHandle() = {
1218 2.0,
1219 -0.5,
1220 13,
1221 };
1222
1223 Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
1224
1225 Placeholder *indices =
1226 mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
1227 /* isTrainable */ false);
1228 Placeholder *lengths =
1229 mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
1230 /* isTrainable */ false);
1231 auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1232 "RQSLWS", data, weights, indices, lengths, ElemKind::UInt8FusedQTy);
1233 SaveNode *S = F->createSave("save", R);
1234
1235 size_t origSize = F->getNodes().size();
1236
1237 PrecisionConfiguration precConfig;
1238 precConfig.convertToFP16 = false;
1239 precConfig.convertFusedToFP16 = true;
1240 convertFunctionToFloat16(F, precConfig);
1241
1242 // Should have added a convert nodes for the data.
1243 EXPECT_EQ(F->getNodes().size(), origSize + 1);
1244
1245 auto *SLWS =
1246 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1247 S->getInput());
1248 ASSERT_EQ(SLWS, R);
1249
1250 auto *convertData = llvm::dyn_cast<ConvertToNode>(SLWS->getData());
1251 ASSERT_NE(convertData, nullptr);
1252 EXPECT_EQ(convertData->getResult().getElementType(),
1253 ElemKind::UInt8FusedFP16QTy);
1254
1255 auto *origData = llvm::dyn_cast<Constant>(convertData->getInput());
1256 ASSERT_NE(origData, nullptr);
1257 EXPECT_EQ(origData->getOutput().getElementType(), ElemKind::UInt8FusedQTy);
1258
1259 auto *origWeights = llvm::dyn_cast<Constant>(SLWS->getWeights());
1260 ASSERT_EQ(origWeights, weights);
1261 EXPECT_EQ(origWeights->getOutput().getElementType(), ElemKind::FloatTy);
1262
1263 EXPECT_TRUE(F->verify());
1264}
1265
1266static void convertWithoutClipAroundNonNumericNodes(
1267 PrecisionConfiguration::Float16Format float16Format) {
1268 Module mod;
1269 Function *F = mod.createFunction("test");
1270 const dim_t dims[] = {1, 5, 10, 15};
1271 const dim_t dimsReshape[] = {10, 10, 15};
1272 Node *I0 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i0", false);
1273 Node *I1 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i1", false);
1274 Node *I2 = mod.createPlaceholder(ElemKind::Int32ITy, {2, 2, 2}, "i2", false);
1275 Node *CN = F->createConcat("concat", {I0, I1}, 1);
1276 Node *R = F->createReshape("reshape", CN, dimsReshape);
1277 Node *S = F->createSlice("slice", R, {0, 0, 0}, {5, 5, 5});
1278 Node *G = F->createGather("gather", S, I2);
1279 F->createSave("ret", G);
1280
1281 PrecisionConfiguration precConfig;
1282 precConfig.convertToFP16 = true;
1283 precConfig.clipFP16 = true;
1284 precConfig.float16Format = float16Format;
1285 convertFunctionToFloat16(F, precConfig);
1286
1287 int numClips = 0;
1288 int numConvertTos = 0;
1289 for (auto &n : F->getNodes()) {
1290 if (n.getKind() == Kinded::Kind::ClipNodeKind) {
1291 ++numClips;
1292 } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) {
1293 ++numConvertTos;
1294 }
1295 }
1296
1297 EXPECT_EQ(9, numConvertTos);
1298 EXPECT_EQ(0, numClips);
1299
1300 EXPECT_TRUE(F->verify());
1301}
1302
1303// Test that we don't insert Clips around non-numeric nodes.
1304TEST(TypeAToTypeBFunctionConverter,
1305 convertWithFP16WithoutClipAroundNonNumericNodes) {
1306 convertWithoutClipAroundNonNumericNodes(
1307 PrecisionConfiguration::Float16Format::FP16);
1308}
1309
1310// Test that we don't insert Clips around non-numeric nodes.
1311TEST(TypeAToTypeBFunctionConverter,
1312 convertWithBFloat16WithoutClipAroundNonNumericNodes) {
1313 convertWithoutClipAroundNonNumericNodes(
1314 PrecisionConfiguration::Float16Format::BFloat16);
1315}
1316
1317static void convertWithoutClipAfterTanhOrSigmoid(
1318 PrecisionConfiguration::Float16Format float16Format) {
1319 Module mod;
1320 Function *F = mod.createFunction("test");
1321 const dim_t dims[] = {10, 20};
1322 const dim_t dims2[] = {10, 30};
1323 Node *I0 = mod.createPlaceholder(ElemKind::FloatTy, dims, "i0", false);
1324 Node *I1 = mod.createPlaceholder(ElemKind::FloatTy, dims2, "i1", false);
1325 Node *T = F->createTanh("tanh", {I0});
1326 Node *S = F->createSigmoid("sigmoid", {I1});
1327 Node *CN = F->createConcat("concat", {T, S}, 1);
1328 F->createSave("ret", CN);
1329
1330 PrecisionConfiguration precConfig;
1331 precConfig.convertToFP16 = true;
1332 precConfig.clipFP16 = true;
1333 precConfig.float16Format = float16Format;
1334 convertFunctionToFloat16(F, precConfig);
1335
1336 int numClips = 0;
1337 int numConvertTos = 0;
1338 for (auto &n : F->getNodes()) {
1339 if (n.getKind() == Kinded::Kind::ClipNodeKind) {
1340 ++numClips;
1341 } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) {
1342 ++numConvertTos;
1343 }
1344 }
1345
1346 EXPECT_EQ(7, numConvertTos);
1347 EXPECT_EQ(2, numClips);
1348
1349 EXPECT_TRUE(F->verify());
1350}
1351
1352// Test that we don't insert Clips at the output of Tanh or Sigmoid
1353TEST(TypeAToTypeBFunctionConverter,
1354 convertWithFP16WithoutClipAfterTanhOrSigmoid) {
1355 convertWithoutClipAfterTanhOrSigmoid(
1356 PrecisionConfiguration::Float16Format::FP16);
1357}
1358
1359// Test that we don't insert Clips at the output of Tanh or Sigmoid
1360TEST(TypeAToTypeBFunctionConverter,
1361 convertWithBFloat16WithoutClipAfterTanhOrSigmoid) {
1362 convertWithoutClipAfterTanhOrSigmoid(
1363 PrecisionConfiguration::Float16Format::BFloat16);
1364}
1365
1366static void convertWithoutClipAfterFp16ConvertTo(
1367 PrecisionConfiguration::Float16Format float16Format) {
1368 Module mod;
1369 Function *F = mod.createFunction("test");
1370 const dim_t dims[] = {10, 20};
1371 const dim_t dims2[] = {10, 30};
1372 Node *I0 = mod.createPlaceholder(ElemKind::Float16Ty, dims, "i0", false);
1373 Node *I1 = mod.createPlaceholder(ElemKind::Float16Ty, dims2, "i1", false);
1374 Node *T = F->createConvertTo("c1", {I0}, ElemKind::FloatTy);
1375 Node *S = F->createConvertTo("c2", {I1}, ElemKind::FloatTy);
1376 Node *CN = F->createConcat("concat", {T, S}, 1);
1377 F->createSave("ret", CN);
1378
1379 PrecisionConfiguration precConfig;
1380 precConfig.convertToFP16 = true;
1381 precConfig.clipFP16 = true;
1382 precConfig.float16Format = float16Format;
1383 convertFunctionToFloat16(F, precConfig);
1384
1385 int numClips = 0;
1386 int numConvertTos = 0;
1387 for (auto &n : F->getNodes()) {
1388 if (n.getKind() == Kinded::Kind::ClipNodeKind) {
1389 ++numClips;
1390 } else if (n.getKind() == Kinded::Kind::ConvertToNodeKind) {
1391 ++numConvertTos;
1392 }
1393 }
1394
1395 EXPECT_EQ(0, numClips);
1396
1397 EXPECT_TRUE(F->verify());
1398}
1399
1400// Test that we don't insert Clips at the output of ConvertTo if its input is
1401// fp16.
1402TEST(TypeAToTypeBFunctionConverter,
1403 convertWithFP16WithoutClipAfterFp16ConvertTo) {
1404 convertWithoutClipAfterFp16ConvertTo(
1405 PrecisionConfiguration::Float16Format::FP16);
1406}
1407
1408// Test that we don't insert Clips at the output of ConvertTo if its input is
1409// bfloat16.
1410TEST(TypeAToTypeBFunctionConverter,
1411 convertWithBFloat16WithoutClipAfterFp16ConvertTo) {
1412 convertWithoutClipAfterFp16ConvertTo(
1413 PrecisionConfiguration::Float16Format::BFloat16);
1414}
1415
1416// Test that we only insert clips for outputs.
1417static void
1418checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format float16Format) {
1419 Module mod;
1420 Function *F = mod.createFunction("test");
1421 Node *I = mod.createPlaceholder(ElemKind::FloatTy, {10}, "i", false);
1422 ReluNode *RN = F->createRELU("relu", I);
1423 SaveNode *SN = F->createSave("ret", RN);
1424
1425 PrecisionConfiguration precConfig;
1426 precConfig.convertToFP16 = true;
1427 precConfig.clipFP16 = true;
1428 precConfig.clipFP16SkipInputs = true;
1429 precConfig.convertPlaceholdersToFP16 = true;
1430 precConfig.convertConstantsToFP16 = true;
1431 precConfig.float16Format = float16Format;
1432 convertFunctionToFloat16(F, precConfig);
1433
1434 ElemKind convertedElementType =
1435 PrecisionConfiguration::getElementType(float16Format);
1436
1437 // PH -> ConvertToFP16 -> ConvertToFP32 -> ConvertToFP16 -> Relu ->
1438 // Clip -> ConvertToFP32 -> Save
1439
1440 ConvertToNode *convertRN = llvm::dyn_cast<ConvertToNode>(SN->getInput());
1441 ASSERT_TRUE(convertRN);
1442 EXPECT_EQ(convertRN->getResult().getType()->getElementType(),
1443 ElemKind::FloatTy);
1444 ClipNode *clipRN = llvm::dyn_cast<ClipNode>(convertRN->getInput());
1445 convertRN->getInput().getNode()->dump();
1446 ASSERT_TRUE(clipRN);
1447 ASSERT_TRUE(clipRN->getInput() == RN->getResult());
1448 ConvertToNode *convert32To16 = llvm::dyn_cast<ConvertToNode>(RN->getInput());
1449 ASSERT_TRUE(convert32To16);
1450 EXPECT_EQ(convert32To16->getResult().getType()->getElementType(),
1451 convertedElementType);
1452 ConvertToNode *convert16To32 =
1453 llvm::dyn_cast<ConvertToNode>(convert32To16->getInput());
1454 ASSERT_TRUE(convert16To32);
1455 EXPECT_EQ(convert16To32->getResult().getType()->getElementType(),
1456 ElemKind::FloatTy);
1457 ConvertToNode *convertPH =
1458 llvm::dyn_cast<ConvertToNode>(convert16To32->getInput());
1459 ASSERT_TRUE(convertPH);
1460 EXPECT_EQ(convertPH->getResult().getType()->getElementType(),
1461 convertedElementType);
1462
1463 EXPECT_TRUE(F->verify());
1464}
1465
1466// Test that we only insert clips for outputs.
1467TEST(TypeAToTypeBFunctionConverter, checkWithFP16ConvertOnlyOutputs) {
1468 checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format::FP16);
1469}
1470
1471// Test that we only insert clips for outputs.
1472TEST(TypeAToTypeBFunctionConverter, checkWithBFloat16ConvertOnlyOutputs) {
1473 checkConvertOnlyOutputs(PrecisionConfiguration::Float16Format::BFloat16);
1474}
1475
1476// Test that we only insert clips for outputs.
1477static void
1478checkConvertClipStorage(PrecisionConfiguration::Float16Format float16Format) {
1479 Module mod;
1480 Function *F = mod.createFunction("test");
1481 Node *PH = mod.createPlaceholder(ElemKind::FloatTy, {10}, "ph", false);
1482 Node *C = mod.createConstant(ElemKind::FloatTy, {10, 1}, "c");
1483 SaveNode *SPH = F->createSave("ret", PH);
1484 SaveNode *SC = F->createSave("ret", C);
1485
1486 PrecisionConfiguration precConfig;
1487 precConfig.convertToFP16 = true;
1488 precConfig.clipFP16 = true;
1489 precConfig.clipFP16SkipInputs = true;
1490 precConfig.convertPlaceholdersToFP16 = true;
1491 precConfig.convertConstantsToFP16 = true;
1492 precConfig.float16Format = float16Format;
1493 convertFunctionToFloat16(F, precConfig);
1494
1495 ConvertToNode *convertFP32PH = llvm::dyn_cast<ConvertToNode>(SPH->getInput());
1496 ASSERT_TRUE(convertFP32PH);
1497 ConvertToNode *convertFP16PH =
1498 llvm::dyn_cast<ConvertToNode>(convertFP32PH->getInput());
1499 ASSERT_TRUE(convertFP16PH);
1500
1501 ConvertToNode *convertFP32C = llvm::dyn_cast<ConvertToNode>(SC->getInput());
1502 ASSERT_TRUE(convertFP32C);
1503 ClipNode *clipC = llvm::dyn_cast<ClipNode>(convertFP32C->getInput());
1504 ASSERT_TRUE(clipC);
1505 ConvertToNode *convertFP16C =
1506 llvm::dyn_cast<ConvertToNode>(clipC->getInput());
1507 ASSERT_TRUE(convertFP16C);
1508
1509 EXPECT_TRUE(F->verify());
1510}
1511
1512// Test that we only insert clips for outputs.
1513TEST(TypeAToTypeBFunctionConverter, checkWithFP16ConvertClipStorage) {
1514 checkConvertClipStorage(PrecisionConfiguration::Float16Format::FP16);
1515}
1516
1517// Test that we only insert clips for outputs.
1518TEST(TypeAToTypeBFunctionConverter, checkWithBFloat16ConvertClipStorage) {
1519 checkConvertClipStorage(PrecisionConfiguration::Float16Format::BFloat16);
1520}
1521
1522/// Check that quantized FC with FP32 bias doesn't have bias converted to FP16.
1523TEST(TypeAToTypeBFunctionConverter, DoNotConvertFloatBiasWithIntInput) {
1524 Module mod;
1525 Function *F = mod.createFunction("test");
1526 PlaceholderBindings bindings;
1527
1528 auto *input = mod.createPlaceholder(ElemKind::Int8QTy, {3, 8}, 0.05, -2,
1529 "input", false);
1530 auto *weight = mod.createConstant(ElemKind::Int8QTy, {8, 10}, 0.02, 3, "w");
1531 auto *bias = mod.createConstant(ElemKind::FloatTy, {10}, "w");
1532
1533 auto *FC = F->createFullyConnected("FC", input, weight, bias);
1534 F->createSave("save", FC);
1535
1536 std::string origGraph = F->toString();
1537
1538 PrecisionConfiguration precConfig;
1539 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy,
1540 ElemKind::Float16Ty, precConfig);
1541 converter.convert();
1542
1543 EXPECT_EQ(origGraph, F->toString());
1544}
1545
1546/// Create a FRWQSLWS node with data type \p fusedKind and shape \p row and \p
1547/// col, check if its data can be properly converted from UInt8FP16QTy to
1548/// UInt8FusedQTy, or from UInt4FP16QTy to UInt4FusedQTy, and its indices can be
1549/// properly coverted to Int64.
1550static void testFRWQSLWSDataIndicesConvert(ElemKind fusedKind, dim_t row,
1551 dim_t col) {
1552 EXPECT_LT(row, 100);
1553 EXPECT_LT(col, 100);
1554
1555 Module mod;
1556 Function *F = mod.createFunction("test");
1557 Tensor data(ElemKind::FloatTy, {row, col});
1558 auto dataH = data.getHandle();
1559 for (dim_t i = 0; i < row; i++) {
1560 for (dim_t j = 0; j < col; j++) {
1561 dataH.at({i, j}) = 2.0 * i + 1.0 * j;
1562 }
1563 }
1564
1565 Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
1566 Placeholder *indices =
1567 mod.createPlaceholder(ElemKind::Int32ITy, {8}, "indices",
1568 /* isTrainable */ false);
1569 Placeholder *lengths =
1570 mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
1571 /* isTrainable */ false);
1572 auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
1573 "RQSLWS", data, weights, indices, lengths, fusedKind);
1574 SaveNode *S = F->createSave("save", R);
1575
1576 size_t origSize = F->getNodes().size();
1577 CompilationContext cctx;
1578 PrecisionConfiguration &precConfig = cctx.precisionConfig;
1579 precConfig.convert4BitFusedToFP32 = true;
1580 precConfig.convert8BitFusedToFP32 = true;
1581 precConfig.convertIndicesToInt64 = true;
1582 precConfig.forceFP16AccumSLS = false;
1583
1584 transformForPrecisionMode(MockBackend(), F, cctx);
1585 // Should have added ConvertTo nodes for the Data and indices.
1586 EXPECT_EQ(F->getNodes().size(), origSize + 2);
1587
1588 optimize(F, CompilationMode::Infer);
1589 // Since data is a constant, after optimization, const folding should be
1590 // applied and a new data is created. Therefore, only 1 ConverTo node is left
1591 // for indices.
1592 EXPECT_EQ(F->getNodes().size(), origSize + 1);
1593
1594 auto *SLWS =
1595 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1596 S->getInput());
1597 ASSERT_NE(SLWS, nullptr);
1598 if (fusedKind == ElemKind::UInt8FusedFP16QTy) {
1599 EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt8FusedQTy);
1600 } else {
1601 EXPECT_EQ(SLWS->getData().getElementType(), ElemKind::UInt4FusedQTy);
1602 }
1603 EXPECT_EQ(SLWS->getIndices().getElementType(), ElemKind::Int64ITy);
1604
1605 EXPECT_TRUE(F->verify());
1606}
1607
1608/// Testing converting UInt8FusedFP16QTy to UInt8FusedQTy.
1609TEST(TypeAToTypeBFunctionConverter, FRWLWSConvert8Bit) {
1610 testFRWQSLWSDataIndicesConvert(ElemKind::UInt8FusedFP16QTy, 10, 10);
1611}
1612
1613/// Testing converting UInt4FusedFP16QTy to UInt4FusedQTy.
1614TEST(TypeAToTypeBFunctionConverter, FRWLWSConvert4Bit) {
1615 testFRWQSLWSDataIndicesConvert(ElemKind::UInt4FusedFP16QTy, 10, 10);
1616}
1617