1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Converter/FusedRowwiseConverter.h"
18
19#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
20#include "glow/Graph/Graph.h"
21
22using namespace glow;
23
24/// Helper to pass over all Nodes in \p F and look for data param of
25/// FusedRowwiseSLWS with type UInt4FusedFP16QTy (if \p convertUInt4FP16 is
26/// true) or UInt8FusedFP16QTy (if \p convertUInt8FP16 is true), and convert
27/// the scale/offset to fp32.
28static void convertFusedRowwiseQuantizedData(Function *F, bool convertUInt4FP16,
29 bool convertUInt8FP16) {
30 auto &mod = *F->getParent();
31
32 // Iterate from original end to beginning to avoid processing new
33 // ConvertToNodes added during the pass.
34 auto nodeIt = F->getNodes().end();
35 auto stopIt = F->getNodes().begin();
36 do {
37 --nodeIt;
38 Node &node = *nodeIt;
39 unsigned idx = 0;
40 // Only convert FusedRowwiseQuantizedSparseLengthsWeightedSumNode and
41 // FusedRowwiseQuantizedSparseLengthsSumNode.
42 switch (node.getKind()) {
43 case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
44 idx = FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx;
45 break;
46 }
47
48 case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind: {
49 idx = FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx;
50 break;
51 }
52
53 case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
54 idx = FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx;
55 break;
56 }
57 default:
58 continue;
59 }
60
61 NodeValue data = node.getNthInput(idx);
62 auto dataType = data.getType()->getElementType();
63 if (dataType != ElemKind::UInt8FusedFP16QTy &&
64 dataType != ElemKind::UInt4FusedFP16QTy) {
65 continue;
66 }
67
68 if (dataType == ElemKind::UInt8FusedFP16QTy && !convertUInt8FP16) {
69 continue;
70 }
71
72 if (dataType == ElemKind::UInt4FusedFP16QTy && !convertUInt4FP16) {
73 continue;
74 }
75
76 const auto &shape = data.dims();
77 assert(shape.size() == 2 && "FusedRowwise Tensor must be 2D.");
78
79 const dim_t newCols =
80 (shape[1] - 2 * sizeof(float16_t)) + 2 * sizeof(float);
81
82 auto dT = (dataType == ElemKind::UInt8FusedFP16QTy)
83 ? ElemKind::UInt8FusedQTy
84 : ElemKind::UInt4FusedQTy;
85 auto OT = mod.uniqueType(dT, {shape[0], newCols}, 1.0, 0);
86 ConvertToNode *CN =
87 F->createConvertTo(data.getNode()->getName().str() + ".FP32", data, OT);
88 node.setNthInput(idx, CN);
89 } while (nodeIt != stopIt);
90}
91
92void glow::convertFunctionToFP32ScaleOffset(
93 Function *F, const PrecisionConfiguration &precConfig) {
94 bool convertUInt4FP16 = precConfig.convert4BitFusedToFP32;
95 bool convertUInt8FP16 = precConfig.convert8BitFusedToFP32;
96 DCHECK(convertUInt4FP16 || convertUInt8FP16)
97 << "Expect to convert at least one of UInt4FusedFP16QTy or "
98 "UInt8FusedFP16QTy.";
99 convertFusedRowwiseQuantizedData(F, convertUInt4FP16, convertUInt8FP16);
100}
101
102void glow::convertFunctionIndicesToInt64(
103 Function *F, const PrecisionConfiguration &precConfig) {
104 DCHECK(precConfig.convertIndicesToInt64)
105 << "Should enable indices conversion.";
106 // Iterate from original end to beginning to avoid processing new
107 // ConvertToNodes added during the pass.
108 auto nodeIt = F->getNodes().end();
109 auto stopIt = F->getNodes().begin();
110 do {
111 --nodeIt;
112 Node &node = *nodeIt;
113 // Only convert FusedRowwiseQuantizedSparseLengthsWeightedSumNode;
114 if (node.getKind() !=
115 Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind) {
116 continue;
117 }
118 auto idx = FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx;
119 NodeValue indices = node.getNthInput(idx);
120 auto indicesType = indices.getType()->getElementType();
121 if (indicesType == ElemKind::Int64ITy) {
122 continue;
123 }
124 DCHECK(indicesType == ElemKind::Int32ITy) << "Indices must be Int32ITy.";
125 ConvertToNode *CN =
126 F->createConvertTo(indices.getNode()->getName().str() + ".Int64",
127 indices, ElemKind::Int64ITy);
128 node.setNthInput(idx, CN);
129 } while (nodeIt != stopIt);
130}
131