1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Quantization/Serialization.h"
18#include "glow/Quantization/Base/Base.h"
19#include "glow/Support/Support.h"
20
21#include "llvm/Support/FileSystem.h"
22#include "llvm/Support/YAMLParser.h"
23#include "llvm/Support/YAMLTraits.h"
24#include "llvm/Support/raw_ostream.h"
25
26#include <glog/logging.h>
27
28/// Yaml serializer for the Glow tools version.
29LLVM_YAML_STRONG_TYPEDEF(std::string, YAMLGlowToolsVersion)
30
31/// Yaml serializer for the graph hash code.
32LLVM_YAML_STRONG_TYPEDEF(llvm::yaml::Hex64, YAMLGraphPreLowerHash)
33
34/// Yaml serializer for vector of NodeProfilingInfo.
35LLVM_YAML_IS_SEQUENCE_VECTOR(glow::NodeProfilingInfo);
36
37namespace llvm {
38namespace yaml {
39
40/// The default behavior of YAML is to serialize floating point numbers
41/// using the "%g" format specifier which is not guaranteed to print all
42/// the decimals. During a round-trip (serialize, deserialize) decimals
43/// might be lost and hence precision is lost. Although this might not be
44/// critical for some quantization schema, for "SymmetricWithPower2Scale"
45/// the round-trip must preserve the exact representation of the floating
46/// point scale which is a power of 2. The code below is a workaround to
47/// overwrite the behavior of the YAML serializer to print all the digits.
48struct FloatWrapper {
49 float val_;
50 FloatWrapper(float val) : val_(val) {}
51};
52
53template <> struct ScalarTraits<FloatWrapper> {
54 static void output(const FloatWrapper &value, void *ctxt,
55 llvm::raw_ostream &out) {
56 // Print number with all the digits and without trailing 0's
57 char buffer[200];
58 snprintf(buffer, sizeof(buffer), "%.126f", value.val_);
59 int n = strlen(buffer) - 1;
60 while ((n > 0) && (buffer[n] == '0') && (buffer[n - 1] != '.')) {
61 buffer[n--] = '\0';
62 }
63 out << buffer;
64 }
65 static StringRef input(StringRef scalar, void *ctxt, FloatWrapper &value) {
66 if (to_float(scalar, value.val_))
67 return StringRef();
68 return "invalid floating point number";
69 }
70 static QuotingType mustQuote(StringRef) { return QuotingType::None; }
71};
72
73/// Mapping for YAMLGlowToolsVersion yaml serializer.
74template <> struct MappingTraits<YAMLGlowToolsVersion> {
75 static void mapping(IO &io, YAMLGlowToolsVersion &ver) {
76 io.mapRequired("GlowToolsVersion", ver.value);
77 }
78};
79
80/// Mapping for YAMLGraphPreLowerHash yaml serializer.
81template <> struct MappingTraits<YAMLGraphPreLowerHash> {
82 static void mapping(IO &io, YAMLGraphPreLowerHash &hash) {
83 io.mapRequired("GraphPreLowerHash", hash.value);
84 }
85};
86
87/// Mapping for NodeProfilingInfo yaml serializer.
88template <> struct MappingTraits<glow::NodeProfilingInfo> {
89 struct FloatNormalized {
90 FloatNormalized(IO &io) : val_(0.0) {}
91 FloatNormalized(IO &, float &val) : val_(val) {}
92 float denormalize(IO &) { return val_.val_; }
93 FloatWrapper val_;
94 };
95 static void mapping(IO &io, glow::NodeProfilingInfo &info) {
96 MappingNormalization<FloatNormalized, float> min(
97 io, info.tensorProfilingParams_.min);
98 MappingNormalization<FloatNormalized, float> max(
99 io, info.tensorProfilingParams_.max);
100 io.mapRequired("NodeOutputName", info.nodeOutputName_);
101 io.mapRequired("Min", min->val_);
102 io.mapRequired("Max", max->val_);
103 io.mapRequired("Histogram", info.tensorProfilingParams_.histogram);
104 }
105};
106
107} // end namespace yaml
108} // end namespace llvm
109
110namespace glow {
111
112void serializeProfilingInfosToYaml(
113 llvm::StringRef fileName, llvm::hash_code graphPreLowerHash,
114 std::vector<NodeProfilingInfo> &profilingInfos) {
115
116 // Open YAML output stream.
117 std::error_code EC;
118 llvm::raw_fd_ostream outputStream(fileName, EC, llvm::sys::fs::F_None);
119 CHECK(!EC) << "Error opening YAML file '" << fileName.str() << "'!";
120 llvm::yaml::Output yout(outputStream);
121
122 // Write Glow tools version.
123#ifdef GLOW_VERSION
124 YAMLGlowToolsVersion yamlVersion = YAMLGlowToolsVersion(GLOW_VERSION);
125#else
126 YAMLGlowToolsVersion yamlVersion = YAMLGlowToolsVersion("");
127#endif
128 yout << yamlVersion;
129
130 // Write graph hash.
131 auto uint64Hash = static_cast<uint64_t>(graphPreLowerHash);
132 YAMLGraphPreLowerHash yamlHash = llvm::yaml::Hex64(uint64Hash);
133 yout << yamlHash;
134
135 // Write profiling info.
136 yout << profilingInfos;
137}
138
139bool deserializeProfilingInfosFromYaml(
140 llvm::StringRef fileName, llvm::hash_code &graphPreLowerHash,
141 std::vector<NodeProfilingInfo> &profilingInfos) {
142
143 if (!llvm::sys::fs::exists(fileName)) {
144 return false;
145 }
146
147 // Open YAML input stream.
148 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> text =
149 llvm::MemoryBuffer::getFileAsStream(fileName);
150 CHECK(!text.getError()) << "Unable to open file with name: "
151 << fileName.str();
152 std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(*text);
153 llvm::yaml::Input yin(buffer->getBuffer());
154
155 // Error message in case of incorrect profile format.
156 std::string profileErrMsg =
157 strFormat("Error reading YAML file '%s'!", fileName.data());
158#ifdef GLOW_VERSION
159 profileErrMsg += strFormat(" Verify that the YAML file was generated with "
160 "the current version (%s) of the Glow tools!",
161 GLOW_VERSION);
162#endif
163
164 // Read Glow tools version.
165 YAMLGlowToolsVersion yamlVersion;
166 yin >> yamlVersion;
167 CHECK(yin.nextDocument()) << profileErrMsg;
168
169 // Read graph hash.
170 YAMLGraphPreLowerHash hash;
171 yin >> hash;
172 graphPreLowerHash = llvm::hash_code(static_cast<size_t>(hash.value));
173 CHECK(yin.nextDocument()) << profileErrMsg;
174
175 // Read profiling info.
176 yin >> profilingInfos;
177 CHECK(!yin.error()) << profileErrMsg;
178
179 for (const auto &PI : profilingInfos) {
180 CHECK_LE(PI.min(), PI.max())
181 << "Bad profile for node " << PI.nodeOutputName_.c_str();
182 }
183
184 return true;
185}
186
187} // namespace glow
188