1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
18
19#include "glow/Base/Tensor.h"
20#include "glow/Graph/Graph.h"
21
22using namespace glow;
23
24TypeAToTypeBFunctionConverter::TypeAToTypeBFunctionConverter(
25 Function &F, ElemKind fromKind, ElemKind toKind,
26 const PrecisionConfiguration &precConfig)
27 : FunctionConverter(F), mod_(*F.getParent()), dstKind_(toKind),
28 srcKind_(fromKind), precConfig_(precConfig) {}
29
30bool TypeAToTypeBFunctionConverter::canConvert(const Node &node) const {
31 // For some ops, if we're converting to FP16/BFloat16 and the bias is FP32 and
32 // the input is quantized then don't convert to FP16/BFloat16.
33 if (srcKind_ == ElemKind::FloatTy &&
34 (dstKind_ == ElemKind::Float16Ty || dstKind_ == ElemKind::BFloat16Ty)) {
35#define QUANT_INPUT_FLOAT_BIAS_CASE(NODE_NAME_) \
36 case glow::Kinded::Kind::NODE_NAME_##NodeKind: { \
37 auto *N = llvm::cast<NODE_NAME_##Node>(&node); \
38 if (N->getBias().getType()->getElementType() == ElemKind::FloatTy && \
39 N->getInput().getType()->isQuantizedType()) { \
40 return false; \
41 } \
42 break; \
43 }
44
45#define QUANT_OR_FP16_INPUT_FLOAT_BIAS_CASE(NODE_NAME_) \
46 case glow::Kinded::Kind::NODE_NAME_##NodeKind: { \
47 auto *N = llvm::cast<NODE_NAME_##Node>(&node); \
48 if (N->getBias().getType()->getElementType() == ElemKind::FloatTy && \
49 (N->getInput().getType()->isQuantizedType() || \
50 N->getInput().getType()->getElementType() == ElemKind::Float16Ty)) { \
51 return false; \
52 } \
53 break; \
54 }
55
56 switch (node.getKind()) {
57 QUANT_INPUT_FLOAT_BIAS_CASE(FullyConnected);
58 QUANT_INPUT_FLOAT_BIAS_CASE(RowwiseQuantizedFullyConnected);
59 QUANT_INPUT_FLOAT_BIAS_CASE(Convolution);
60 QUANT_INPUT_FLOAT_BIAS_CASE(ConvTranspose);
61 QUANT_INPUT_FLOAT_BIAS_CASE(Convolution3D);
62 QUANT_INPUT_FLOAT_BIAS_CASE(ChannelwiseQuantizedConvolution);
63 QUANT_OR_FP16_INPUT_FLOAT_BIAS_CASE(BatchNormalization);
64 default:
65 break;
66 }
67#undef QUANT_INPUT_FLOAT_BIAS_CASE
68 }
69
70 const bool inSet = precConfig_.precisionModeKindSet.count(node.getKind());
71 const bool allowConversion = precConfig_.useSetAsWhitelist ? inSet : !inSet;
72
73 if (!allowConversion) {
74 return false;
75 }
76 return FunctionConverter::canConvert(node);
77}
78
79TypeRef TypeAToTypeBFunctionConverter::getTargetTypeForOutput(
80 const NodeValue &out) const {
81 if (out.getType()->getElementType() != srcKind_) {
82 return nullptr;
83 }
84 return mod_.uniqueType(dstKind_, out.dims());
85}
86
87TypeRef
88TypeAToTypeBFunctionConverter::getTargetTypeForInput(const Node &use,
89 unsigned idx) const {
90#define IGNORE_CONVERT(nodeKind, inputIdx) \
91 if (use.getKind() == nodeKind && idx == inputIdx) { \
92 return nullptr; \
93 }
94 IGNORE_CONVERT(Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind,
95 DynamicQuantizedFullyConnectedNode::BiasIdx)
96 IGNORE_CONVERT(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind,
97 DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx)
98 IGNORE_CONVERT(Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind,
99 DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx)
100#undef IGNORE_CONVERT
101 return getTargetTypeForOutput(use.getNthInput(idx));
102}
103
104Node *TypeAToTypeBFunctionConverter::createConversion(Function &function,
105 const Node &node,
106 NodeValue &val,
107 TypeRef destTy,
108 bool isInput) {
109 assert(((destTy->getElementType() == dstKind_ &&
110 val.getType()->getElementType() == srcKind_) ||
111 (destTy->getElementType() == srcKind_ &&
112 val.getType()->getElementType() == dstKind_)) &&
113 "Unexpected conversion type");
114 bool needClip =
115 ((dstKind_ == ElemKind::Float16Ty || dstKind_ == ElemKind::BFloat16Ty) &&
116 precConfig_.clipFP16 && !(isInput && precConfig_.clipFP16SkipInputs));
117
118 if (precConfig_.skipBiasFp32tofp16Convert &&
119 isBiasInput<FullyConnectedNode>(val, &node)) {
120 return val;
121 }
122 if (needClip) {
123 switch (node.getKind()) {
124 case Kinded::Kind::ConcatNodeKind:
125 case Kinded::Kind::GatherNodeKind:
126 case Kinded::Kind::ReshapeNodeKind:
127 case Kinded::Kind::SliceNodeKind:
128 case Kinded::Kind::TileNodeKind:
129 case Kinded::Kind::TransposeNodeKind:
130 needClip = false;
131 break;
132 case Kinded::Kind::SigmoidNodeKind:
133 case Kinded::Kind::TanhNodeKind:
134 needClip = isInput;
135 break;
136 case Kinded::Kind::ConvertToNodeKind:
137 needClip = (llvm::dyn_cast<const ConvertToNode>(&node)
138 ->getInput()
139 .getElementType() == ElemKind::Float16Ty ||
140 llvm::dyn_cast<const ConvertToNode>(&node)
141 ->getInput()
142 .getElementType() == ElemKind::BFloat16Ty)
143 ? false
144 : true;
145 break;
146 default:
147 break;
148 }
149 }
150 if (needClip) {
151 assert((destTy->getElementType() == ElemKind::Float16Ty ||
152 val.getType()->getElementType() == ElemKind::Float16Ty ||
153 destTy->getElementType() == ElemKind::BFloat16Ty ||
154 val.getType()->getElementType() == ElemKind::BFloat16Ty) &&
155 "Unexpected conversion type");
156 // If the input is fp32 and output is fp16, then we want to do the convert
157 // before the clip. This way the clip can execute in fp16 mode.
158 if (destTy->getElementType() == ElemKind::Float16Ty &&
159 val.getType()->getElementType() == ElemKind::FloatTy) {
160 auto convert = function.createConvertTo(
161 val.getNode()->getName().str() + "_converted", val, destTy);
162 return function.createClipMinMaxFP16(
163 val.getNode()->getName().str() + "_clipped", convert);
164 } else if (destTy->getElementType() == ElemKind::BFloat16Ty &&
165 val.getType()->getElementType() == ElemKind::FloatTy) {
166 auto convert = function.createConvertTo(
167 val.getNode()->getName().str() + "_converted", val, destTy);
168 return function.createClipMinMaxBFloat16(
169 val.getNode()->getName().str() + "_clipped", convert);
170 } else {
171 auto clip = function.createClipMinMaxFP16(
172 val.getNode()->getName().str() + "_clipped", val);
173 return function.createConvertTo(
174 val.getNode()->getName().str() + "_converted", clip, destTy);
175 }
176 } else {
177 return function.createConvertTo(
178 val.getNode()->getName().str() + "_converted", val, destTy);
179 }
180}
181
182void TypeAToTypeBFunctionConverter::convertTensor(Tensor &tensor,
183 TypeRef destTy) {
184 assert(destTy->getElementType() == dstKind_);
185 tensor.convertToType(dstKind_);
186}
187
188void convertAndClipStorageHelper(
189 Storage &S, Function &F, bool clipFloat16,
190 PrecisionConfiguration::Float16Format float16Format, ElemKind srcKind,
191 ElemKind dstKind) {
192 if (S.getOutput().getType()->getElementType() != srcKind) {
193 return;
194 }
195
196 ConvertToNode *convertToFloat16 = F.createConvertTo(
197 S.getName().str() + "convert_to", S.getOutput(), dstKind);
198
199 NodeValue NV = convertToFloat16->getResult();
200 if (clipFloat16) {
201 switch (float16Format) {
202 case PrecisionConfiguration::Float16Format::FP16:
203 NV = F.createClipMinMaxFP16(S.getName().str() + "_clipped", NV)
204 ->getResult();
205 break;
206 case PrecisionConfiguration::Float16Format::BFloat16:
207 NV = F.createClipMinMaxBFloat16(S.getName().str() + "_clipped", NV)
208 ->getResult();
209 break;
210 default:
211 llvm_unreachable("Unknown float16 format");
212 }
213 }
214
215 // We have to convert back to the srcKind now as the users currently must be
216 // expecting FP32. The optimizer will remove if possible.
217 NodeValue convertBack =
218 F.createConvertTo(NV.getNode()->getName().str() + "convert_back", NV,
219 srcKind)
220 ->getResult();
221
222 // We need to specify to skip replacing convertToFloat16 here as otherwise we
223 // will create a cycle in the graph.
224 S.getOutput().replaceAllUsesOfWith(convertBack, &F, convertToFloat16);
225}
226
227void TypeAToTypeBFunctionConverter::convertAndClipStorage() {
228 if (precConfig_.convertPlaceholdersToFP16) {
229 for (Placeholder *PH : function_.findPlaceholders()) {
230 // If the PH is not used as an input then we do not clip it.
231 if (!isInput(PH, function_)) {
232 continue;
233 }
234 convertAndClipStorageHelper(
235 *PH, function_,
236 precConfig_.clipFP16 && !precConfig_.clipFP16SkipInputs,
237 precConfig_.float16Format, srcKind_, dstKind_);
238 }
239 }
240 if (precConfig_.convertConstantsToFP16) {
241 for (Constant *C : function_.findConstants()) {
242 if (precConfig_.skipBiasFp32tofp16Convert && C->hasOneUse() &&
243 isBiasInput<FullyConnectedNode>(C->getOutput(),
244 C->getUsers().front().getUser())) {
245 continue;
246 }
247 convertAndClipStorageHelper(*C, function_, precConfig_.clipFP16,
248 precConfig_.float16Format, srcKind_,
249 dstKind_);
250 }
251 }
252}
253
254template <class T>
255bool TypeAToTypeBFunctionConverter::isBiasInput(NodeValue input,
256 const Node *N) {
257 auto *castedN = llvm::dyn_cast<T>(N);
258 return castedN && castedN->getBias() == input;
259}
260