1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/saved_model/reader.h" |
17 | |
18 | #include <unordered_set> |
19 | |
20 | #include "absl/memory/memory.h" |
21 | #include "tensorflow/cc/saved_model/constants.h" |
22 | #include "tensorflow/cc/saved_model/metrics.h" |
23 | #include "tensorflow/cc/saved_model/util.h" |
24 | #include "tensorflow/core/framework/attr_value.pb.h" |
25 | #include "tensorflow/core/framework/function.pb.h" |
26 | #include "tensorflow/core/framework/graph.pb.h" |
27 | #include "tensorflow/core/framework/node_def.pb.h" |
28 | #include "tensorflow/core/framework/tensor.pb.h" |
29 | #include "tensorflow/core/lib/io/path.h" |
30 | #include "tensorflow/core/lib/strings/str_util.h" |
31 | #include "tensorflow/core/lib/strings/strcat.h" |
32 | #include "tensorflow/core/platform/env.h" |
33 | #include "tensorflow/core/platform/file_system_helper.h" |
34 | #include "tensorflow/core/platform/statusor.h" |
35 | #include "tensorflow/core/protobuf/saved_model.pb.h" |
36 | #include "tensorflow/core/util/tensor_bundle/byte_swap.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace { |
40 | |
41 | // Reads the SavedModel proto from saved_model.pb in `export_dir`. |
42 | // Returns a failure status when the SavedModel file does not exist. |
43 | Status ReadSavedModel(absl::string_view export_dir, |
44 | SavedModel* saved_model_proto) { |
45 | LOG(INFO) << "Reading SavedModel from: " << export_dir; |
46 | |
47 | const std::string saved_model_pb_path = |
48 | io::JoinPath(export_dir, kSavedModelFilenamePb); |
49 | |
50 | TF_ASSIGN_OR_RETURN( |
51 | bool saved_model_pb_exists, |
52 | internal::FileExists(Env::Default(), saved_model_pb_path)); |
53 | if (saved_model_pb_exists) { |
54 | Status result = |
55 | ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto); |
56 | if (result.ok()) { |
57 | metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto)) |
58 | .IncrementBy(1); |
59 | } |
60 | return result; |
61 | } |
62 | const std::string saved_model_pbtxt_path = |
63 | io::JoinPath(export_dir, kSavedModelFilenamePbTxt); |
64 | TF_ASSIGN_OR_RETURN( |
65 | bool saved_model_pbtxt_exists, |
66 | internal::FileExists(Env::Default(), saved_model_pbtxt_path)); |
67 | if (saved_model_pbtxt_exists) { |
68 | Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path, |
69 | saved_model_proto); |
70 | if (result.ok()) { |
71 | metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto)) |
72 | .IncrementBy(1); |
73 | } |
74 | return result; |
75 | } |
76 | return Status( |
77 | error::Code::NOT_FOUND, |
78 | strings::StrCat("Could not find SavedModel .pb or .pbtxt at supplied " |
79 | "export directory path: " , |
80 | export_dir, |
81 | ". Check that " |
82 | "the directory exists and that you have the right " |
83 | "permissions for accessing it." )); |
84 | } |
85 | |
86 | Status FindMetaGraphDef(const std::unordered_set<string>& tags, |
87 | SavedModel* saved_model_proto, |
88 | MetaGraphDef* meta_graph_def) { |
89 | LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " " ) |
90 | << " }" ; |
91 | for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) { |
92 | // Get tags from the graph_def. |
93 | std::unordered_set<string> graph_tags; |
94 | for (const string& tag : graph_def.meta_info_def().tags()) { |
95 | graph_tags.insert(tag); |
96 | } |
97 | // Match with the set of tags provided. |
98 | if (graph_tags == tags) { |
99 | *meta_graph_def = std::move(graph_def); |
100 | // Correct the endiness of Tensor content on big-endian system |
101 | if (!port::kLittleEndian) { |
102 | TF_RETURN_IF_ERROR(ByteSwapTensorContent(meta_graph_def)); |
103 | } |
104 | return OkStatus(); |
105 | } |
106 | } |
107 | return Status( |
108 | error::Code::NOT_FOUND, |
109 | strings::StrCat( |
110 | "Could not find meta graph def matching supplied tags: { " , |
111 | absl::StrJoin(tags, " " ), |
112 | " }. To inspect available tag-sets in the SavedModel, please " |
113 | "use the SavedModel CLI: `saved_model_cli`" )); |
114 | } |
115 | } // namespace |
116 | |
117 | Status ReadMetaGraphDefFromSavedModel(const string& export_dir, |
118 | const std::unordered_set<string>& tags, |
119 | MetaGraphDef* const meta_graph_def) { |
120 | SavedModel saved_model_proto; |
121 | TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); |
122 | TF_RETURN_IF_ERROR( |
123 | FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def)); |
124 | return OkStatus(); |
125 | } |
126 | |
127 | Status ReadSavedModelDebugInfoIfPresent( |
128 | const string& export_dir, |
129 | std::unique_ptr<GraphDebugInfo>* debug_info_proto) { |
130 | LOG(INFO) << "Reading SavedModel debug info (if present) from: " |
131 | << export_dir; |
132 | |
133 | const string debug_info_pb_path = |
134 | io::JoinPath(export_dir, "debug" , "saved_model_debug_info.pb" ); |
135 | TF_ASSIGN_OR_RETURN(bool debug_info_pb_exists, |
136 | internal::FileExists(Env::Default(), debug_info_pb_path)); |
137 | if (debug_info_pb_exists) { |
138 | GraphDebugInfo debug_info; |
139 | TF_RETURN_IF_ERROR( |
140 | ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); |
141 | *debug_info_proto = |
142 | absl::make_unique<GraphDebugInfo>(std::move(debug_info)); |
143 | } |
144 | return OkStatus(); |
145 | } |
146 | |
147 | } // namespace tensorflow |
148 | |