1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Base/Type.h" |
18 | #include "llvm/Support/NativeFormatting.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | namespace glow { |
22 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Type &type) { |
23 | os << type.getElementName(); |
24 | |
25 | if (type.isQuantizedType()) { |
26 | os << "[S:" ; |
27 | llvm::write_double(os, type.getScale(), llvm::FloatStyle::Fixed, 9); |
28 | os << " O:" ; |
29 | os << type.getOffset(); |
30 | os << ']'; |
31 | auto valueRange = type.getQuantizedValueRange(); |
32 | os << "[" ; |
33 | llvm::write_double(os, valueRange.first, llvm::FloatStyle::Fixed, 3); |
34 | os << "," ; |
35 | llvm::write_double(os, valueRange.second, llvm::FloatStyle::Fixed, 3); |
36 | os << "]" ; |
37 | } |
38 | |
39 | os << '<'; |
40 | for (unsigned i = 0; i < type.numSizes_; ++i) { |
41 | if (i) { |
42 | os << " x " ; |
43 | } |
44 | os << type.sizes_[i]; |
45 | if (type.numSizes_ >= 2 && i + 1 < type.numSizes_ && |
46 | type.strides_[i] != type.strides_[i + 1] * type.sizes_[i + 1]) { |
47 | assert(type.strides_[i] % type.strides_[i + 1] == 0); |
48 | // Print the alignment only if it is not 1. |
49 | os << ":" << (type.strides_[i] / type.strides_[i + 1]); |
50 | } |
51 | } |
52 | os << '>'; |
53 | |
54 | return os; |
55 | } |
56 | |
57 | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TypeRef &type) { |
58 | if (!type) { |
59 | return os << "<none>" ; |
60 | } |
61 | return os << *type; |
62 | } |
63 | |
64 | void Type::dump(llvm::raw_ostream &out) const { out << this; } |
65 | |
66 | void Type::dump() const { dump(llvm::outs()); } |
67 | |
68 | std::string Type::toString() const { |
69 | std::string storage; |
70 | llvm::raw_string_ostream os(storage); |
71 | os << this; |
72 | return os.str(); |
73 | } |
74 | |
75 | Type Type::fromString(llvm::StringRef str) { |
76 | |
77 | // Get element type. |
78 | std::pair<llvm::StringRef, llvm::StringRef> strPair; |
79 | auto strPair1 = str.split('<'); |
80 | auto strPair2 = str.split('['); |
81 | if (strPair1.first.size() < strPair2.first.size()) { |
82 | strPair = strPair1; |
83 | } else { |
84 | strPair = strPair2; |
85 | } |
86 | CHECK(strPair.first.size()) << "Type string element type field invalid!" ; |
87 | ElemKind elemTy = Type::getElementKindFromName(strPair.first); |
88 | |
89 | // Get scale and offset for quantized type. |
90 | double scale = 0; |
91 | int32_t offset = 0; |
92 | if (isQuantizedElemKind(elemTy)) { |
93 | // Get scale. |
94 | strPair = strPair.second.split(':').second.split(' '); |
95 | CHECK(!strPair.first.getAsDouble(scale)) |
96 | << "Type string scale field invalid!" ; |
97 | // Get offset. |
98 | strPair = strPair.second.split(':').second.split(']'); |
99 | CHECK(!strPair.first.getAsInteger(0, offset)) |
100 | << "Type string offset field invalid!" ; |
101 | // Ignore quantized min/max range. |
102 | strPair = strPair.second.split('<'); |
103 | } |
104 | |
105 | // Get shape. |
106 | llvm::StringRef shapeStr = strPair.second; |
107 | CHECK(shapeStr.size()) << "Type string shape field invalid!" ; |
108 | CHECK_EQ(shapeStr.back(), '>') << "Type string shape field invalid!" ; |
109 | shapeStr = shapeStr.drop_back(); |
110 | CHECK(shapeStr.size()) << "Type string shape field invalid!" ; |
111 | |
112 | // Add the delimiter in the end to have the loop self contained. |
113 | // Note: Type alignment field not supported. |
114 | std::string shapeStrExt = shapeStr.str() + " x" ; |
115 | shapeStr = llvm::StringRef(shapeStrExt); |
116 | ShapeVector dims; |
117 | while (shapeStr.contains('x')) { |
118 | auto splitRes = shapeStr.split('x'); |
119 | auto dimStr = splitRes.first.trim(); |
120 | CHECK(!dimStr.contains(':')) << "Type with alignment field not supported!" ; |
121 | dim_t dim; |
122 | CHECK(!dimStr.getAsInteger(0, dim)) << "Type string shape field invalid!" ; |
123 | dims.push_back(dim); |
124 | shapeStr = splitRes.second; |
125 | } |
126 | |
127 | // Return type. |
128 | if (isQuantizedElemKind(elemTy)) { |
129 | return Type(elemTy, dims, (float)scale, offset); |
130 | } else { |
131 | return Type(elemTy, dims); |
132 | } |
133 | } |
134 | |
135 | std::pair<float, float> getQuantizedValueRange(float scale, int32_t offset, |
136 | ElemKind elementType) { |
137 | assert(isQuantizedElemKind(elementType) && |
138 | "Can't get the quantized value range of a non-quantized type" ); |
139 | |
140 | int64_t low = 0, high = 0; |
141 | switch (elementType) { |
142 | case ElemKind::Int32QTy: { |
143 | low = INT32_MIN; |
144 | high = INT32_MAX; |
145 | break; |
146 | } |
147 | case ElemKind::Int16QTy: { |
148 | low = INT16_MIN; |
149 | high = INT16_MAX; |
150 | break; |
151 | } |
152 | case ElemKind::Int8QTy: { |
153 | low = INT8_MIN; |
154 | high = INT8_MAX; |
155 | break; |
156 | } |
157 | case ElemKind::UInt8QTy: { |
158 | low = UINT8_MIN; |
159 | high = UINT8_MAX; |
160 | break; |
161 | } |
162 | default:; |
163 | } |
164 | |
165 | float lowFloat = (low - offset) * scale; |
166 | float highFloat = (high - offset) * scale; |
167 | return std::make_pair(lowFloat, highFloat); |
168 | } |
169 | |
170 | } // namespace glow |
171 | |