1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17
18#include <algorithm>
19#include <memory>
20#include <vector>
21
22#include "tensorflow/core/framework/full_type.pb.h"
23#include "tensorflow/core/framework/op_def_builder.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/gtl/map_util.h"
26#include "tensorflow/core/lib/strings/str_util.h"
27#include "tensorflow/core/platform/host_info.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/mutex.h"
30#include "tensorflow/core/platform/protobuf.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35Status DefaultValidator(const OpRegistryInterface& op_registry) {
36 LOG(WARNING) << "No kernel validator registered with OpRegistry.";
37 return OkStatus();
38}
39
40// OpRegistry -----------------------------------------------------------------
41
42OpRegistryInterface::~OpRegistryInterface() {}
43
44Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
45 const OpDef** op_def) const {
46 *op_def = nullptr;
47 const OpRegistrationData* op_reg_data = nullptr;
48 TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
49 *op_def = &op_reg_data->op_def;
50 return OkStatus();
51}
52
53OpRegistry::OpRegistry()
54 : initialized_(false), op_registry_validator_(DefaultValidator) {}
55
56OpRegistry::~OpRegistry() {
57 for (const auto& e : registry_) delete e.second;
58}
59
60void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
61 mutex_lock lock(mu_);
62 if (initialized_) {
63 TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
64 } else {
65 deferred_.push_back(op_data_factory);
66 }
67}
68
69namespace {
70// Helper function that returns Status message for failed LookUp.
71Status OpNotFound(const string& op_type_name) {
72 Status status = errors::NotFound(
73 "Op type not registered '", op_type_name, "' in binary running on ",
74 port::Hostname(), ". ",
75 "Make sure the Op and Kernel are registered in the binary running in "
76 "this process. Note that if you are loading a saved graph which used ops "
77 "from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
78 "before importing the graph, as contrib ops are lazily registered when "
79 "the module is first accessed.");
80 VLOG(1) << status.ToString();
81 return status;
82}
83} // namespace
84
85Status OpRegistry::LookUp(const string& op_type_name,
86 const OpRegistrationData** op_reg_data) const {
87 if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
88 return OpNotFound(op_type_name);
89}
90
91const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const {
92 {
93 tf_shared_lock l(mu_);
94 if (initialized_) {
95 if (const OpRegistrationData* res =
96 gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
97 return res;
98 }
99 }
100 }
101 return LookUpSlow(op_type_name);
102}
103
104const OpRegistrationData* OpRegistry::LookUpSlow(
105 const string& op_type_name) const {
106 const OpRegistrationData* res = nullptr;
107
108 bool first_call = false;
109 bool first_unregistered = false;
110 { // Scope for lock.
111 mutex_lock lock(mu_);
112 first_call = MustCallDeferred();
113 res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
114
115 static bool unregistered_before = false;
116 first_unregistered = !unregistered_before && (res == nullptr);
117 if (first_unregistered) {
118 unregistered_before = true;
119 }
120 // Note: Can't hold mu_ while calling Export() below.
121 }
122 if (first_call) {
123 TF_QCHECK_OK(op_registry_validator_(*this));
124 }
125 if (res == nullptr) {
126 if (first_unregistered) {
127 OpList op_list;
128 Export(true, &op_list);
129 if (VLOG_IS_ON(3)) {
130 LOG(INFO) << "All registered Ops:";
131 for (const auto& op : op_list.op()) {
132 LOG(INFO) << SummarizeOpDef(op);
133 }
134 }
135 }
136 }
137 return res;
138}
139
140void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
141 mutex_lock lock(mu_);
142 MustCallDeferred();
143 for (const auto& p : registry_) {
144 op_defs->push_back(p.second->op_def);
145 }
146}
147
148void OpRegistry::GetOpRegistrationData(
149 std::vector<OpRegistrationData>* op_data) {
150 mutex_lock lock(mu_);
151 MustCallDeferred();
152 for (const auto& p : registry_) {
153 op_data->push_back(*p.second);
154 }
155}
156
157Status OpRegistry::SetWatcher(const Watcher& watcher) {
158 mutex_lock lock(mu_);
159 if (watcher_ && watcher) {
160 return errors::AlreadyExists(
161 "Cannot over-write a valid watcher with another.");
162 }
163 watcher_ = watcher;
164 return OkStatus();
165}
166
167void OpRegistry::Export(bool include_internal, OpList* ops) const {
168 mutex_lock lock(mu_);
169 MustCallDeferred();
170
171 std::vector<std::pair<string, const OpRegistrationData*>> sorted(
172 registry_.begin(), registry_.end());
173 std::sort(sorted.begin(), sorted.end());
174
175 auto out = ops->mutable_op();
176 out->Clear();
177 out->Reserve(sorted.size());
178
179 for (const auto& item : sorted) {
180 if (include_internal || !absl::StartsWith(item.first, "_")) {
181 *out->Add() = item.second->op_def;
182 }
183 }
184}
185
186void OpRegistry::DeferRegistrations() {
187 mutex_lock lock(mu_);
188 initialized_ = false;
189}
190
191void OpRegistry::ClearDeferredRegistrations() {
192 mutex_lock lock(mu_);
193 deferred_.clear();
194}
195
196Status OpRegistry::ProcessRegistrations() const {
197 mutex_lock lock(mu_);
198 return CallDeferred();
199}
200
201string OpRegistry::DebugString(bool include_internal) const {
202 OpList op_list;
203 Export(include_internal, &op_list);
204 string ret;
205 for (const auto& op : op_list.op()) {
206 strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
207 }
208 return ret;
209}
210
211bool OpRegistry::MustCallDeferred() const {
212 if (initialized_) return false;
213 initialized_ = true;
214 for (size_t i = 0; i < deferred_.size(); ++i) {
215 TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
216 }
217 deferred_.clear();
218 return true;
219}
220
221Status OpRegistry::CallDeferred() const {
222 if (initialized_) return OkStatus();
223 initialized_ = true;
224 for (size_t i = 0; i < deferred_.size(); ++i) {
225 Status s = RegisterAlreadyLocked(deferred_[i]);
226 if (!s.ok()) {
227 return s;
228 }
229 }
230 deferred_.clear();
231 return OkStatus();
232}
233
234Status OpRegistry::RegisterAlreadyLocked(
235 const OpRegistrationDataFactory& op_data_factory) const {
236 std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
237 Status s = op_data_factory(op_reg_data.get());
238 if (s.ok()) {
239 s = ValidateOpDef(op_reg_data->op_def);
240 if (s.ok() &&
241 !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
242 op_reg_data.get())) {
243 s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
244 }
245 }
246 Status watcher_status = s;
247 if (watcher_) {
248 watcher_status = watcher_(s, op_reg_data->op_def);
249 }
250 if (s.ok()) {
251 op_reg_data.release();
252 } else {
253 op_reg_data.reset();
254 }
255 return watcher_status;
256}
257
258// static
259OpRegistry* OpRegistry::Global() {
260 static OpRegistry* global_op_registry = new OpRegistry;
261 return global_op_registry;
262}
263
264// OpListOpRegistry -----------------------------------------------------------
265
266OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
267 for (const OpDef& op_def : op_list->op()) {
268 auto* op_reg_data = new OpRegistrationData();
269 op_reg_data->op_def = op_def;
270 index_[op_def.name()] = op_reg_data;
271 }
272}
273
274OpListOpRegistry::~OpListOpRegistry() {
275 for (const auto& e : index_) delete e.second;
276}
277
278const OpRegistrationData* OpListOpRegistry::LookUp(
279 const string& op_type_name) const {
280 auto iter = index_.find(op_type_name);
281 if (iter == index_.end()) {
282 return nullptr;
283 }
284 return iter->second;
285}
286
287Status OpListOpRegistry::LookUp(const string& op_type_name,
288 const OpRegistrationData** op_reg_data) const {
289 if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
290 return OpNotFound(op_type_name);
291}
292
293namespace register_op {
294
295InitOnStartupMarker OpDefBuilderWrapper::operator()() {
296 OpRegistry::Global()->Register(
297 [builder =
298 std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
299 return builder.Finalize(op_reg_data);
300 });
301 return {};
302}
303
304} // namespace register_op
305
306} // namespace tensorflow
307