1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstdio> |
16 | #include <memory> |
17 | #include <string> |
18 | |
19 | #include "absl/strings/string_view.h" |
20 | #include "tensorflow/lite/toco/model.h" |
21 | #include "tensorflow/lite/toco/model_cmdline_flags.h" |
22 | #include "tensorflow/lite/toco/model_flags.pb.h" |
23 | #include "tensorflow/lite/toco/toco_cmdline_flags.h" |
24 | #include "tensorflow/lite/toco/toco_flags.pb.h" |
25 | #include "tensorflow/lite/toco/toco_port.h" |
26 | #include "tensorflow/lite/toco/toco_tooling.h" |
27 | #include "tensorflow/lite/toco/toco_types.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | |
31 | namespace toco { |
32 | namespace { |
33 | |
34 | // Checks the permissions of the output file to ensure it is writeable. |
35 | void CheckOutputFilePermissions(const Arg<std::string>& output_file) { |
36 | QCHECK(output_file.specified()) << "Missing required flag --output_file.\n" ; |
37 | QCHECK(port::file::Writable(output_file.value()).ok()) |
38 | << "Specified output_file is not writable: " << output_file.value() |
39 | << ".\n" ; |
40 | } |
41 | |
42 | // Checks the permissions of the frozen model file. |
43 | void CheckFrozenModelPermissions(const Arg<std::string>& input_file) { |
44 | QCHECK(input_file.specified()) << "Missing required flag --input_file.\n" ; |
45 | QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) |
46 | << "Specified input_file does not exist: " << input_file.value() << ".\n" ; |
47 | QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) |
48 | << "Specified input_file exists, but is not readable: " |
49 | << input_file.value() << ".\n" ; |
50 | } |
51 | |
52 | // Reads the contents of the GraphDef from either the frozen graph file or the |
53 | // SavedModel directory. If it reads the SavedModel directory, it updates the |
54 | // ModelFlags and TocoFlags accordingly. |
55 | void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, |
56 | const ParsedModelFlags& parsed_model_flags, |
57 | std::string* graph_def_contents) { |
58 | port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n" ); |
59 | |
60 | // Ensure savedmodel_directory is not set. |
61 | QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) |
62 | << "Use `tensorflow/lite/python/tflite_convert` script with " |
63 | << "SavedModel directories.\n" ; |
64 | |
65 | // Checks the input file permissions and reads the contents. |
66 | CheckFrozenModelPermissions(parsed_toco_flags.input_file); |
67 | CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), |
68 | graph_def_contents, port::file::Defaults()) |
69 | .ok()); |
70 | } |
71 | } // namespace |
72 | |
73 | tensorflow::Status Convert(const std::string& graph_def_contents, |
74 | const TocoFlags& toco_flags, |
75 | const ModelFlags& model_flags, |
76 | std::string* output_file_contents, |
77 | int64_t* arithmetic_ops_count = nullptr) { |
78 | std::unique_ptr<Model> model = |
79 | Import(toco_flags, model_flags, graph_def_contents); |
80 | TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get())); |
81 | TF_RETURN_IF_ERROR(Export(toco_flags, *model, toco_flags.allow_custom_ops(), |
82 | output_file_contents)); |
83 | if (arithmetic_ops_count != nullptr) { |
84 | *arithmetic_ops_count = model->ArithmeticOpsCount(); |
85 | } |
86 | return ::tensorflow::OkStatus(); |
87 | } |
88 | |
89 | tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, |
90 | const ParsedModelFlags& parsed_model_flags) { |
91 | ModelFlags model_flags; |
92 | ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); |
93 | |
94 | TocoFlags toco_flags; |
95 | ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); |
96 | |
97 | std::string graph_def_contents; |
98 | ReadInputData(parsed_toco_flags, parsed_model_flags, &graph_def_contents); |
99 | CheckOutputFilePermissions(parsed_toco_flags.output_file); |
100 | |
101 | std::string output_file_contents; |
102 | TF_RETURN_IF_ERROR(Convert(graph_def_contents, toco_flags, model_flags, |
103 | &output_file_contents)); |
104 | |
105 | TF_RETURN_IF_ERROR( |
106 | port::file::SetContents(parsed_toco_flags.output_file.value(), |
107 | output_file_contents, port::file::Defaults())); |
108 | return tensorflow::Status(); |
109 | } |
110 | |
111 | } // namespace toco |
112 | |