1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Converter/Float16Converter.h"
18
19#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
20#include "glow/Graph/Graph.h"
21
22using namespace glow;
23
24/// Helper to pass over all Nodes in \p F and look for inputs of UInt8FusedQTy,
25/// and convert them to UInt8FusedFP16QTy. \p precConfig contains the
26/// black/whitelist for skipping nodes for transformation.
27static void
28convertFusedRowwiseQuantizedInputs(Function *F,
29 const PrecisionConfiguration &precConfig) {
30 auto &mod = *F->getParent();
31
32 // Iterate from original end to beginning to avoid processing new
33 // ConvertToNodes added during the pass.
34 auto nodeIt = F->getNodes().end();
35 auto stopIt = F->getNodes().begin();
36 do {
37 --nodeIt;
38 Node &node = *nodeIt;
39 // Only convert allowed nodes based on black/whitelist.
40 const bool inSet = precConfig.precisionModeKindSet.count(node.getKind());
41 const bool allowConversion = precConfig.useSetAsWhitelist ? inSet : !inSet;
42 if (!allowConversion) {
43 continue;
44 }
45
46 // Now check if any inputs are UInt8FusedQTy, and convert them accordingly.
47 for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) {
48 NodeValue input = node.getNthInput(idx);
49 if (input.getElementType() != ElemKind::UInt8FusedQTy) {
50 continue;
51 }
52
53 // Create the conversion using the same shape but without the extra space
54 // needed for FP16 scale/offset instead of FP32.
55 const auto &shape = input.dims();
56 assert(shape.size() == 2 && "UInt8FusedQTy must be 2D.");
57 assert(precConfig.float16Format ==
58 PrecisionConfiguration::Float16Format::FP16 &&
59 "Only fused FP16 scale/offset is supported");
60 const dim_t newCols = shape[1] - 2 * (sizeof(float) - sizeof(float16_t));
61 auto OT = mod.uniqueType(ElemKind::UInt8FusedFP16QTy, {shape[0], newCols},
62 1.0, 0); // Dummy scale/offset.
63 ConvertToNode *CN = F->createConvertTo(
64 input.getNode()->getName().str() + ".FP16", input, OT);
65 node.setNthInput(idx, CN);
66 }
67 } while (nodeIt != stopIt);
68}
69
70void glow::convertFunctionToFloat16(Function *F,
71 const PrecisionConfiguration &precConfig) {
72 DCHECK(precConfig.convertToFP16 || precConfig.convertFusedToFP16)
73 << "Expected to convert at least one of FloatTy or UInt8FusedQTy.";
74
75 // Convert FloatTy to Float16Ty or BFloat16Ty.
76 ElemKind destTy =
77 PrecisionConfiguration::getElementType(precConfig.float16Format);
78 TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, destTy,
79 precConfig);
80 if (precConfig.convertToFP16) {
81 converter.convert();
82
83 // Storage nodes are not converted + clipped directly -- they need to be
84 // converted via adding ConvertToNodes instead of directly setting their
85 // types like the TypeAToTypeBFunctionConverter does.
86 converter.convertAndClipStorage();
87 }
88
89 // Now we want to additionally convert all nodes with inputs in UInt8FusedQTy
90 // to UInt8FusedFP16QTy. This does not fit cleanly into the
91 // TypeAToTypeBFunctionConverter, so write a custom pass to do so.
92 if (precConfig.convertFusedToFP16) {
93 convertFusedRowwiseQuantizedInputs(F, precConfig);
94 }
95}
96