1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Quantization/Serialization.h" |
18 | #include "glow/Quantization/Base/Base.h" |
19 | #include "glow/Support/Support.h" |
20 | |
21 | #include "llvm/Support/FileSystem.h" |
22 | #include "llvm/Support/YAMLParser.h" |
23 | #include "llvm/Support/YAMLTraits.h" |
24 | #include "llvm/Support/raw_ostream.h" |
25 | |
26 | #include <glog/logging.h> |
27 | |
28 | /// Yaml serializer for the Glow tools version. |
29 | LLVM_YAML_STRONG_TYPEDEF(std::string, YAMLGlowToolsVersion) |
30 | |
31 | /// Yaml serializer for the graph hash code. |
32 | LLVM_YAML_STRONG_TYPEDEF(llvm::yaml::Hex64, YAMLGraphPreLowerHash) |
33 | |
34 | /// Yaml serializer for vector of NodeProfilingInfo. |
35 | LLVM_YAML_IS_SEQUENCE_VECTOR(glow::NodeProfilingInfo); |
36 | |
37 | namespace llvm { |
38 | namespace yaml { |
39 | |
40 | /// The default behavior of YAML is to serialize floating point numbers |
41 | /// using the "%g" format specifier which is not guaranteed to print all |
42 | /// the decimals. During a round-trip (serialize, deserialize) decimals |
43 | /// might be lost and hence precision is lost. Although this might not be |
44 | /// critical for some quantization schema, for "SymmetricWithPower2Scale" |
45 | /// the round-trip must preserve the exact representation of the floating |
46 | /// point scale which is a power of 2. The code below is a workaround to |
47 | /// overwrite the behavior of the YAML serializer to print all the digits. |
48 | struct FloatWrapper { |
49 | float val_; |
50 | FloatWrapper(float val) : val_(val) {} |
51 | }; |
52 | |
53 | template <> struct ScalarTraits<FloatWrapper> { |
54 | static void output(const FloatWrapper &value, void *ctxt, |
55 | llvm::raw_ostream &out) { |
56 | // Print number with all the digits and without trailing 0's |
57 | char buffer[200]; |
58 | snprintf(buffer, sizeof(buffer), "%.126f" , value.val_); |
59 | int n = strlen(buffer) - 1; |
60 | while ((n > 0) && (buffer[n] == '0') && (buffer[n - 1] != '.')) { |
61 | buffer[n--] = '\0'; |
62 | } |
63 | out << buffer; |
64 | } |
65 | static StringRef input(StringRef scalar, void *ctxt, FloatWrapper &value) { |
66 | if (to_float(scalar, value.val_)) |
67 | return StringRef(); |
68 | return "invalid floating point number" ; |
69 | } |
70 | static QuotingType mustQuote(StringRef) { return QuotingType::None; } |
71 | }; |
72 | |
73 | /// Mapping for YAMLGlowToolsVersion yaml serializer. |
74 | template <> struct MappingTraits<YAMLGlowToolsVersion> { |
75 | static void mapping(IO &io, YAMLGlowToolsVersion &ver) { |
76 | io.mapRequired("GlowToolsVersion" , ver.value); |
77 | } |
78 | }; |
79 | |
80 | /// Mapping for YAMLGraphPreLowerHash yaml serializer. |
81 | template <> struct MappingTraits<YAMLGraphPreLowerHash> { |
82 | static void mapping(IO &io, YAMLGraphPreLowerHash &hash) { |
83 | io.mapRequired("GraphPreLowerHash" , hash.value); |
84 | } |
85 | }; |
86 | |
87 | /// Mapping for NodeProfilingInfo yaml serializer. |
88 | template <> struct MappingTraits<glow::NodeProfilingInfo> { |
89 | struct FloatNormalized { |
90 | FloatNormalized(IO &io) : val_(0.0) {} |
91 | FloatNormalized(IO &, float &val) : val_(val) {} |
92 | float denormalize(IO &) { return val_.val_; } |
93 | FloatWrapper val_; |
94 | }; |
95 | static void mapping(IO &io, glow::NodeProfilingInfo &info) { |
96 | MappingNormalization<FloatNormalized, float> min( |
97 | io, info.tensorProfilingParams_.min); |
98 | MappingNormalization<FloatNormalized, float> max( |
99 | io, info.tensorProfilingParams_.max); |
100 | io.mapRequired("NodeOutputName" , info.nodeOutputName_); |
101 | io.mapRequired("Min" , min->val_); |
102 | io.mapRequired("Max" , max->val_); |
103 | io.mapRequired("Histogram" , info.tensorProfilingParams_.histogram); |
104 | } |
105 | }; |
106 | |
107 | } // end namespace yaml |
108 | } // end namespace llvm |
109 | |
110 | namespace glow { |
111 | |
112 | void serializeProfilingInfosToYaml( |
113 | llvm::StringRef fileName, llvm::hash_code graphPreLowerHash, |
114 | std::vector<NodeProfilingInfo> &profilingInfos) { |
115 | |
116 | // Open YAML output stream. |
117 | std::error_code EC; |
118 | llvm::raw_fd_ostream outputStream(fileName, EC, llvm::sys::fs::F_None); |
119 | CHECK(!EC) << "Error opening YAML file '" << fileName.str() << "'!" ; |
120 | llvm::yaml::Output yout(outputStream); |
121 | |
122 | // Write Glow tools version. |
123 | #ifdef GLOW_VERSION |
124 | YAMLGlowToolsVersion yamlVersion = YAMLGlowToolsVersion(GLOW_VERSION); |
125 | #else |
126 | YAMLGlowToolsVersion yamlVersion = YAMLGlowToolsVersion("" ); |
127 | #endif |
128 | yout << yamlVersion; |
129 | |
130 | // Write graph hash. |
131 | auto uint64Hash = static_cast<uint64_t>(graphPreLowerHash); |
132 | YAMLGraphPreLowerHash yamlHash = llvm::yaml::Hex64(uint64Hash); |
133 | yout << yamlHash; |
134 | |
135 | // Write profiling info. |
136 | yout << profilingInfos; |
137 | } |
138 | |
139 | bool deserializeProfilingInfosFromYaml( |
140 | llvm::StringRef fileName, llvm::hash_code &graphPreLowerHash, |
141 | std::vector<NodeProfilingInfo> &profilingInfos) { |
142 | |
143 | if (!llvm::sys::fs::exists(fileName)) { |
144 | return false; |
145 | } |
146 | |
147 | // Open YAML input stream. |
148 | llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> text = |
149 | llvm::MemoryBuffer::getFileAsStream(fileName); |
150 | CHECK(!text.getError()) << "Unable to open file with name: " |
151 | << fileName.str(); |
152 | std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(*text); |
153 | llvm::yaml::Input yin(buffer->getBuffer()); |
154 | |
155 | // Error message in case of incorrect profile format. |
156 | std::string profileErrMsg = |
157 | strFormat("Error reading YAML file '%s'!" , fileName.data()); |
158 | #ifdef GLOW_VERSION |
159 | profileErrMsg += strFormat(" Verify that the YAML file was generated with " |
160 | "the current version (%s) of the Glow tools!" , |
161 | GLOW_VERSION); |
162 | #endif |
163 | |
164 | // Read Glow tools version. |
165 | YAMLGlowToolsVersion yamlVersion; |
166 | yin >> yamlVersion; |
167 | CHECK(yin.nextDocument()) << profileErrMsg; |
168 | |
169 | // Read graph hash. |
170 | YAMLGraphPreLowerHash hash; |
171 | yin >> hash; |
172 | graphPreLowerHash = llvm::hash_code(static_cast<size_t>(hash.value)); |
173 | CHECK(yin.nextDocument()) << profileErrMsg; |
174 | |
175 | // Read profiling info. |
176 | yin >> profilingInfos; |
177 | CHECK(!yin.error()) << profileErrMsg; |
178 | |
179 | for (const auto &PI : profilingInfos) { |
180 | CHECK_LE(PI.min(), PI.max()) |
181 | << "Bad profile for node " << PI.nodeOutputName_.c_str(); |
182 | } |
183 | |
184 | return true; |
185 | } |
186 | |
187 | } // namespace glow |
188 | |