1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Converter/TypeAToTypeBFunctionConverter.h" |
18 | |
19 | #include "glow/Base/Tensor.h" |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | using namespace glow; |
23 | |
24 | TypeAToTypeBFunctionConverter::TypeAToTypeBFunctionConverter( |
25 | Function &F, ElemKind fromKind, ElemKind toKind, |
26 | const PrecisionConfiguration &precConfig) |
27 | : FunctionConverter(F), mod_(*F.getParent()), dstKind_(toKind), |
28 | srcKind_(fromKind), precConfig_(precConfig) {} |
29 | |
30 | bool TypeAToTypeBFunctionConverter::canConvert(const Node &node) const { |
31 | // For some ops, if we're converting to FP16/BFloat16 and the bias is FP32 and |
32 | // the input is quantized then don't convert to FP16/BFloat16. |
33 | if (srcKind_ == ElemKind::FloatTy && |
34 | (dstKind_ == ElemKind::Float16Ty || dstKind_ == ElemKind::BFloat16Ty)) { |
35 | #define QUANT_INPUT_FLOAT_BIAS_CASE(NODE_NAME_) \ |
36 | case glow::Kinded::Kind::NODE_NAME_##NodeKind: { \ |
37 | auto *N = llvm::cast<NODE_NAME_##Node>(&node); \ |
38 | if (N->getBias().getType()->getElementType() == ElemKind::FloatTy && \ |
39 | N->getInput().getType()->isQuantizedType()) { \ |
40 | return false; \ |
41 | } \ |
42 | break; \ |
43 | } |
44 | |
45 | #define QUANT_OR_FP16_INPUT_FLOAT_BIAS_CASE(NODE_NAME_) \ |
46 | case glow::Kinded::Kind::NODE_NAME_##NodeKind: { \ |
47 | auto *N = llvm::cast<NODE_NAME_##Node>(&node); \ |
48 | if (N->getBias().getType()->getElementType() == ElemKind::FloatTy && \ |
49 | (N->getInput().getType()->isQuantizedType() || \ |
50 | N->getInput().getType()->getElementType() == ElemKind::Float16Ty)) { \ |
51 | return false; \ |
52 | } \ |
53 | break; \ |
54 | } |
55 | |
56 | switch (node.getKind()) { |
57 | QUANT_INPUT_FLOAT_BIAS_CASE(FullyConnected); |
58 | QUANT_INPUT_FLOAT_BIAS_CASE(RowwiseQuantizedFullyConnected); |
59 | QUANT_INPUT_FLOAT_BIAS_CASE(Convolution); |
60 | QUANT_INPUT_FLOAT_BIAS_CASE(ConvTranspose); |
61 | QUANT_INPUT_FLOAT_BIAS_CASE(Convolution3D); |
62 | QUANT_INPUT_FLOAT_BIAS_CASE(ChannelwiseQuantizedConvolution); |
63 | QUANT_OR_FP16_INPUT_FLOAT_BIAS_CASE(BatchNormalization); |
64 | default: |
65 | break; |
66 | } |
67 | #undef QUANT_INPUT_FLOAT_BIAS_CASE |
68 | } |
69 | |
70 | const bool inSet = precConfig_.precisionModeKindSet.count(node.getKind()); |
71 | const bool allowConversion = precConfig_.useSetAsWhitelist ? inSet : !inSet; |
72 | |
73 | if (!allowConversion) { |
74 | return false; |
75 | } |
76 | return FunctionConverter::canConvert(node); |
77 | } |
78 | |
79 | TypeRef TypeAToTypeBFunctionConverter::getTargetTypeForOutput( |
80 | const NodeValue &out) const { |
81 | if (out.getType()->getElementType() != srcKind_) { |
82 | return nullptr; |
83 | } |
84 | return mod_.uniqueType(dstKind_, out.dims()); |
85 | } |
86 | |
87 | TypeRef |
88 | TypeAToTypeBFunctionConverter::getTargetTypeForInput(const Node &use, |
89 | unsigned idx) const { |
90 | #define IGNORE_CONVERT(nodeKind, inputIdx) \ |
91 | if (use.getKind() == nodeKind && idx == inputIdx) { \ |
92 | return nullptr; \ |
93 | } |
94 | IGNORE_CONVERT(Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind, |
95 | DynamicQuantizedFullyConnectedNode::BiasIdx) |
96 | IGNORE_CONVERT(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind, |
97 | DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx) |
98 | IGNORE_CONVERT(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind, |
99 | DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx) |
100 | #undef IGNORE_CONVERT |
101 | return getTargetTypeForOutput(use.getNthInput(idx)); |
102 | } |
103 | |
104 | Node *TypeAToTypeBFunctionConverter::createConversion(Function &function, |
105 | const Node &node, |
106 | NodeValue &val, |
107 | TypeRef destTy, |
108 | bool isInput) { |
109 | assert(((destTy->getElementType() == dstKind_ && |
110 | val.getType()->getElementType() == srcKind_) || |
111 | (destTy->getElementType() == srcKind_ && |
112 | val.getType()->getElementType() == dstKind_)) && |
113 | "Unexpected conversion type" ); |
114 | bool needClip = |
115 | ((dstKind_ == ElemKind::Float16Ty || dstKind_ == ElemKind::BFloat16Ty) && |
116 | precConfig_.clipFP16 && !(isInput && precConfig_.clipFP16SkipInputs)); |
117 | |
118 | if (precConfig_.skipBiasFp32tofp16Convert && |
119 | isBiasInput<FullyConnectedNode>(val, &node)) { |
120 | return val; |
121 | } |
122 | if (needClip) { |
123 | switch (node.getKind()) { |
124 | case Kinded::Kind::ConcatNodeKind: |
125 | case Kinded::Kind::GatherNodeKind: |
126 | case Kinded::Kind::ReshapeNodeKind: |
127 | case Kinded::Kind::SliceNodeKind: |
128 | case Kinded::Kind::TileNodeKind: |
129 | case Kinded::Kind::TransposeNodeKind: |
130 | needClip = false; |
131 | break; |
132 | case Kinded::Kind::SigmoidNodeKind: |
133 | case Kinded::Kind::TanhNodeKind: |
134 | needClip = isInput; |
135 | break; |
136 | case Kinded::Kind::ConvertToNodeKind: |
137 | needClip = (llvm::dyn_cast<const ConvertToNode>(&node) |
138 | ->getInput() |
139 | .getElementType() == ElemKind::Float16Ty || |
140 | llvm::dyn_cast<const ConvertToNode>(&node) |
141 | ->getInput() |
142 | .getElementType() == ElemKind::BFloat16Ty) |
143 | ? false |
144 | : true; |
145 | break; |
146 | default: |
147 | break; |
148 | } |
149 | } |
150 | if (needClip) { |
151 | assert((destTy->getElementType() == ElemKind::Float16Ty || |
152 | val.getType()->getElementType() == ElemKind::Float16Ty || |
153 | destTy->getElementType() == ElemKind::BFloat16Ty || |
154 | val.getType()->getElementType() == ElemKind::BFloat16Ty) && |
155 | "Unexpected conversion type" ); |
156 | // If the input is fp32 and output is fp16, then we want to do the convert |
157 | // before the clip. This way the clip can execute in fp16 mode. |
158 | if (destTy->getElementType() == ElemKind::Float16Ty && |
159 | val.getType()->getElementType() == ElemKind::FloatTy) { |
160 | auto convert = function.createConvertTo( |
161 | val.getNode()->getName().str() + "_converted" , val, destTy); |
162 | return function.createClipMinMaxFP16( |
163 | val.getNode()->getName().str() + "_clipped" , convert); |
164 | } else if (destTy->getElementType() == ElemKind::BFloat16Ty && |
165 | val.getType()->getElementType() == ElemKind::FloatTy) { |
166 | auto convert = function.createConvertTo( |
167 | val.getNode()->getName().str() + "_converted" , val, destTy); |
168 | return function.createClipMinMaxBFloat16( |
169 | val.getNode()->getName().str() + "_clipped" , convert); |
170 | } else { |
171 | auto clip = function.createClipMinMaxFP16( |
172 | val.getNode()->getName().str() + "_clipped" , val); |
173 | return function.createConvertTo( |
174 | val.getNode()->getName().str() + "_converted" , clip, destTy); |
175 | } |
176 | } else { |
177 | return function.createConvertTo( |
178 | val.getNode()->getName().str() + "_converted" , val, destTy); |
179 | } |
180 | } |
181 | |
182 | void TypeAToTypeBFunctionConverter::convertTensor(Tensor &tensor, |
183 | TypeRef destTy) { |
184 | assert(destTy->getElementType() == dstKind_); |
185 | tensor.convertToType(dstKind_); |
186 | } |
187 | |
188 | void convertAndClipStorageHelper( |
189 | Storage &S, Function &F, bool clipFloat16, |
190 | PrecisionConfiguration::Float16Format float16Format, ElemKind srcKind, |
191 | ElemKind dstKind) { |
192 | if (S.getOutput().getType()->getElementType() != srcKind) { |
193 | return; |
194 | } |
195 | |
196 | ConvertToNode *convertToFloat16 = F.createConvertTo( |
197 | S.getName().str() + "convert_to" , S.getOutput(), dstKind); |
198 | |
199 | NodeValue NV = convertToFloat16->getResult(); |
200 | if (clipFloat16) { |
201 | switch (float16Format) { |
202 | case PrecisionConfiguration::Float16Format::FP16: |
203 | NV = F.createClipMinMaxFP16(S.getName().str() + "_clipped" , NV) |
204 | ->getResult(); |
205 | break; |
206 | case PrecisionConfiguration::Float16Format::BFloat16: |
207 | NV = F.createClipMinMaxBFloat16(S.getName().str() + "_clipped" , NV) |
208 | ->getResult(); |
209 | break; |
210 | default: |
211 | llvm_unreachable("Unknown float16 format" ); |
212 | } |
213 | } |
214 | |
215 | // We have to convert back to the srcKind now as the users currently must be |
216 | // expecting FP32. The optimizer will remove if possible. |
217 | NodeValue convertBack = |
218 | F.createConvertTo(NV.getNode()->getName().str() + "convert_back" , NV, |
219 | srcKind) |
220 | ->getResult(); |
221 | |
222 | // We need to specify to skip replacing convertToFloat16 here as otherwise we |
223 | // will create a cycle in the graph. |
224 | S.getOutput().replaceAllUsesOfWith(convertBack, &F, convertToFloat16); |
225 | } |
226 | |
227 | void TypeAToTypeBFunctionConverter::convertAndClipStorage() { |
228 | if (precConfig_.convertPlaceholdersToFP16) { |
229 | for (Placeholder *PH : function_.findPlaceholders()) { |
230 | // If the PH is not used as an input then we do not clip it. |
231 | if (!isInput(PH, function_)) { |
232 | continue; |
233 | } |
234 | convertAndClipStorageHelper( |
235 | *PH, function_, |
236 | precConfig_.clipFP16 && !precConfig_.clipFP16SkipInputs, |
237 | precConfig_.float16Format, srcKind_, dstKind_); |
238 | } |
239 | } |
240 | if (precConfig_.convertConstantsToFP16) { |
241 | for (Constant *C : function_.findConstants()) { |
242 | if (precConfig_.skipBiasFp32tofp16Convert && C->hasOneUse() && |
243 | isBiasInput<FullyConnectedNode>(C->getOutput(), |
244 | C->getUsers().front().getUser())) { |
245 | continue; |
246 | } |
247 | convertAndClipStorageHelper(*C, function_, precConfig_.clipFP16, |
248 | precConfig_.float16Format, srcKind_, |
249 | dstKind_); |
250 | } |
251 | } |
252 | } |
253 | |
254 | template <class T> |
255 | bool TypeAToTypeBFunctionConverter::isBiasInput(NodeValue input, |
256 | const Node *N) { |
257 | auto *castedN = llvm::dyn_cast<T>(N); |
258 | return castedN && castedN->getBias() == input; |
259 | } |
260 | |