1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ |
18 | |
19 | #include <functional> |
20 | #include <unordered_map> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/full_type.pb.h" |
24 | #include "tensorflow/core/framework/full_type_inference_util.h" |
25 | #include "tensorflow/core/framework/full_type_util.h" |
26 | #include "tensorflow/core/framework/op_def_builder.h" |
27 | #include "tensorflow/core/framework/op_def_util.h" |
28 | #include "tensorflow/core/framework/registration/registration.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/lib/strings/str_util.h" |
32 | #include "tensorflow/core/lib/strings/strcat.h" |
33 | #include "tensorflow/core/platform/logging.h" |
34 | #include "tensorflow/core/platform/macros.h" |
35 | #include "tensorflow/core/platform/mutex.h" |
36 | #include "tensorflow/core/platform/thread_annotations.h" |
37 | #include "tensorflow/core/platform/types.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | // Users that want to look up an OpDef by type name should take an |
42 | // OpRegistryInterface. Functions accepting a |
43 | // (const) OpRegistryInterface* may call LookUp() from multiple threads. |
44 | class OpRegistryInterface { |
45 | public: |
46 | virtual ~OpRegistryInterface(); |
47 | |
48 | // Returns an error status and sets *op_reg_data to nullptr if no OpDef is |
49 | // registered under that name, otherwise returns the registered OpDef. |
50 | // Caller must not delete the returned pointer. |
51 | virtual Status LookUp(const std::string& op_type_name, |
52 | const OpRegistrationData** op_reg_data) const = 0; |
53 | |
54 | // Shorthand for calling LookUp to get the OpDef. |
55 | Status LookUpOpDef(const std::string& op_type_name, |
56 | const OpDef** op_def) const; |
57 | }; |
58 | |
59 | // The standard implementation of OpRegistryInterface, along with a |
60 | // global singleton used for registering ops via the REGISTER |
61 | // macros below. Thread-safe. |
62 | // |
63 | // Example registration: |
64 | // OpRegistry::Global()->Register( |
65 | // [](OpRegistrationData* op_reg_data)->Status { |
66 | // // Populate *op_reg_data here. |
67 | // return Status::OK(); |
68 | // }); |
69 | class OpRegistry : public OpRegistryInterface { |
70 | public: |
71 | typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; |
72 | |
73 | OpRegistry(); |
74 | ~OpRegistry() override; |
75 | |
76 | void Register(const OpRegistrationDataFactory& op_data_factory); |
77 | |
78 | Status LookUp(const std::string& op_type_name, |
79 | const OpRegistrationData** op_reg_data) const override; |
80 | |
81 | // Returns OpRegistrationData* of registered op type, else returns nullptr. |
82 | const OpRegistrationData* LookUp(const std::string& op_type_name) const; |
83 | |
84 | // Fills *ops with all registered OpDefs (except those with names |
85 | // starting with '_' if include_internal == false) sorted in |
86 | // ascending alphabetical order. |
87 | void Export(bool include_internal, OpList* ops) const; |
88 | |
89 | // Returns ASCII-format OpList for all registered OpDefs (except |
90 | // those with names starting with '_' if include_internal == false). |
91 | std::string DebugString(bool include_internal) const; |
92 | |
93 | // A singleton available at startup. |
94 | static OpRegistry* Global(); |
95 | |
96 | // Get all registered ops. |
97 | void GetRegisteredOps(std::vector<OpDef>* op_defs); |
98 | |
99 | // Get all `OpRegistrationData`s. |
100 | void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); |
101 | |
102 | // Registers a function that validates op registry. |
103 | void RegisterValidator( |
104 | std::function<Status(const OpRegistryInterface&)> validator) { |
105 | op_registry_validator_ = std::move(validator); |
106 | } |
107 | |
108 | // Watcher, a function object. |
109 | // The watcher, if set by SetWatcher(), is called every time an op is |
110 | // registered via the Register function. The watcher is passed the Status |
111 | // obtained from building and adding the OpDef to the registry, and the OpDef |
112 | // itself if it was successfully built. A watcher returns a Status which is in |
113 | // turn returned as the final registration status. |
114 | typedef std::function<Status(const Status&, const OpDef&)> Watcher; |
115 | |
116 | // An OpRegistry object has only one watcher. This interface is not thread |
117 | // safe, as different clients are free to set the watcher any time. |
118 | // Clients are expected to atomically perform the following sequence of |
119 | // operations : |
120 | // SetWatcher(a_watcher); |
121 | // Register some ops; |
122 | // op_registry->ProcessRegistrations(); |
123 | // SetWatcher(nullptr); |
124 | // Returns a non-OK status if a non-null watcher is over-written by another |
125 | // non-null watcher. |
126 | Status SetWatcher(const Watcher& watcher); |
127 | |
128 | // Process the current list of deferred registrations. Note that calls to |
129 | // Export, LookUp and DebugString would also implicitly process the deferred |
130 | // registrations. Returns the status of the first failed op registration or |
131 | // Status::OK() otherwise. |
132 | Status ProcessRegistrations() const; |
133 | |
134 | // Defer the registrations until a later call to a function that processes |
135 | // deferred registrations are made. Normally, registrations that happen after |
136 | // calls to Export, LookUp, ProcessRegistrations and DebugString are processed |
137 | // immediately. Call this to defer future registrations. |
138 | void DeferRegistrations(); |
139 | |
140 | // Clear the registrations that have been deferred. |
141 | void ClearDeferredRegistrations(); |
142 | |
143 | private: |
144 | // Ensures that all the functions in deferred_ get called, their OpDef's |
145 | // registered, and returns with deferred_ empty. Returns true the first |
146 | // time it is called. Prints a fatal log if any op registration fails. |
147 | bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
148 | |
149 | // Calls the functions in deferred_ and registers their OpDef's |
150 | // It returns the Status of the first failed op registration or Status::OK() |
151 | // otherwise. |
152 | Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
153 | |
154 | // Add 'def' to the registry with additional data 'data'. On failure, or if |
155 | // there is already an OpDef with that name registered, returns a non-okay |
156 | // status. |
157 | Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) |
158 | const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
159 | |
160 | const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; |
161 | |
162 | mutable mutex mu_; |
163 | // Functions in deferred_ may only be called with mu_ held. |
164 | mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_); |
165 | // Values are owned. |
166 | mutable std::unordered_map<string, const OpRegistrationData*> registry_ |
167 | TF_GUARDED_BY(mu_); |
168 | mutable bool initialized_ TF_GUARDED_BY(mu_); |
169 | |
170 | // Registry watcher. |
171 | mutable Watcher watcher_ TF_GUARDED_BY(mu_); |
172 | |
173 | std::function<Status(const OpRegistryInterface&)> op_registry_validator_; |
174 | }; |
175 | |
176 | // An adapter to allow an OpList to be used as an OpRegistryInterface. |
177 | // |
178 | // Note that shape inference functions are not passed in to OpListOpRegistry, so |
179 | // it will return an unusable shape inference function for every op it supports; |
180 | // therefore, it should only be used in contexts where this is okay. |
181 | class OpListOpRegistry : public OpRegistryInterface { |
182 | public: |
183 | // Does not take ownership of op_list, *op_list must outlive *this. |
184 | explicit OpListOpRegistry(const OpList* op_list); |
185 | ~OpListOpRegistry() override; |
186 | Status LookUp(const std::string& op_type_name, |
187 | const OpRegistrationData** op_reg_data) const override; |
188 | |
189 | // Returns OpRegistrationData* of op type in list, else returns nullptr. |
190 | const OpRegistrationData* LookUp(const std::string& op_type_name) const; |
191 | |
192 | private: |
193 | // Values are owned. |
194 | std::unordered_map<string, const OpRegistrationData*> index_; |
195 | }; |
196 | |
197 | // Support for defining the OpDef (specifying the semantics of the Op and how |
198 | // it should be created) and registering it in the OpRegistry::Global() |
199 | // registry. Usage: |
200 | // |
201 | // REGISTER_OP("my_op_name") |
202 | // .Attr("<name>:<type>") |
203 | // .Attr("<name>:<type>=<default>") |
204 | // .Input("<name>:<type-expr>") |
205 | // .Input("<name>:Ref(<type-expr>)") |
206 | // .Output("<name>:<type-expr>") |
207 | // .Doc(R"( |
208 | // <1-line summary> |
209 | // <rest of the description (potentially many lines)> |
210 | // <name-of-attr-input-or-output>: <description of name> |
211 | // <name-of-attr-input-or-output>: <description of name; |
212 | // if long, indent the description on subsequent lines> |
213 | // )"); |
214 | // |
215 | // Note: .Doc() should be last. |
216 | // For details, see the OpDefBuilder class in op_def_builder.h. |
217 | |
218 | namespace register_op { |
219 | |
220 | class OpDefBuilderWrapper { |
221 | public: |
222 | explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} |
223 | OpDefBuilderWrapper& Attr(std::string spec) { |
224 | builder_.Attr(std::move(spec)); |
225 | return *this; |
226 | } |
227 | OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE { |
228 | return Attr(std::string(spec)); |
229 | } |
230 | OpDefBuilderWrapper& Input(std::string spec) { |
231 | builder_.Input(std::move(spec)); |
232 | return *this; |
233 | } |
234 | OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE { |
235 | return Input(std::string(spec)); |
236 | } |
237 | OpDefBuilderWrapper& Output(std::string spec) { |
238 | builder_.Output(std::move(spec)); |
239 | return *this; |
240 | } |
241 | OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE { |
242 | return Output(std::string(spec)); |
243 | } |
244 | OpDefBuilderWrapper& SetIsCommutative() { |
245 | builder_.SetIsCommutative(); |
246 | return *this; |
247 | } |
248 | OpDefBuilderWrapper& SetIsAggregate() { |
249 | builder_.SetIsAggregate(); |
250 | return *this; |
251 | } |
252 | OpDefBuilderWrapper& SetIsStateful() { |
253 | builder_.SetIsStateful(); |
254 | return *this; |
255 | } |
256 | OpDefBuilderWrapper& SetDoNotOptimize() { |
257 | // We don't have a separate flag to disable optimizations such as constant |
258 | // folding and CSE so we reuse the stateful flag. |
259 | builder_.SetIsStateful(); |
260 | return *this; |
261 | } |
262 | OpDefBuilderWrapper& SetAllowsUninitializedInput() { |
263 | builder_.SetAllowsUninitializedInput(); |
264 | return *this; |
265 | } |
266 | OpDefBuilderWrapper& Deprecated(int version, std::string explanation) { |
267 | builder_.Deprecated(version, std::move(explanation)); |
268 | return *this; |
269 | } |
270 | OpDefBuilderWrapper& Doc(std::string text) { |
271 | builder_.Doc(std::move(text)); |
272 | return *this; |
273 | } |
274 | OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { |
275 | builder_.SetShapeFn(std::move(fn)); |
276 | return *this; |
277 | } |
278 | OpDefBuilderWrapper& SetIsDistributedCommunication() { |
279 | builder_.SetIsDistributedCommunication(); |
280 | return *this; |
281 | } |
282 | |
283 | OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) { |
284 | builder_.SetTypeConstructor(std::move(fn)); |
285 | return *this; |
286 | } |
287 | |
288 | OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) { |
289 | builder_.SetForwardTypeFn(std::move(fn)); |
290 | return *this; |
291 | } |
292 | |
293 | OpDefBuilderWrapper& SetReverseTypeFn(int input_number, |
294 | ForwardTypeInferenceFn fn) { |
295 | builder_.SetReverseTypeFn(input_number, std::move(fn)); |
296 | return *this; |
297 | } |
298 | |
299 | const ::tensorflow::OpDefBuilder& builder() const { return builder_; } |
300 | |
301 | InitOnStartupMarker operator()(); |
302 | |
303 | private: |
304 | mutable ::tensorflow::OpDefBuilder builder_; |
305 | }; |
306 | |
307 | } // namespace register_op |
308 | |
309 | #define REGISTER_OP_IMPL(ctr, name, is_system_op) \ |
310 | static ::tensorflow::InitOnStartupMarker const register_op##ctr \ |
311 | TF_ATTRIBUTE_UNUSED = \ |
312 | TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \ |
313 | << ::tensorflow::register_op::OpDefBuilderWrapper(name) |
314 | |
315 | #define REGISTER_OP(name) \ |
316 | TF_ATTRIBUTE_ANNOTATE("tf:op") \ |
317 | TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false) |
318 | |
319 | // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except |
320 | // that the op is registered unconditionally even when selective |
321 | // registration is used. |
322 | #define REGISTER_SYSTEM_OP(name) \ |
323 | TF_ATTRIBUTE_ANNOTATE("tf:op") \ |
324 | TF_ATTRIBUTE_ANNOTATE("tf:op:system") \ |
325 | TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true) |
326 | |
327 | } // namespace tensorflow |
328 | |
329 | #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ |
330 | |