1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Support/ZipUtils.h" |
18 | #include "onnx/onnx_pb.h" |
19 | |
20 | #include "llvm/Support/CommandLine.h" |
21 | #include "llvm/Support/Signals.h" |
22 | |
23 | #include "google/protobuf/io/coded_stream.h" |
24 | #include "google/protobuf/io/zero_copy_stream_impl.h" |
25 | |
26 | #include <glog/logging.h> |
27 | |
28 | #include <list> |
29 | #include <random> |
30 | #include <sstream> |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <unordered_set> |
34 | |
35 | namespace { |
36 | llvm::cl::OptionCategory scramblerCat("Scrambler Category" ); |
37 | llvm::cl::opt<std::string> |
38 | inputModelPathOpt("input_model" , llvm::cl::desc("Input model zip file" ), |
39 | llvm::cl::Required, llvm::cl::cat(scramblerCat)); |
40 | llvm::cl::opt<std::string> |
41 | outputModelPathOpt("output_model" , llvm::cl::desc("Output model zip file" ), |
42 | llvm::cl::Required, llvm::cl::cat(scramblerCat)); |
43 | llvm::cl::opt<std::string> inputDeferredWeightsPathOpt( |
44 | "input_deferred_weights" , |
45 | llvm::cl::desc("Path to the input deferred weights file" ), |
46 | llvm::cl::Optional, llvm::cl::init("" ), llvm::cl::cat(scramblerCat)); |
47 | llvm::cl::opt<std::string> outputDeferredWeightsPathOpt( |
48 | "output_deferred_weights" , |
49 | llvm::cl::desc("Path to the output deferred weights file" ), |
50 | llvm::cl::Optional, llvm::cl::init("" ), llvm::cl::cat(scramblerCat)); |
51 | llvm::cl::opt<std::string> |
52 | inputPatternOpt("inputs_pattern" , |
53 | llvm::cl::desc("Input file pattern. in_{}.onnx" ), |
54 | llvm::cl::init("" ), llvm::cl::cat(scramblerCat)); |
55 | llvm::cl::opt<std::string> |
56 | outputPatternOpt("outputs_pattern" , |
57 | llvm::cl::desc("Output file pattern. out_{}.onnx" ), |
58 | llvm::cl::init("" ), llvm::cl::cat(scramblerCat)); |
59 | llvm::cl::opt<unsigned> seqStartOpt( |
60 | "seq_start" , llvm::cl::desc("Start index of input/output files" ), |
61 | llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(scramblerCat)); |
62 | llvm::cl::opt<unsigned> seqLenOpt( |
63 | "seq_len" , llvm::cl::desc("Lengths of the input/output file seqquence." ), |
64 | llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(scramblerCat)); |
65 | llvm::cl::opt<unsigned> |
66 | methodOpt("method" , |
67 | llvm::cl::desc( |
68 | "Scrambling method: 0: simple tag; 1: pad to the same length;" |
69 | "2: same lengths, change uppercase letters to a new random " |
70 | "upper case letter" ), |
71 | llvm::cl::Optional, llvm::cl::init(0), |
72 | llvm::cl::cat(scramblerCat)); |
73 | } // namespace |
74 | |
75 | using namespace glow; |
76 | |
77 | constexpr size_t MAX_PROTO_SIZE = 0x7FFFFFFF; |
78 | constexpr bool kCompressed = false; |
79 | constexpr int kMaxTrial = 1000; |
80 | |
81 | void scrambleMethod2(std::string &str) { |
82 | static std::random_device rd; |
83 | static std::mt19937 gen(rd()); |
84 | static std::unordered_set<std::string> used_names; |
85 | std::uniform_int_distribution<> dis(0, 25); |
86 | for (int trial = 0; trial < kMaxTrial; ++trial) { |
87 | std::transform(str.begin(), str.end(), str.begin(), [&](char c) { |
88 | if (c >= 'A' && c <= 'Z') { |
89 | return static_cast<char>('A' + dis(gen)); |
90 | } else { |
91 | return c; |
92 | } |
93 | }); |
94 | if (used_names.emplace(str).second) { |
95 | return; |
96 | } |
97 | } |
98 | LOG(FATAL) << "Bad luck. Cannot find a unique random name. Try run me again!" ; |
99 | } |
100 | |
101 | std::string makeNewName(const std::string &in) { |
102 | static size_t idx = 0; |
103 | std::stringstream ss; |
104 | if (methodOpt == 2) { |
105 | std::string out = in; |
106 | scrambleMethod2(out); |
107 | return out; |
108 | } else if (methodOpt == 1) { |
109 | ss << idx++; |
110 | std::string tail = ss.str(); |
111 | std::string out(in.size() - tail.size(), 'X'); |
112 | out += tail; |
113 | if (out.size() != in.size()) { |
114 | LOG(WARNING) << "Cannot pad to the same length of " << in; |
115 | } |
116 | return out; |
117 | } else { |
118 | ss << "X__" << idx++; |
119 | return ss.str(); |
120 | } |
121 | } |
122 | |
123 | std::string makeNewNodeName(const std::string &in) { |
124 | static size_t idx = 0; |
125 | std::stringstream ss; |
126 | if (methodOpt == 1) { |
127 | ss << idx++; |
128 | std::string tail = ss.str(); |
129 | std::string out(in.size() - tail.size(), 'N'); |
130 | out += tail; |
131 | if (out.size() != in.size()) { |
132 | LOG(WARNING) << "Cannot pad to the same length of " << in; |
133 | } |
134 | return out; |
135 | } else { |
136 | ss << "N__" << idx++; |
137 | return ss.str(); |
138 | } |
139 | } |
140 | |
141 | bool parseIO(const std::string &filename, ::ONNX_NAMESPACE::GraphProto &g) { |
142 | std::ifstream ff(filename, std::ios::in | std::ios::binary); |
143 | if (!ff) { |
144 | return false; |
145 | } |
146 | google::protobuf::io::IstreamInputStream fileStream(&ff); |
147 | google::protobuf::io::CodedInputStream codedStream(&fileStream); |
148 | #if GOOGLE_PROTOBUF_VERSION >= 3002000 |
149 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE); |
150 | #else |
151 | codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE); |
152 | #endif |
153 | bool yes = g.ParseFromCodedStream(&codedStream); |
154 | if (!yes) { |
155 | return false; |
156 | } |
157 | return true; |
158 | } |
159 | |
160 | void rewriteIO(const std::string &filename, |
161 | std::unordered_map<std::string, std::string> &name_map) { |
162 | LOG(INFO) << "Reading file: " << filename; |
163 | ::ONNX_NAMESPACE::GraphProto g; |
164 | if (!parseIO(filename, g)) { |
165 | LOG(ERROR) << "Cannot open " << filename; |
166 | return; |
167 | } |
168 | for (auto &t : *g.mutable_initializer()) { |
169 | const auto &name = t.name(); |
170 | if (!name_map.count(name)) { |
171 | LOG(ERROR) << "It's very straight that input " << name |
172 | << " is not referenced in the net" ; |
173 | name_map.emplace(name, makeNewName(name)); |
174 | } |
175 | t.set_name(name_map.at(name)); |
176 | } |
177 | std::string new_filename = filename + ".2" ; |
178 | LOG(INFO) << "Writing new file: " << new_filename; |
179 | std::ofstream of(new_filename, |
180 | std::ios::out | std::ios::trunc | std::ios::binary); |
181 | if (!of) { |
182 | LOG(ERROR) << "Cannot open " << new_filename; |
183 | return; |
184 | } |
185 | std::string buffer; |
186 | g.SerializeToString(&buffer); |
187 | of << buffer; |
188 | } |
189 | |
190 | std::list<::ONNX_NAMESPACE::TensorProto> |
191 | readWeightsAndMaybeCopyData(ZipReader &zip, ZipWriter &zipO, bool compressed) { |
192 | std::list<::ONNX_NAMESPACE::TensorProto> weights; |
193 | auto numWeightsStr = zip.getRecord("weights" ); |
194 | size_t numWeights = 0; |
195 | numWeights = atoi(numWeightsStr.c_str()); |
196 | std::string buffer; |
197 | for (size_t i = 0; i < numWeights; ++i) { |
198 | std::stringstream ss; |
199 | ss << "weight_" << i; |
200 | buffer = zip.getRecord(ss.str()); |
201 | weights.emplace_back(); |
202 | auto &t = weights.back(); |
203 | t.ParseFromString(buffer); |
204 | |
205 | ss.str("" ); |
206 | ss << "data_" << i; |
207 | if (zip.hasRecord(ss.str())) { |
208 | buffer = zip.getRecord(ss.str()); |
209 | zipO.writeRecord(ss.str(), buffer.c_str(), buffer.size(), compressed); |
210 | } |
211 | } |
212 | return weights; |
213 | } |
214 | |
215 | void writeWeights(ZipWriter &zip, |
216 | const std::list<::ONNX_NAMESPACE::TensorProto> &weights, |
217 | bool compressed) { |
218 | std::stringstream ss; |
219 | ss << weights.size() << "\n" ; |
220 | zip.writeRecord("weights" , ss.str().c_str(), ss.str().size(), compressed); |
221 | std::string largeBuffer; |
222 | int i = 0; |
223 | // This part is probably quite inefficient as we are deserializing the |
224 | // protobuf to a char buffer and then put it to zip stream. I didn't dig |
225 | // enough to see if we can deserialize it into zip stream directly. |
226 | for (const auto &t : weights) { |
227 | std::stringstream nm; |
228 | nm << "weight_" << i++; |
229 | t.SerializeToString(&largeBuffer); |
230 | zip.writeRecord(nm.str(), largeBuffer.c_str(), largeBuffer.size(), |
231 | compressed); |
232 | } |
233 | } |
234 | |
235 | void scramble() { |
236 | LOG(INFO) << "Input model: " << inputModelPathOpt; |
237 | ::ONNX_NAMESPACE::ModelProto modelDef; |
238 | std::list<::ONNX_NAMESPACE::TensorProto> weights; |
239 | std::unordered_map<std::string, std::string> name_map; |
240 | std::unordered_map<std::string, std::string> node_map; |
241 | { |
242 | LOG(INFO) << "Writing output model to " << outputModelPathOpt; |
243 | std::ofstream ffO(outputModelPathOpt, |
244 | std::ios::out | std::ios::trunc | std::ios::binary); |
245 | CHECK(ffO); |
246 | ZipWriter zipO(&ffO, "test" ); |
247 | { |
248 | ZipReader zip(inputModelPathOpt); |
249 | std::string buffer; |
250 | buffer = zip.getRecord("model" ); |
251 | modelDef.ParseFromString(buffer); |
252 | weights = readWeightsAndMaybeCopyData(zip, zipO, kCompressed); |
253 | } |
254 | |
255 | auto *g = modelDef.mutable_graph(); |
256 | for (auto &n : *g->mutable_node()) { |
257 | for (auto &i : *n.mutable_input()) { |
258 | if (!name_map.count(i)) { |
259 | name_map.emplace(i, makeNewName(i)); |
260 | } |
261 | i = name_map.at(i); |
262 | } |
263 | for (auto &o : *n.mutable_output()) { |
264 | if (!name_map.count(o)) { |
265 | name_map.emplace(o, makeNewName(o)); |
266 | } |
267 | o = name_map.at(o); |
268 | } |
269 | const auto &name = n.name(); |
270 | if (!node_map.count(name)) { |
271 | node_map.emplace(name, makeNewNodeName(name)); |
272 | } |
273 | n.set_name(node_map.at(name)); |
274 | } |
275 | for (auto &i : *g->mutable_input()) { |
276 | const auto &name = i.name(); |
277 | if (!name_map.count(name)) { |
278 | name_map.emplace(name, makeNewName(name)); |
279 | } |
280 | i.set_name(name_map.at(name)); |
281 | } |
282 | for (auto &o : *g->mutable_output()) { |
283 | const auto &name = o.name(); |
284 | if (!name_map.count(name)) { |
285 | name_map.emplace(name, makeNewName(name)); |
286 | } |
287 | o.set_name(name_map.at(name)); |
288 | } |
289 | for (auto &t : weights) { |
290 | const auto &name = t.name(); |
291 | if (!name_map.count(name)) { |
292 | LOG(ERROR) << "It's a bit straight that weight " << name |
293 | << " is not referenced in the net" ; |
294 | name_map.emplace(name, makeNewName(name)); |
295 | } |
296 | t.set_name(name_map.at(name)); |
297 | } |
298 | // Look for attributes of a list of strings matching a name and swap for |
299 | // scrambled version. Note that this should be fine because we currently |
300 | // only use a list of strings for vector<NodeValue>. |
301 | for (auto &n : *g->mutable_node()) { |
302 | for (auto &a : *n.mutable_attribute()) { |
303 | if (a.name() == "Predicate" ) { |
304 | LOG(FATAL) << "Predicate NodeValue unhandled." ; |
305 | } |
306 | for (auto &s : *a.mutable_strings()) { |
307 | if (name_map.count(s)) { |
308 | s = name_map[s]; |
309 | } |
310 | } |
311 | } |
312 | } |
313 | |
314 | writeWeights(zipO, weights, kCompressed); |
315 | std::string largeBuffer; |
316 | modelDef.SerializeToString(&largeBuffer); |
317 | zipO.writeRecord("model" , largeBuffer.c_str(), largeBuffer.size(), |
318 | kCompressed); |
319 | zipO.writeEndOfFile(); |
320 | ffO.flush(); |
321 | ffO.close(); |
322 | } |
323 | |
324 | if (!inputDeferredWeightsPathOpt.empty()) { |
325 | weights.clear(); |
326 | // Open the zip writer first in case we need to copy raw tensor data |
327 | // directly from zip reader |
328 | LOG(INFO) << "Writing output deferred weights to " |
329 | << outputDeferredWeightsPathOpt; |
330 | std::ofstream ffO(outputDeferredWeightsPathOpt, |
331 | std::ios::out | std::ios::trunc | std::ios::binary); |
332 | CHECK(ffO); |
333 | ZipWriter zipO(&ffO, "test" ); |
334 | |
335 | { |
336 | LOG(INFO) << "Input deferred weights: " << inputDeferredWeightsPathOpt; |
337 | ZipReader zip(inputDeferredWeightsPathOpt); |
338 | weights = readWeightsAndMaybeCopyData(zip, zipO, kCompressed); |
339 | for (auto &t : weights) { |
340 | const auto &name = t.name(); |
341 | if (!name_map.count(name)) { |
342 | LOG(ERROR) << "It's very straight that weight " << name |
343 | << " is not referenced in the net" ; |
344 | name_map.emplace(name, makeNewName(name)); |
345 | } |
346 | t.set_name(name_map.at(name)); |
347 | } |
348 | } |
349 | |
350 | writeWeights(zipO, weights, kCompressed); |
351 | zipO.writeEndOfFile(); |
352 | ffO.flush(); |
353 | ffO.close(); |
354 | } |
355 | |
356 | size_t input_iter = inputPatternOpt.find("{}" ); |
357 | CHECK_NE(input_iter, std::string::npos) |
358 | << "Input pattern " << inputPatternOpt << " has to contain {}" ; |
359 | size_t output_iter = outputPatternOpt.find("{}" ); |
360 | CHECK_NE(output_iter, std::string::npos) |
361 | << "Output pattern " << outputPatternOpt << " has to contain {}" ; |
362 | for (unsigned i = seqStartOpt; i < seqLenOpt; ++i) { |
363 | std::string input = inputPatternOpt; |
364 | input.replace(input_iter, 2, std::to_string(seqStartOpt + i)); |
365 | rewriteIO(input, name_map); |
366 | std::string output = outputPatternOpt; |
367 | output.replace(output_iter, 2, std::to_string(seqStartOpt + i)); |
368 | rewriteIO(output, name_map); |
369 | } |
370 | } |
371 | |
372 | void parseCommandLine(int argc, char **argv) { |
373 | llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); |
374 | llvm::cl::ParseCommandLineOptions(argc, argv, |
375 | "The name scrambler\n\n" |
376 | "Scramble the name for repro files" ); |
377 | } |
378 | |
379 | int main(int argc, char **argv) { |
380 | parseCommandLine(argc, argv); |
381 | scramble(); |
382 | return 0; |
383 | } |
384 | |