1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Converter/FusedRowwiseConverter.h" |
18 | |
19 | #include "glow/Converter/TypeAToTypeBFunctionConverter.h" |
20 | #include "glow/Graph/Graph.h" |
21 | |
22 | using namespace glow; |
23 | |
24 | /// Helper to pass over all Nodes in \p F and look for data param of |
25 | /// FusedRowwiseSLWS with type UInt4FusedFP16QTy (if \p convertUInt4FP16 is |
26 | /// true) or UInt8FusedFP16QTy (if \p convertUInt8FP16 is true), and convert |
27 | /// the scale/offset to fp32. |
28 | static void convertFusedRowwiseQuantizedData(Function *F, bool convertUInt4FP16, |
29 | bool convertUInt8FP16) { |
30 | auto &mod = *F->getParent(); |
31 | |
32 | // Iterate from original end to beginning to avoid processing new |
33 | // ConvertToNodes added during the pass. |
34 | auto nodeIt = F->getNodes().end(); |
35 | auto stopIt = F->getNodes().begin(); |
36 | do { |
37 | --nodeIt; |
38 | Node &node = *nodeIt; |
39 | unsigned idx = 0; |
40 | // Only convert FusedRowwiseQuantizedSparseLengthsWeightedSumNode and |
41 | // FusedRowwiseQuantizedSparseLengthsSumNode. |
42 | switch (node.getKind()) { |
43 | case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: { |
44 | idx = FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx; |
45 | break; |
46 | } |
47 | |
48 | case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind: { |
49 | idx = FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx; |
50 | break; |
51 | } |
52 | |
53 | case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: { |
54 | idx = FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx; |
55 | break; |
56 | } |
57 | default: |
58 | continue; |
59 | } |
60 | |
61 | NodeValue data = node.getNthInput(idx); |
62 | auto dataType = data.getType()->getElementType(); |
63 | if (dataType != ElemKind::UInt8FusedFP16QTy && |
64 | dataType != ElemKind::UInt4FusedFP16QTy) { |
65 | continue; |
66 | } |
67 | |
68 | if (dataType == ElemKind::UInt8FusedFP16QTy && !convertUInt8FP16) { |
69 | continue; |
70 | } |
71 | |
72 | if (dataType == ElemKind::UInt4FusedFP16QTy && !convertUInt4FP16) { |
73 | continue; |
74 | } |
75 | |
76 | const auto &shape = data.dims(); |
77 | assert(shape.size() == 2 && "FusedRowwise Tensor must be 2D." ); |
78 | |
79 | const dim_t newCols = |
80 | (shape[1] - 2 * sizeof(float16_t)) + 2 * sizeof(float); |
81 | |
82 | auto dT = (dataType == ElemKind::UInt8FusedFP16QTy) |
83 | ? ElemKind::UInt8FusedQTy |
84 | : ElemKind::UInt4FusedQTy; |
85 | auto OT = mod.uniqueType(dT, {shape[0], newCols}, 1.0, 0); |
86 | ConvertToNode *CN = |
87 | F->createConvertTo(data.getNode()->getName().str() + ".FP32" , data, OT); |
88 | node.setNthInput(idx, CN); |
89 | } while (nodeIt != stopIt); |
90 | } |
91 | |
92 | void glow::convertFunctionToFP32ScaleOffset( |
93 | Function *F, const PrecisionConfiguration &precConfig) { |
94 | bool convertUInt4FP16 = precConfig.convert4BitFusedToFP32; |
95 | bool convertUInt8FP16 = precConfig.convert8BitFusedToFP32; |
96 | DCHECK(convertUInt4FP16 || convertUInt8FP16) |
97 | << "Expect to convert at least one of UInt4FusedFP16QTy or " |
98 | "UInt8FusedFP16QTy." ; |
99 | convertFusedRowwiseQuantizedData(F, convertUInt4FP16, convertUInt8FP16); |
100 | } |
101 | |
102 | void glow::convertFunctionIndicesToInt64( |
103 | Function *F, const PrecisionConfiguration &precConfig) { |
104 | DCHECK(precConfig.convertIndicesToInt64) |
105 | << "Should enable indices conversion." ; |
106 | // Iterate from original end to beginning to avoid processing new |
107 | // ConvertToNodes added during the pass. |
108 | auto nodeIt = F->getNodes().end(); |
109 | auto stopIt = F->getNodes().begin(); |
110 | do { |
111 | --nodeIt; |
112 | Node &node = *nodeIt; |
113 | // Only convert FusedRowwiseQuantizedSparseLengthsWeightedSumNode; |
114 | if (node.getKind() != |
115 | Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind) { |
116 | continue; |
117 | } |
118 | auto idx = FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx; |
119 | NodeValue indices = node.getNthInput(idx); |
120 | auto indicesType = indices.getType()->getElementType(); |
121 | if (indicesType == ElemKind::Int64ITy) { |
122 | continue; |
123 | } |
124 | DCHECK(indicesType == ElemKind::Int32ITy) << "Indices must be Int32ITy." ; |
125 | ConvertToNode *CN = |
126 | F->createConvertTo(indices.getNode()->getName().str() + ".Int64" , |
127 | indices, ElemKind::Int64ITy); |
128 | node.setNthInput(idx, CN); |
129 | } while (nodeIt != stopIt); |
130 | } |
131 | |