1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/executor_factory.h" |
17 | |
18 | #include <unordered_map> |
19 | |
20 | #include "tensorflow/core/graph/graph.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/strings/str_util.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/platform/mutex.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace { |
29 | |
30 | static mutex executor_factory_lock(LINKER_INITIALIZED); |
31 | |
32 | typedef std::unordered_map<string, ExecutorFactory*> ExecutorFactories; |
33 | ExecutorFactories* executor_factories() { |
34 | static ExecutorFactories* factories = new ExecutorFactories; |
35 | return factories; |
36 | } |
37 | |
38 | } // namespace |
39 | |
40 | void ExecutorFactory::Register(const string& executor_type, |
41 | ExecutorFactory* factory) { |
42 | mutex_lock l(executor_factory_lock); |
43 | if (!executor_factories()->insert({executor_type, factory}).second) { |
44 | LOG(FATAL) << "Two executor factories are being registered " |
45 | << "under" << executor_type; |
46 | } |
47 | } |
48 | |
49 | namespace { |
50 | const string RegisteredFactoriesErrorMessageLocked() |
51 | TF_SHARED_LOCKS_REQUIRED(executor_factory_lock) { |
52 | std::vector<string> factory_types; |
53 | for (const auto& executor_factory : *executor_factories()) { |
54 | factory_types.push_back(executor_factory.first); |
55 | } |
56 | return strings::StrCat("Registered factories are {" , |
57 | absl::StrJoin(factory_types, ", " ), "}." ); |
58 | } |
59 | } // namespace |
60 | |
61 | Status ExecutorFactory::GetFactory(const string& executor_type, |
62 | ExecutorFactory** out_factory) { |
63 | tf_shared_lock l(executor_factory_lock); |
64 | |
65 | auto iter = executor_factories()->find(executor_type); |
66 | if (iter == executor_factories()->end()) { |
67 | return errors::NotFound( |
68 | "No executor factory registered for the given executor type: " , |
69 | executor_type, " " , RegisteredFactoriesErrorMessageLocked()); |
70 | } |
71 | |
72 | *out_factory = iter->second; |
73 | return OkStatus(); |
74 | } |
75 | |
76 | Status NewExecutor(const string& executor_type, |
77 | const LocalExecutorParams& params, const Graph& graph, |
78 | std::unique_ptr<Executor>* out_executor) { |
79 | ExecutorFactory* factory = nullptr; |
80 | TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory)); |
81 | return factory->NewExecutor(params, std::move(graph), out_executor); |
82 | } |
83 | |
84 | } // namespace tensorflow |
85 | |