1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Converter/Float16Converter.h" |
18 | |
19 | #include "glow/Converter/TypeAToTypeBFunctionConverter.h" |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | using namespace glow; |
23 | |
24 | /// Helper to pass over all Nodes in \p F and look for inputs of UInt8FusedQTy, |
25 | /// and convert them to UInt8FusedFP16QTy. \p precConfig contains the |
26 | /// black/whitelist for skipping nodes for transformation. |
27 | static void |
28 | convertFusedRowwiseQuantizedInputs(Function *F, |
29 | const PrecisionConfiguration &precConfig) { |
30 | auto &mod = *F->getParent(); |
31 | |
32 | // Iterate from original end to beginning to avoid processing new |
33 | // ConvertToNodes added during the pass. |
34 | auto nodeIt = F->getNodes().end(); |
35 | auto stopIt = F->getNodes().begin(); |
36 | do { |
37 | --nodeIt; |
38 | Node &node = *nodeIt; |
39 | // Only convert allowed nodes based on black/whitelist. |
40 | const bool inSet = precConfig.precisionModeKindSet.count(node.getKind()); |
41 | const bool allowConversion = precConfig.useSetAsWhitelist ? inSet : !inSet; |
42 | if (!allowConversion) { |
43 | continue; |
44 | } |
45 | |
46 | // Now check if any inputs are UInt8FusedQTy, and convert them accordingly. |
47 | for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) { |
48 | NodeValue input = node.getNthInput(idx); |
49 | if (input.getElementType() != ElemKind::UInt8FusedQTy) { |
50 | continue; |
51 | } |
52 | |
53 | // Create the conversion using the same shape but without the extra space |
54 | // needed for FP16 scale/offset instead of FP32. |
55 | const auto &shape = input.dims(); |
56 | assert(shape.size() == 2 && "UInt8FusedQTy must be 2D." ); |
57 | assert(precConfig.float16Format == |
58 | PrecisionConfiguration::Float16Format::FP16 && |
59 | "Only fused FP16 scale/offset is supported" ); |
60 | const dim_t newCols = shape[1] - 2 * (sizeof(float) - sizeof(float16_t)); |
61 | auto OT = mod.uniqueType(ElemKind::UInt8FusedFP16QTy, {shape[0], newCols}, |
62 | 1.0, 0); // Dummy scale/offset. |
63 | ConvertToNode *CN = F->createConvertTo( |
64 | input.getNode()->getName().str() + ".FP16" , input, OT); |
65 | node.setNthInput(idx, CN); |
66 | } |
67 | } while (nodeIt != stopIt); |
68 | } |
69 | |
70 | void glow::convertFunctionToFloat16(Function *F, |
71 | const PrecisionConfiguration &precConfig) { |
72 | DCHECK(precConfig.convertToFP16 || precConfig.convertFusedToFP16) |
73 | << "Expected to convert at least one of FloatTy or UInt8FusedQTy." ; |
74 | |
75 | // Convert FloatTy to Float16Ty or BFloat16Ty. |
76 | ElemKind destTy = |
77 | PrecisionConfiguration::getElementType(precConfig.float16Format); |
78 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::FloatTy, destTy, |
79 | precConfig); |
80 | if (precConfig.convertToFP16) { |
81 | converter.convert(); |
82 | |
83 | // Storage nodes are not converted + clipped directly -- they need to be |
84 | // converted via adding ConvertToNodes instead of directly setting their |
85 | // types like the TypeAToTypeBFunctionConverter does. |
86 | converter.convertAndClipStorage(); |
87 | } |
88 | |
89 | // Now we want to additionally convert all nodes with inputs in UInt8FusedQTy |
90 | // to UInt8FusedFP16QTy. This does not fit cleanly into the |
91 | // TypeAToTypeBFunctionConverter, so write a custom pass to do so. |
92 | if (precConfig.convertFusedToFP16) { |
93 | convertFusedRowwiseQuantizedInputs(F, precConfig); |
94 | } |
95 | } |
96 | |