1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Exporter/ProtobufWriter.h" |
18 | #include <google/protobuf/io/coded_stream.h> |
19 | #include <google/protobuf/io/zero_copy_stream_impl.h> |
20 | |
21 | namespace glow { |
22 | |
23 | ProtobufWriter::ProtobufWriter(const std::string &modelFilename, Function *F, |
24 | Error *errPtr, bool writingToFile) |
25 | : F_(F) { |
26 | // Verify that the version of the library that we linked against is |
27 | // compatible with the version of the headers we compiled against. |
28 | GOOGLE_PROTOBUF_VERIFY_VERSION; |
29 | |
30 | // if errPtr already contains an error then don't continue with constructor |
31 | if (errPtr && *errPtr) { |
32 | return; |
33 | } |
34 | |
35 | // Lambda to setup the ProtobufWriter and return any Errors that were |
36 | // raised. |
37 | auto setup = [&]() -> Error { |
38 | if (writingToFile) { |
39 | // Try to open file for write |
40 | ff_.open(modelFilename, |
41 | std::ios::out | std::ios::trunc | std::ios::binary); |
42 | RETURN_ERR_IF_NOT(ff_, |
43 | "Can't find the output file name: " + modelFilename, |
44 | ErrorValue::ErrorCode::MODEL_WRITER_INVALID_FILENAME); |
45 | } |
46 | return Error::success(); |
47 | }; |
48 | |
49 | if (errPtr) { |
50 | *errPtr = setup(); |
51 | } else { |
52 | EXIT_ON_ERR(setup()); |
53 | } |
54 | } |
55 | |
56 | Error ProtobufWriter::writeModel(const ::google::protobuf::Message &modelProto, |
57 | bool textMode) { |
58 | { |
59 | ::google::protobuf::io::OstreamOutputStream zeroCopyOutput(&ff_); |
60 | // Write the content. |
61 | if (textMode) { |
62 | RETURN_ERR_IF_NOT( |
63 | google::protobuf::TextFormat::Print(modelProto, &zeroCopyOutput), |
64 | "Can't write to the output file name" , |
65 | ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR); |
66 | } else { |
67 | ::google::protobuf::io::CodedOutputStream codedOutput(&zeroCopyOutput); |
68 | modelProto.SerializeToCodedStream(&codedOutput); |
69 | RETURN_ERR_IF_NOT( |
70 | !codedOutput.HadError(), "Can't write to the output file name" , |
71 | ErrorValue::ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR); |
72 | } |
73 | } |
74 | ff_.flush(); |
75 | ff_.close(); |
76 | return Error::success(); |
77 | } |
78 | |
79 | } // namespace glow |
80 | |