1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17#define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18
19#include <functional>
20#include <unordered_map>
21#include <vector>
22
23#include "tensorflow/core/framework/full_type.pb.h"
24#include "tensorflow/core/framework/full_type_inference_util.h"
25#include "tensorflow/core/framework/full_type_util.h"
26#include "tensorflow/core/framework/op_def_builder.h"
27#include "tensorflow/core/framework/op_def_util.h"
28#include "tensorflow/core/framework/registration/registration.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/lib/core/status.h"
31#include "tensorflow/core/lib/strings/str_util.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33#include "tensorflow/core/platform/logging.h"
34#include "tensorflow/core/platform/macros.h"
35#include "tensorflow/core/platform/mutex.h"
36#include "tensorflow/core/platform/thread_annotations.h"
37#include "tensorflow/core/platform/types.h"
38
39namespace tensorflow {
40
41// Users that want to look up an OpDef by type name should take an
42// OpRegistryInterface. Functions accepting a
43// (const) OpRegistryInterface* may call LookUp() from multiple threads.
44class OpRegistryInterface {
45 public:
46 virtual ~OpRegistryInterface();
47
48 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
49 // registered under that name, otherwise returns the registered OpDef.
50 // Caller must not delete the returned pointer.
51 virtual Status LookUp(const std::string& op_type_name,
52 const OpRegistrationData** op_reg_data) const = 0;
53
54 // Shorthand for calling LookUp to get the OpDef.
55 Status LookUpOpDef(const std::string& op_type_name,
56 const OpDef** op_def) const;
57};
58
59// The standard implementation of OpRegistryInterface, along with a
60// global singleton used for registering ops via the REGISTER
61// macros below. Thread-safe.
62//
63// Example registration:
64// OpRegistry::Global()->Register(
65// [](OpRegistrationData* op_reg_data)->Status {
66// // Populate *op_reg_data here.
67// return Status::OK();
68// });
69class OpRegistry : public OpRegistryInterface {
70 public:
71 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
72
73 OpRegistry();
74 ~OpRegistry() override;
75
76 void Register(const OpRegistrationDataFactory& op_data_factory);
77
78 Status LookUp(const std::string& op_type_name,
79 const OpRegistrationData** op_reg_data) const override;
80
81 // Returns OpRegistrationData* of registered op type, else returns nullptr.
82 const OpRegistrationData* LookUp(const std::string& op_type_name) const;
83
84 // Fills *ops with all registered OpDefs (except those with names
85 // starting with '_' if include_internal == false) sorted in
86 // ascending alphabetical order.
87 void Export(bool include_internal, OpList* ops) const;
88
89 // Returns ASCII-format OpList for all registered OpDefs (except
90 // those with names starting with '_' if include_internal == false).
91 std::string DebugString(bool include_internal) const;
92
93 // A singleton available at startup.
94 static OpRegistry* Global();
95
96 // Get all registered ops.
97 void GetRegisteredOps(std::vector<OpDef>* op_defs);
98
99 // Get all `OpRegistrationData`s.
100 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
101
102 // Registers a function that validates op registry.
103 void RegisterValidator(
104 std::function<Status(const OpRegistryInterface&)> validator) {
105 op_registry_validator_ = std::move(validator);
106 }
107
108 // Watcher, a function object.
109 // The watcher, if set by SetWatcher(), is called every time an op is
110 // registered via the Register function. The watcher is passed the Status
111 // obtained from building and adding the OpDef to the registry, and the OpDef
112 // itself if it was successfully built. A watcher returns a Status which is in
113 // turn returned as the final registration status.
114 typedef std::function<Status(const Status&, const OpDef&)> Watcher;
115
116 // An OpRegistry object has only one watcher. This interface is not thread
117 // safe, as different clients are free to set the watcher any time.
118 // Clients are expected to atomically perform the following sequence of
119 // operations :
120 // SetWatcher(a_watcher);
121 // Register some ops;
122 // op_registry->ProcessRegistrations();
123 // SetWatcher(nullptr);
124 // Returns a non-OK status if a non-null watcher is over-written by another
125 // non-null watcher.
126 Status SetWatcher(const Watcher& watcher);
127
128 // Process the current list of deferred registrations. Note that calls to
129 // Export, LookUp and DebugString would also implicitly process the deferred
130 // registrations. Returns the status of the first failed op registration or
131 // Status::OK() otherwise.
132 Status ProcessRegistrations() const;
133
134 // Defer the registrations until a later call to a function that processes
135 // deferred registrations are made. Normally, registrations that happen after
136 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
137 // immediately. Call this to defer future registrations.
138 void DeferRegistrations();
139
140 // Clear the registrations that have been deferred.
141 void ClearDeferredRegistrations();
142
143 private:
144 // Ensures that all the functions in deferred_ get called, their OpDef's
145 // registered, and returns with deferred_ empty. Returns true the first
146 // time it is called. Prints a fatal log if any op registration fails.
147 bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
148
149 // Calls the functions in deferred_ and registers their OpDef's
150 // It returns the Status of the first failed op registration or Status::OK()
151 // otherwise.
152 Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
153
154 // Add 'def' to the registry with additional data 'data'. On failure, or if
155 // there is already an OpDef with that name registered, returns a non-okay
156 // status.
157 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
158 const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
159
160 const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
161
162 mutable mutex mu_;
163 // Functions in deferred_ may only be called with mu_ held.
164 mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
165 // Values are owned.
166 mutable std::unordered_map<string, const OpRegistrationData*> registry_
167 TF_GUARDED_BY(mu_);
168 mutable bool initialized_ TF_GUARDED_BY(mu_);
169
170 // Registry watcher.
171 mutable Watcher watcher_ TF_GUARDED_BY(mu_);
172
173 std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
174};
175
176// An adapter to allow an OpList to be used as an OpRegistryInterface.
177//
178// Note that shape inference functions are not passed in to OpListOpRegistry, so
179// it will return an unusable shape inference function for every op it supports;
180// therefore, it should only be used in contexts where this is okay.
181class OpListOpRegistry : public OpRegistryInterface {
182 public:
183 // Does not take ownership of op_list, *op_list must outlive *this.
184 explicit OpListOpRegistry(const OpList* op_list);
185 ~OpListOpRegistry() override;
186 Status LookUp(const std::string& op_type_name,
187 const OpRegistrationData** op_reg_data) const override;
188
189 // Returns OpRegistrationData* of op type in list, else returns nullptr.
190 const OpRegistrationData* LookUp(const std::string& op_type_name) const;
191
192 private:
193 // Values are owned.
194 std::unordered_map<string, const OpRegistrationData*> index_;
195};
196
197// Support for defining the OpDef (specifying the semantics of the Op and how
198// it should be created) and registering it in the OpRegistry::Global()
199// registry. Usage:
200//
201// REGISTER_OP("my_op_name")
202// .Attr("<name>:<type>")
203// .Attr("<name>:<type>=<default>")
204// .Input("<name>:<type-expr>")
205// .Input("<name>:Ref(<type-expr>)")
206// .Output("<name>:<type-expr>")
207// .Doc(R"(
208// <1-line summary>
209// <rest of the description (potentially many lines)>
210// <name-of-attr-input-or-output>: <description of name>
211// <name-of-attr-input-or-output>: <description of name;
212// if long, indent the description on subsequent lines>
213// )");
214//
215// Note: .Doc() should be last.
216// For details, see the OpDefBuilder class in op_def_builder.h.
217
218namespace register_op {
219
220class OpDefBuilderWrapper {
221 public:
222 explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
223 OpDefBuilderWrapper& Attr(std::string spec) {
224 builder_.Attr(std::move(spec));
225 return *this;
226 }
227 OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE {
228 return Attr(std::string(spec));
229 }
230 OpDefBuilderWrapper& Input(std::string spec) {
231 builder_.Input(std::move(spec));
232 return *this;
233 }
234 OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE {
235 return Input(std::string(spec));
236 }
237 OpDefBuilderWrapper& Output(std::string spec) {
238 builder_.Output(std::move(spec));
239 return *this;
240 }
241 OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE {
242 return Output(std::string(spec));
243 }
244 OpDefBuilderWrapper& SetIsCommutative() {
245 builder_.SetIsCommutative();
246 return *this;
247 }
248 OpDefBuilderWrapper& SetIsAggregate() {
249 builder_.SetIsAggregate();
250 return *this;
251 }
252 OpDefBuilderWrapper& SetIsStateful() {
253 builder_.SetIsStateful();
254 return *this;
255 }
256 OpDefBuilderWrapper& SetDoNotOptimize() {
257 // We don't have a separate flag to disable optimizations such as constant
258 // folding and CSE so we reuse the stateful flag.
259 builder_.SetIsStateful();
260 return *this;
261 }
262 OpDefBuilderWrapper& SetAllowsUninitializedInput() {
263 builder_.SetAllowsUninitializedInput();
264 return *this;
265 }
266 OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
267 builder_.Deprecated(version, std::move(explanation));
268 return *this;
269 }
270 OpDefBuilderWrapper& Doc(std::string text) {
271 builder_.Doc(std::move(text));
272 return *this;
273 }
274 OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
275 builder_.SetShapeFn(std::move(fn));
276 return *this;
277 }
278 OpDefBuilderWrapper& SetIsDistributedCommunication() {
279 builder_.SetIsDistributedCommunication();
280 return *this;
281 }
282
283 OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
284 builder_.SetTypeConstructor(std::move(fn));
285 return *this;
286 }
287
288 OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) {
289 builder_.SetForwardTypeFn(std::move(fn));
290 return *this;
291 }
292
293 OpDefBuilderWrapper& SetReverseTypeFn(int input_number,
294 ForwardTypeInferenceFn fn) {
295 builder_.SetReverseTypeFn(input_number, std::move(fn));
296 return *this;
297 }
298
299 const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
300
301 InitOnStartupMarker operator()();
302
303 private:
304 mutable ::tensorflow::OpDefBuilder builder_;
305};
306
307} // namespace register_op
308
309#define REGISTER_OP_IMPL(ctr, name, is_system_op) \
310 static ::tensorflow::InitOnStartupMarker const register_op##ctr \
311 TF_ATTRIBUTE_UNUSED = \
312 TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
313 << ::tensorflow::register_op::OpDefBuilderWrapper(name)
314
315#define REGISTER_OP(name) \
316 TF_ATTRIBUTE_ANNOTATE("tf:op") \
317 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
318
319// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
320// that the op is registered unconditionally even when selective
321// registration is used.
322#define REGISTER_SYSTEM_OP(name) \
323 TF_ATTRIBUTE_ANNOTATE("tf:op") \
324 TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
325 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
326
327} // namespace tensorflow
328
329#endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_
330