1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Support/ZipUtils.h"
18#include "onnx/onnx_pb.h"
19
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Signals.h"
22
23#include "google/protobuf/io/coded_stream.h"
24#include "google/protobuf/io/zero_copy_stream_impl.h"
25
26#include <glog/logging.h>
27
28#include <list>
29#include <random>
30#include <sstream>
31#include <string>
32#include <unordered_map>
33#include <unordered_set>
34
35namespace {
36llvm::cl::OptionCategory scramblerCat("Scrambler Category");
37llvm::cl::opt<std::string>
38 inputModelPathOpt("input_model", llvm::cl::desc("Input model zip file"),
39 llvm::cl::Required, llvm::cl::cat(scramblerCat));
40llvm::cl::opt<std::string>
41 outputModelPathOpt("output_model", llvm::cl::desc("Output model zip file"),
42 llvm::cl::Required, llvm::cl::cat(scramblerCat));
43llvm::cl::opt<std::string> inputDeferredWeightsPathOpt(
44 "input_deferred_weights",
45 llvm::cl::desc("Path to the input deferred weights file"),
46 llvm::cl::Optional, llvm::cl::init(""), llvm::cl::cat(scramblerCat));
47llvm::cl::opt<std::string> outputDeferredWeightsPathOpt(
48 "output_deferred_weights",
49 llvm::cl::desc("Path to the output deferred weights file"),
50 llvm::cl::Optional, llvm::cl::init(""), llvm::cl::cat(scramblerCat));
51llvm::cl::opt<std::string>
52 inputPatternOpt("inputs_pattern",
53 llvm::cl::desc("Input file pattern. in_{}.onnx"),
54 llvm::cl::init(""), llvm::cl::cat(scramblerCat));
55llvm::cl::opt<std::string>
56 outputPatternOpt("outputs_pattern",
57 llvm::cl::desc("Output file pattern. out_{}.onnx"),
58 llvm::cl::init(""), llvm::cl::cat(scramblerCat));
59llvm::cl::opt<unsigned> seqStartOpt(
60 "seq_start", llvm::cl::desc("Start index of input/output files"),
61 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(scramblerCat));
62llvm::cl::opt<unsigned> seqLenOpt(
63 "seq_len", llvm::cl::desc("Lengths of the input/output file seqquence."),
64 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(scramblerCat));
65llvm::cl::opt<unsigned>
66 methodOpt("method",
67 llvm::cl::desc(
68 "Scrambling method: 0: simple tag; 1: pad to the same length;"
69 "2: same lengths, change uppercase letters to a new random "
70 "upper case letter"),
71 llvm::cl::Optional, llvm::cl::init(0),
72 llvm::cl::cat(scramblerCat));
73} // namespace
74
75using namespace glow;
76
77constexpr size_t MAX_PROTO_SIZE = 0x7FFFFFFF;
78constexpr bool kCompressed = false;
79constexpr int kMaxTrial = 1000;
80
81void scrambleMethod2(std::string &str) {
82 static std::random_device rd;
83 static std::mt19937 gen(rd());
84 static std::unordered_set<std::string> used_names;
85 std::uniform_int_distribution<> dis(0, 25);
86 for (int trial = 0; trial < kMaxTrial; ++trial) {
87 std::transform(str.begin(), str.end(), str.begin(), [&](char c) {
88 if (c >= 'A' && c <= 'Z') {
89 return static_cast<char>('A' + dis(gen));
90 } else {
91 return c;
92 }
93 });
94 if (used_names.emplace(str).second) {
95 return;
96 }
97 }
98 LOG(FATAL) << "Bad luck. Cannot find a unique random name. Try run me again!";
99}
100
101std::string makeNewName(const std::string &in) {
102 static size_t idx = 0;
103 std::stringstream ss;
104 if (methodOpt == 2) {
105 std::string out = in;
106 scrambleMethod2(out);
107 return out;
108 } else if (methodOpt == 1) {
109 ss << idx++;
110 std::string tail = ss.str();
111 std::string out(in.size() - tail.size(), 'X');
112 out += tail;
113 if (out.size() != in.size()) {
114 LOG(WARNING) << "Cannot pad to the same length of " << in;
115 }
116 return out;
117 } else {
118 ss << "X__" << idx++;
119 return ss.str();
120 }
121}
122
123std::string makeNewNodeName(const std::string &in) {
124 static size_t idx = 0;
125 std::stringstream ss;
126 if (methodOpt == 1) {
127 ss << idx++;
128 std::string tail = ss.str();
129 std::string out(in.size() - tail.size(), 'N');
130 out += tail;
131 if (out.size() != in.size()) {
132 LOG(WARNING) << "Cannot pad to the same length of " << in;
133 }
134 return out;
135 } else {
136 ss << "N__" << idx++;
137 return ss.str();
138 }
139}
140
141bool parseIO(const std::string &filename, ::ONNX_NAMESPACE::GraphProto &g) {
142 std::ifstream ff(filename, std::ios::in | std::ios::binary);
143 if (!ff) {
144 return false;
145 }
146 google::protobuf::io::IstreamInputStream fileStream(&ff);
147 google::protobuf::io::CodedInputStream codedStream(&fileStream);
148#if GOOGLE_PROTOBUF_VERSION >= 3002000
149 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE);
150#else
151 codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE);
152#endif
153 bool yes = g.ParseFromCodedStream(&codedStream);
154 if (!yes) {
155 return false;
156 }
157 return true;
158}
159
160void rewriteIO(const std::string &filename,
161 std::unordered_map<std::string, std::string> &name_map) {
162 LOG(INFO) << "Reading file: " << filename;
163 ::ONNX_NAMESPACE::GraphProto g;
164 if (!parseIO(filename, g)) {
165 LOG(ERROR) << "Cannot open " << filename;
166 return;
167 }
168 for (auto &t : *g.mutable_initializer()) {
169 const auto &name = t.name();
170 if (!name_map.count(name)) {
171 LOG(ERROR) << "It's very straight that input " << name
172 << " is not referenced in the net";
173 name_map.emplace(name, makeNewName(name));
174 }
175 t.set_name(name_map.at(name));
176 }
177 std::string new_filename = filename + ".2";
178 LOG(INFO) << "Writing new file: " << new_filename;
179 std::ofstream of(new_filename,
180 std::ios::out | std::ios::trunc | std::ios::binary);
181 if (!of) {
182 LOG(ERROR) << "Cannot open " << new_filename;
183 return;
184 }
185 std::string buffer;
186 g.SerializeToString(&buffer);
187 of << buffer;
188}
189
190std::list<::ONNX_NAMESPACE::TensorProto>
191readWeightsAndMaybeCopyData(ZipReader &zip, ZipWriter &zipO, bool compressed) {
192 std::list<::ONNX_NAMESPACE::TensorProto> weights;
193 auto numWeightsStr = zip.getRecord("weights");
194 size_t numWeights = 0;
195 numWeights = atoi(numWeightsStr.c_str());
196 std::string buffer;
197 for (size_t i = 0; i < numWeights; ++i) {
198 std::stringstream ss;
199 ss << "weight_" << i;
200 buffer = zip.getRecord(ss.str());
201 weights.emplace_back();
202 auto &t = weights.back();
203 t.ParseFromString(buffer);
204
205 ss.str("");
206 ss << "data_" << i;
207 if (zip.hasRecord(ss.str())) {
208 buffer = zip.getRecord(ss.str());
209 zipO.writeRecord(ss.str(), buffer.c_str(), buffer.size(), compressed);
210 }
211 }
212 return weights;
213}
214
215void writeWeights(ZipWriter &zip,
216 const std::list<::ONNX_NAMESPACE::TensorProto> &weights,
217 bool compressed) {
218 std::stringstream ss;
219 ss << weights.size() << "\n";
220 zip.writeRecord("weights", ss.str().c_str(), ss.str().size(), compressed);
221 std::string largeBuffer;
222 int i = 0;
223 // This part is probably quite inefficient as we are deserializing the
224 // protobuf to a char buffer and then put it to zip stream. I didn't dig
225 // enough to see if we can deserialize it into zip stream directly.
226 for (const auto &t : weights) {
227 std::stringstream nm;
228 nm << "weight_" << i++;
229 t.SerializeToString(&largeBuffer);
230 zip.writeRecord(nm.str(), largeBuffer.c_str(), largeBuffer.size(),
231 compressed);
232 }
233}
234
235void scramble() {
236 LOG(INFO) << "Input model: " << inputModelPathOpt;
237 ::ONNX_NAMESPACE::ModelProto modelDef;
238 std::list<::ONNX_NAMESPACE::TensorProto> weights;
239 std::unordered_map<std::string, std::string> name_map;
240 std::unordered_map<std::string, std::string> node_map;
241 {
242 LOG(INFO) << "Writing output model to " << outputModelPathOpt;
243 std::ofstream ffO(outputModelPathOpt,
244 std::ios::out | std::ios::trunc | std::ios::binary);
245 CHECK(ffO);
246 ZipWriter zipO(&ffO, "test");
247 {
248 ZipReader zip(inputModelPathOpt);
249 std::string buffer;
250 buffer = zip.getRecord("model");
251 modelDef.ParseFromString(buffer);
252 weights = readWeightsAndMaybeCopyData(zip, zipO, kCompressed);
253 }
254
255 auto *g = modelDef.mutable_graph();
256 for (auto &n : *g->mutable_node()) {
257 for (auto &i : *n.mutable_input()) {
258 if (!name_map.count(i)) {
259 name_map.emplace(i, makeNewName(i));
260 }
261 i = name_map.at(i);
262 }
263 for (auto &o : *n.mutable_output()) {
264 if (!name_map.count(o)) {
265 name_map.emplace(o, makeNewName(o));
266 }
267 o = name_map.at(o);
268 }
269 const auto &name = n.name();
270 if (!node_map.count(name)) {
271 node_map.emplace(name, makeNewNodeName(name));
272 }
273 n.set_name(node_map.at(name));
274 }
275 for (auto &i : *g->mutable_input()) {
276 const auto &name = i.name();
277 if (!name_map.count(name)) {
278 name_map.emplace(name, makeNewName(name));
279 }
280 i.set_name(name_map.at(name));
281 }
282 for (auto &o : *g->mutable_output()) {
283 const auto &name = o.name();
284 if (!name_map.count(name)) {
285 name_map.emplace(name, makeNewName(name));
286 }
287 o.set_name(name_map.at(name));
288 }
289 for (auto &t : weights) {
290 const auto &name = t.name();
291 if (!name_map.count(name)) {
292 LOG(ERROR) << "It's a bit straight that weight " << name
293 << " is not referenced in the net";
294 name_map.emplace(name, makeNewName(name));
295 }
296 t.set_name(name_map.at(name));
297 }
298 // Look for attributes of a list of strings matching a name and swap for
299 // scrambled version. Note that this should be fine because we currently
300 // only use a list of strings for vector<NodeValue>.
301 for (auto &n : *g->mutable_node()) {
302 for (auto &a : *n.mutable_attribute()) {
303 if (a.name() == "Predicate") {
304 LOG(FATAL) << "Predicate NodeValue unhandled.";
305 }
306 for (auto &s : *a.mutable_strings()) {
307 if (name_map.count(s)) {
308 s = name_map[s];
309 }
310 }
311 }
312 }
313
314 writeWeights(zipO, weights, kCompressed);
315 std::string largeBuffer;
316 modelDef.SerializeToString(&largeBuffer);
317 zipO.writeRecord("model", largeBuffer.c_str(), largeBuffer.size(),
318 kCompressed);
319 zipO.writeEndOfFile();
320 ffO.flush();
321 ffO.close();
322 }
323
324 if (!inputDeferredWeightsPathOpt.empty()) {
325 weights.clear();
326 // Open the zip writer first in case we need to copy raw tensor data
327 // directly from zip reader
328 LOG(INFO) << "Writing output deferred weights to "
329 << outputDeferredWeightsPathOpt;
330 std::ofstream ffO(outputDeferredWeightsPathOpt,
331 std::ios::out | std::ios::trunc | std::ios::binary);
332 CHECK(ffO);
333 ZipWriter zipO(&ffO, "test");
334
335 {
336 LOG(INFO) << "Input deferred weights: " << inputDeferredWeightsPathOpt;
337 ZipReader zip(inputDeferredWeightsPathOpt);
338 weights = readWeightsAndMaybeCopyData(zip, zipO, kCompressed);
339 for (auto &t : weights) {
340 const auto &name = t.name();
341 if (!name_map.count(name)) {
342 LOG(ERROR) << "It's very straight that weight " << name
343 << " is not referenced in the net";
344 name_map.emplace(name, makeNewName(name));
345 }
346 t.set_name(name_map.at(name));
347 }
348 }
349
350 writeWeights(zipO, weights, kCompressed);
351 zipO.writeEndOfFile();
352 ffO.flush();
353 ffO.close();
354 }
355
356 size_t input_iter = inputPatternOpt.find("{}");
357 CHECK_NE(input_iter, std::string::npos)
358 << "Input pattern " << inputPatternOpt << " has to contain {}";
359 size_t output_iter = outputPatternOpt.find("{}");
360 CHECK_NE(output_iter, std::string::npos)
361 << "Output pattern " << outputPatternOpt << " has to contain {}";
362 for (unsigned i = seqStartOpt; i < seqLenOpt; ++i) {
363 std::string input = inputPatternOpt;
364 input.replace(input_iter, 2, std::to_string(seqStartOpt + i));
365 rewriteIO(input, name_map);
366 std::string output = outputPatternOpt;
367 output.replace(output_iter, 2, std::to_string(seqStartOpt + i));
368 rewriteIO(output, name_map);
369 }
370}
371
372void parseCommandLine(int argc, char **argv) {
373 llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
374 llvm::cl::ParseCommandLineOptions(argc, argv,
375 "The name scrambler\n\n"
376 "Scramble the name for repro files");
377}
378
379int main(int argc, char **argv) {
380 parseCommandLine(argc, argv);
381 scramble();
382 return 0;
383}
384