1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/session_factory.h" |
17 | |
18 | #include <unordered_map> |
19 | |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/strings/str_util.h" |
22 | #include "tensorflow/core/platform/logging.h" |
23 | #include "tensorflow/core/platform/mutex.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | #include "tensorflow/core/protobuf/config.pb.h" |
26 | #include "tensorflow/core/public/session_options.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace { |
30 | |
31 | static mutex* get_session_factory_lock() { |
32 | static mutex session_factory_lock(LINKER_INITIALIZED); |
33 | return &session_factory_lock; |
34 | } |
35 | |
36 | typedef std::unordered_map<string, SessionFactory*> SessionFactories; |
37 | SessionFactories* session_factories() { |
38 | static SessionFactories* factories = new SessionFactories; |
39 | return factories; |
40 | } |
41 | |
42 | } // namespace |
43 | |
44 | void SessionFactory::Register(const string& runtime_type, |
45 | SessionFactory* factory) { |
46 | mutex_lock l(*get_session_factory_lock()); |
47 | if (!session_factories()->insert({runtime_type, factory}).second) { |
48 | LOG(ERROR) << "Two session factories are being registered " |
49 | << "under " << runtime_type; |
50 | } |
51 | } |
52 | |
53 | namespace { |
54 | const string RegisteredFactoriesErrorMessageLocked() { |
55 | std::vector<string> factory_types; |
56 | for (const auto& session_factory : *session_factories()) { |
57 | factory_types.push_back(session_factory.first); |
58 | } |
59 | return strings::StrCat("Registered factories are {" , |
60 | absl::StrJoin(factory_types, ", " ), "}." ); |
61 | } |
62 | string SessionOptionsToString(const SessionOptions& options) { |
63 | return strings::StrCat("target: \"" , options.target, |
64 | "\" config: " , options.config.ShortDebugString()); |
65 | } |
66 | } // namespace |
67 | |
68 | Status SessionFactory::GetFactory(const SessionOptions& options, |
69 | SessionFactory** out_factory) { |
70 | mutex_lock l(*get_session_factory_lock()); // could use reader lock |
71 | |
72 | std::vector<std::pair<string, SessionFactory*>> candidate_factories; |
73 | for (const auto& session_factory : *session_factories()) { |
74 | if (session_factory.second->AcceptsOptions(options)) { |
75 | VLOG(2) << "SessionFactory type " << session_factory.first |
76 | << " accepts target: " << options.target; |
77 | candidate_factories.push_back(session_factory); |
78 | } else { |
79 | VLOG(2) << "SessionFactory type " << session_factory.first |
80 | << " does not accept target: " << options.target; |
81 | } |
82 | } |
83 | |
84 | if (candidate_factories.size() == 1) { |
85 | *out_factory = candidate_factories[0].second; |
86 | return OkStatus(); |
87 | } else if (candidate_factories.size() > 1) { |
88 | // NOTE(mrry): This implementation assumes that the domains (in |
89 | // terms of acceptable SessionOptions) of the registered |
90 | // SessionFactory implementations do not overlap. This is fine for |
91 | // now, but we may need an additional way of distinguishing |
92 | // different runtimes (such as an additional session option) if |
93 | // the number of sessions grows. |
94 | // TODO(mrry): Consider providing a system-default fallback option |
95 | // in this case. |
96 | std::vector<string> factory_types; |
97 | factory_types.reserve(candidate_factories.size()); |
98 | for (const auto& candidate_factory : candidate_factories) { |
99 | factory_types.push_back(candidate_factory.first); |
100 | } |
101 | return errors::Internal( |
102 | "Multiple session factories registered for the given session " |
103 | "options: {" , |
104 | SessionOptionsToString(options), "} Candidate factories are {" , |
105 | absl::StrJoin(factory_types, ", " ), "}. " , |
106 | RegisteredFactoriesErrorMessageLocked()); |
107 | } else { |
108 | return errors::NotFound( |
109 | "No session factory registered for the given session options: {" , |
110 | SessionOptionsToString(options), "} " , |
111 | RegisteredFactoriesErrorMessageLocked()); |
112 | } |
113 | } |
114 | |
115 | } // namespace tensorflow |
116 | |