1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/saved_model/loader_util.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/cc/saved_model/constants.h" |
21 | #include "tensorflow/core/lib/strings/strcat.h" |
22 | #include "tensorflow/core/platform/errors.h" |
23 | #include "tensorflow/core/platform/protobuf_internal.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace internal { |
27 | |
28 | // A SavedModel may store the name of the initialization op to run in the |
29 | // in the SignatureDef (v2) or a collection (v1). If an init_op collection |
30 | // exists, then the collection must contain exactly one op. |
31 | Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, |
32 | string* init_op_name) { |
33 | const auto& sig_def_map = meta_graph_def.signature_def(); |
34 | const auto& init_op_sig_it = |
35 | meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey); |
36 | if (init_op_sig_it != sig_def_map.end()) { |
37 | const auto& sig_def_outputs = init_op_sig_it->second.outputs(); |
38 | const auto& sig_def_outputs_it = |
39 | sig_def_outputs.find(kSavedModelInitOpSignatureKey); |
40 | if (sig_def_outputs_it == sig_def_outputs.end()) { |
41 | return errors::FailedPrecondition("Could not find output " , |
42 | kSavedModelInitOpSignatureKey); |
43 | } |
44 | *init_op_name = sig_def_outputs_it->second.name(); |
45 | return OkStatus(); |
46 | } |
47 | |
48 | const auto& collection_def_map = meta_graph_def.collection_def(); |
49 | string init_op_collection_key; |
50 | if (collection_def_map.find(kSavedModelMainOpKey) != |
51 | collection_def_map.end()) { |
52 | init_op_collection_key = kSavedModelMainOpKey; |
53 | } else { |
54 | init_op_collection_key = kSavedModelLegacyInitOpKey; |
55 | } |
56 | |
57 | const auto init_op_it = collection_def_map.find(init_op_collection_key); |
58 | if (init_op_it != collection_def_map.end()) { |
59 | if (init_op_it->second.node_list().value_size() != 1) { |
60 | return errors::FailedPrecondition( |
61 | strings::StrCat("Expected exactly one main op in : " , export_dir)); |
62 | } |
63 | *init_op_name = init_op_it->second.node_list().value(0); |
64 | } |
65 | return OkStatus(); |
66 | } |
67 | |
68 | Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, |
69 | std::vector<AssetFileDef>* asset_file_defs) { |
70 | // With SavedModel v2, we write asset file def into metagraph instead of |
71 | // collection, so read from metagraph first. |
72 | if (meta_graph_def.asset_file_def_size() > 0) { |
73 | for (const auto& asset : meta_graph_def.asset_file_def()) { |
74 | asset_file_defs->push_back(asset); |
75 | } |
76 | return OkStatus(); |
77 | } |
78 | // Fall back to read from collection to be backward compatible with v1. |
79 | const auto& collection_def_map = meta_graph_def.collection_def(); |
80 | const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); |
81 | if (assets_it == collection_def_map.end()) { |
82 | return OkStatus(); |
83 | } |
84 | const auto& any_assets = assets_it->second.any_list().value(); |
85 | for (const auto& any_asset : any_assets) { |
86 | AssetFileDef asset_file_def; |
87 | TF_RETURN_IF_ERROR( |
88 | ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef" )); |
89 | asset_file_defs->push_back(asset_file_def); |
90 | } |
91 | return OkStatus(); |
92 | } |
93 | |
94 | } // namespace internal |
95 | } // namespace tensorflow |
96 | |