1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | // clang-format off |
22 | // Required for IS_MOBILE_PLATFORM |
23 | #include "tensorflow/core/platform/platform.h" |
24 | // clang-format on |
25 | |
26 | #include "absl/container/flat_hash_map.h" |
27 | #include "absl/types/optional.h" |
28 | #include "absl/types/variant.h" |
29 | #include "tensorflow/core/framework/attr_value.pb.h" |
30 | #include "tensorflow/core/framework/attr_value_util.h" |
31 | #include "tensorflow/core/framework/cancellation.h" |
32 | #include "tensorflow/core/framework/function.pb.h" |
33 | #include "tensorflow/core/framework/node_def_util.h" |
34 | #include "tensorflow/core/framework/op.h" |
35 | #include "tensorflow/core/framework/op_kernel.h" |
36 | #include "tensorflow/core/framework/registration/registration.h" |
37 | #include "tensorflow/core/framework/types.h" |
38 | #include "tensorflow/core/lib/gtl/flatmap.h" |
39 | #include "tensorflow/core/lib/hash/hash.h" |
40 | #include "tensorflow/core/lib/random/random.h" |
41 | #include "tensorflow/core/platform/env.h" |
42 | #include "tensorflow/core/platform/macros.h" |
43 | #include "tensorflow/core/platform/mutex.h" |
44 | #include "tensorflow/core/platform/protobuf.h" |
45 | #include "tensorflow/core/platform/threadpool_interface.h" |
46 | #include "tensorflow/core/protobuf/config.pb.h" |
47 | #if !defined(IS_MOBILE_PLATFORM) |
48 | #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" |
49 | #endif // IS_MOBILE_PLATFORM |
50 | |
51 | namespace tensorflow { |
52 | |
53 | class CollectiveExecutor; |
54 | class DeviceSet; |
55 | class Graph; |
56 | class GraphDef; |
57 | class OpKernel; |
58 | class ProcessFunctionLibraryRuntime; |
59 | class ResourceMgr; |
60 | class Rendezvous; |
61 | class ScopedStepContainer; |
62 | class StepStatsCollectorInterface; |
63 | class Node; |
64 | |
65 | // FunctionDefHelper::Create is a convenient helper to construct a |
66 | // FunctionDef proto. |
67 | // E.g., |
68 | // FunctionDef my_func = FunctionDefHelper::Create( |
69 | // "my_func_name", |
70 | // {"x:T", "y:T" /* one string per argument */}, |
71 | // {"z:T" /* one string per return value */}, |
72 | // {"T: {float, double}" /* one string per attribute */}, |
73 | // { |
74 | // {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}} |
75 | // /* one entry per function node */ |
76 | // }, |
77 | // /* Mapping between function returns and function node outputs. */ |
78 | // {{"z", "o:z"}}); |
79 | // |
80 | // For the old Function::Node approach, use FunctionDefHelper::Define() |
81 | // E.g., |
82 | // FunctionDef my_func = FunctionDefHelper::Define( |
83 | // "my_func_name", |
84 | // {"x:T", "y:T" /* one string per argument */}, |
85 | // {"z:T" /* one string per return value */}, |
86 | // {"T: {float, double}" /* one string per attribute */}, |
87 | // { |
88 | // {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} |
89 | // /* one entry per function node */ |
90 | // }); |
91 | class FunctionDefHelper { |
92 | public: |
93 | // AttrValueWrapper has copy constructors for the type T so that |
94 | // it's easy to construct a simple AttrValue proto. |
95 | // |
96 | // If T is a string type (const char*, string, or StringPiece), and |
97 | // it starts with "$", we construct a AttrValue of "placeholder". |
98 | // |
99 | // E.g., |
100 | // std::<string, AttrValueWrapper> x = {"T", "$T"} |
101 | // is a named attr value placeholder. |
102 | struct AttrValueWrapper { |
103 | AttrValue proto; |
104 | |
105 | AttrValueWrapper() {} |
106 | |
107 | template <typename T> |
108 | AttrValueWrapper(T val) { // NOLINT(runtime/explicit) |
109 | SetAttrValue(val, &proto); |
110 | } |
111 | |
112 | private: |
113 | void InitFromString(StringPiece val); |
114 | }; |
115 | |
116 | // Constructs an AttrValue.func given the "name" and "attrs". |
117 | static AttrValueWrapper FunctionRef( |
118 | const std::string& name, |
119 | gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs); |
120 | static AttrValueWrapper FunctionRef(const std::string& name) { |
121 | return FunctionRef(name, {}); |
122 | } |
123 | |
124 | // Node is used to construct FunctionDef.Node using initialization |
125 | // lists. E.g., |
126 | // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y |
127 | // |
128 | // If the op has no inputs, then name is be specified. |
129 | // Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype", |
130 | // "DT_FLOAT"}, |
131 | // {"update0"}, "CPU:0", "update1"}} |
132 | struct Node { |
133 | // When constructing a NodeDef, the first entry in ret is used as |
134 | // the node name, the remaining values are ignored. |
135 | std::vector<string> ret; |
136 | std::string op; |
137 | std::vector<string> arg; |
138 | std::vector<std::pair<string, AttrValueWrapper>> attr; |
139 | std::vector<string> dep; |
140 | std::string device; |
141 | |
142 | // Required if the op has zero outputs. Otherwise, ret[0] used as name if |
143 | // name is left empty. |
144 | std::string name; |
145 | |
146 | std::string GetName() const { |
147 | if (!name.empty()) return name; |
148 | CHECK(!ret.empty()); |
149 | return ret[0]; |
150 | } |
151 | std::vector<string> original_node_names; |
152 | std::vector<string> original_func_names; |
153 | |
154 | NodeDef ToNodeDef() const; |
155 | }; |
156 | |
157 | // Creates a FunctionDef from the given parameters. Node inputs must use |
158 | // function encoding (node_name:output_name[:output_index]). |
159 | // - `ret_def` holds a mapping from the function output names from `out_def` |
160 | // to the node outputs from `node_def`. |
161 | // - `control_ret_def` holds a mapping from the function control |
162 | // output names to the nodes from `node_def`. |
163 | static FunctionDef Create( |
164 | const std::string& function_name, gtl::ArraySlice<string> in_def, |
165 | gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def, |
166 | gtl::ArraySlice<Node> node_def, |
167 | gtl::ArraySlice<std::pair<string, string>> ret_def, |
168 | gtl::ArraySlice<std::pair<string, string>> control_ret_def); |
169 | |
170 | // Creates a FunctionDef from the given parameters. Node inputs must use |
171 | // function encoding (node_name:output_name[:output_index]). |
172 | // - `ret_def` holds a mapping from the function output names from `out_def` |
173 | // to the node outputs from `node_def`. |
174 | static FunctionDef Create(const std::string& function_name, |
175 | gtl::ArraySlice<string> in_def, |
176 | gtl::ArraySlice<string> out_def, |
177 | gtl::ArraySlice<string> attr_def, |
178 | gtl::ArraySlice<Node> node_def, |
179 | gtl::ArraySlice<std::pair<string, string>> ret_def); |
180 | |
181 | // TODO(josh11b): Get rid of these and transition to the one above. |
182 | static FunctionDef Define(const std::string& function_name, |
183 | gtl::ArraySlice<string> arg_def, |
184 | gtl::ArraySlice<string> ret_def, |
185 | gtl::ArraySlice<string> attr_def, |
186 | gtl::ArraySlice<Node> node_def); |
187 | |
188 | // Defines an anonymous function. I.e., its name is not relevant. |
189 | static FunctionDef Define(gtl::ArraySlice<string> arg_def, |
190 | gtl::ArraySlice<string> ret_def, |
191 | gtl::ArraySlice<string> attr_def, |
192 | gtl::ArraySlice<Node> node_def); |
193 | |
194 | // Helpers to construct a constant scalar. |
195 | template <typename T> |
196 | static Node Const(const std::string& name, const T& val) { |
197 | Node n = {{name}, "Const" }; |
198 | const DataType dtype = DataTypeToEnum<T>::value; |
199 | n.attr.push_back({"dtype" , dtype}); |
200 | Tensor t(dtype, TensorShape({})); |
201 | t.scalar<T>()() = val; |
202 | n.attr.push_back({"value" , t}); |
203 | return n; |
204 | } |
205 | |
206 | template <typename T> |
207 | static Node Const(const std::string& name, gtl::ArraySlice<T> vals) { |
208 | Node n = {{name}, "Const" }; |
209 | const DataType dtype = DataTypeToEnum<T>::value; |
210 | n.attr.push_back({"dtype" , dtype}); |
211 | int64_t num = vals.size(); |
212 | Tensor t(dtype, TensorShape({num})); |
213 | for (size_t i = 0; i < vals.size(); ++i) { |
214 | t.flat<T>()(i) = vals[i]; |
215 | } |
216 | n.attr.push_back({"value" , t}); |
217 | return n; |
218 | } |
219 | }; |
220 | |
221 | template <> |
222 | inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { |
223 | InitFromString(val); |
224 | } |
225 | |
226 | template <> |
227 | inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( |
228 | const std::string& val) { |
229 | InitFromString(val); |
230 | } |
231 | |
232 | template <> |
233 | inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { |
234 | InitFromString(val); |
235 | } |
236 | |
237 | // Instantiate a function. |
238 | // |
239 | // "fdef" encodes a TF function with some attrs in fdef.signature.attr |
240 | // containing placeholders. InstantiateFunction binds these |
241 | // placeholders and produces an instantiated function encoded in |
242 | // "result.gdef". The value to substitute a placeholder is given by |
243 | // "attr_values", which is a map from a placeholder name to an attr |
244 | // value. |
245 | // |
246 | // InstantiateFunction calls "get_function" to find signatures of other |
247 | // functions and primitive ops. |
248 | |
249 | // GetFunctionSignature(func name, opdef) returns OK if the func name is found |
250 | // and opdef is filled with a pointer to the corresponding signature |
251 | // (a OpDef proto). Otherwise, returns an error. |
252 | typedef std::function<Status(const string&, const OpDef**)> |
253 | GetFunctionSignature; |
254 | |
255 | struct InstantiationResult { |
256 | DataTypeVector arg_types; |
257 | DataTypeVector ret_types; |
258 | std::vector<NodeDef> nodes; |
259 | }; |
260 | Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, |
261 | GetFunctionSignature get_function, |
262 | InstantiationResult* result); |
263 | |
264 | // Returns a debug string for a function definition. |
265 | // |
266 | // The returned text is multiple-line. It is intended to be |
267 | // human-readable rather than being friendly to parsers. It is _NOT_ |
268 | // intended to be the canonical string representation of "func_def". |
269 | // Particularly, it may not include all information presented in |
270 | // "func_def" (e.g., comments, description of the function arguments, |
271 | // etc.) |
272 | std::string DebugString(const FunctionDef& func_def); |
273 | std::string DebugString(const GraphDef& instantiated_func_def); |
274 | std::string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes); |
275 | |
276 | // Returns a debug string for a top level graph (the main program and |
277 | // its supporting functions defined in its library). |
278 | std::string DebugStringWhole(const GraphDef& gdef); |
279 | |
280 | // Returns true if f1 == f2. Compares all fields, including descriptions. Order |
281 | // of NodeDefs doesn't matter. |
282 | bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); |
283 | |
284 | // Return a hash of `fdef` that is consistent with FunctionDefsEqual method. |
285 | // In other words, if two fdefs compare equal, their hash values will be the |
286 | // same. |
287 | uint64 FunctionDefHash(const FunctionDef& fdef); |
288 | |
289 | class CallFrameInterface { |
290 | public: |
291 | virtual ~CallFrameInterface() {} |
292 | |
293 | virtual size_t num_args() const = 0; |
294 | virtual size_t num_retvals() const = 0; |
295 | |
296 | virtual Status GetArg(int index, const Tensor** val) = 0; |
297 | |
298 | // Optimized implementation of `GetArg()` that allows the caller to take |
299 | // ownership of the tensor. This method may only be called once per |
300 | // value of `index` and `CallFrameInterface` instance. |
301 | // |
302 | // REQUIRES: `this->CanConsumeArg(index) == true`. |
303 | virtual void ConsumeArg(int index, Tensor* val) { |
304 | LOG(ERROR) << "This `CallFrameInterface` implementation does not support " |
305 | "consuming arguments." ; |
306 | } |
307 | virtual bool CanConsumeArg(int index) const { return false; } |
308 | |
309 | virtual Status SetRetval(int index, const Tensor& val) = 0; |
310 | }; |
311 | |
312 | // Represents a function call frame. I.e., the data structure used to |
313 | // pass arguments to a function and retrieve its results. |
314 | // |
315 | // Runtime must arrange accesses to one FunctionCallFrame s.t. |
316 | // 1. SetArgs() happens before any GetArg(); |
317 | // 2. GetRetvals happens after all SetRetval(); |
318 | class FunctionCallFrame : public CallFrameInterface { |
319 | public: |
320 | FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); |
321 | ~FunctionCallFrame() override; |
322 | |
323 | // Caller methods. |
324 | Status SetArgs(gtl::ArraySlice<Tensor> args); |
325 | Status GetRetvals(std::vector<Tensor>* rets) const; |
326 | |
327 | // Moves the return values from the frame to rets. If allow_dead_tensors is |
328 | // false it will fail if any of the retvals do not have a value. |
329 | Status ConsumeRetvals(std::vector<Tensor>* rets, bool allow_dead_tensors); |
330 | |
331 | size_t num_args() const override { return arg_types_.size(); } |
332 | size_t num_retvals() const override { return ret_types_.size(); } |
333 | |
334 | // Callee methods. |
335 | Status GetArg(int index, const Tensor** val) override; |
336 | Status SetRetval(int index, const Tensor& val) override; |
337 | |
338 | private: |
339 | DataTypeVector arg_types_; |
340 | DataTypeVector ret_types_; |
341 | gtl::InlinedVector<Tensor, 4> args_; |
342 | struct Retval { |
343 | bool has_val = false; |
344 | Tensor val; |
345 | }; |
346 | gtl::InlinedVector<Retval, 4> rets_; |
347 | |
348 | TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); |
349 | }; |
350 | |
351 | // Language agnostic stack traces. |
352 | class AbstractStackTrace { |
353 | public: |
354 | struct TracePrintingOptions { |
355 | // Show inline the contents of each stack line. |
356 | bool show_line_contents = false; |
357 | |
358 | // Drop the common largest prefix of all filenames in stack frames. |
359 | bool filter_common_prefix = false; |
360 | |
361 | // Do not show internal frames. |
362 | bool drop_internal_frames = false; |
363 | }; |
364 | |
365 | virtual ~AbstractStackTrace() {} |
366 | |
367 | // The returned span is alive as long as the AbstractStackTrace is alive. |
368 | virtual absl::Span<StackFrame const> ToFrames() const = 0; |
369 | |
370 | // Returns the last stack frame from user code, attempting to ignore the |
371 | // framework code. Returns an empty frame if no such stack frame was found. |
372 | virtual StackFrame LastUserFrame() const = 0; |
373 | virtual std::string ToString(const TracePrintingOptions& opts) const = 0; |
374 | }; |
375 | |
376 | using StackTracesMap = |
377 | std::unordered_map<std::string, |
378 | std::shared_ptr<tensorflow::AbstractStackTrace>>; |
379 | |
380 | // Helper to maintain a map between function names in a given |
381 | // FunctionDefLibrary and function definitions. |
382 | // |
383 | // This class is thread-safe. |
384 | class FunctionLibraryDefinition : public OpRegistryInterface { |
385 | public: |
386 | // Ops created for function arguments bear the name given by `kArgOp`; those |
387 | // created for return values bear the name given by `kRetOp`. |
388 | static constexpr const char* const kArgOp = "_Arg" ; |
389 | static constexpr const char* const kDeviceArgOp = "_DeviceArg" ; |
390 | static constexpr const char* const kRetOp = "_Retval" ; |
391 | static constexpr const char* const kDeviceRetOp = "_DeviceRetval" ; |
392 | static constexpr const char* const kIntsOnDeviceAttr = |
393 | "experimental_ints_on_device" ; |
394 | static constexpr const char* const kSharedRendezvousAttr = |
395 | "shared_rendezvous" ; |
396 | |
397 | static constexpr const char* const kGradientOp = "SymbolicGradient" ; |
398 | static constexpr const char* const kFuncAttr = "f" ; |
399 | |
400 | // Note: This constructor grabs `lib_def`'s lock in shared mode. |
401 | FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); |
402 | FunctionLibraryDefinition(const OpRegistryInterface* default_registry, |
403 | const FunctionDefLibrary& lib_def = {}); |
404 | ~FunctionLibraryDefinition() override; |
405 | |
406 | FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = |
407 | delete; |
408 | |
409 | // Returns True if the library contains `func`, False otherwise. |
410 | bool Contains(const std::string& func) const; |
411 | |
412 | // Returns nullptr if "func" is not defined in "lib_def". Otherwise, |
413 | // returns its definition proto. |
414 | // |
415 | // NB: This function returns a borrowed pointer, which can be invalidated by a |
416 | // subsequent call to `ReplaceFunction()` with the given name. |
417 | const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_); |
418 | |
419 | // Adds function definition 'fdef' to this function library. |
420 | // Returns status 'ok' on success, or error otherwise. This is a no-op if |
421 | // 'fdef' already exists in this function library. |
422 | // If 'fdef' is successfully added to the library, it will be accessible |
423 | // from 'LookUp' and included in the proto returned by 'ToProto'. |
424 | // This operation is atomic. |
425 | // |
426 | // Associates `graph` with a function `func_name`. Lifetime assumption: |
427 | // `graph` has to outlive all instantiated graphs. |
428 | Status AddFunctionDef(const FunctionDef& fdef, |
429 | const StackTracesMap& stack_traces = {}) |
430 | TF_LOCKS_EXCLUDED(mu_); |
431 | |
432 | // Adds gradient definition 'grad' to this function library. |
433 | // This is a no-op if 'grad' already exists in this function library. |
434 | // If 'grad' is successfully added, it will be accessible via 'FindGradient' |
435 | // and included in the proto returned by 'ToProto'. |
436 | // This operation is atomic. |
437 | Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); |
438 | |
439 | // Replaces the function corresponding to `func` with `fdef`. Returns |
440 | // a non-OK status if "func" was not found in the library, OK otherwise. |
441 | // Please be careful when replacing function: make sure all previous pointers |
442 | // returned by `Find()` are no longer in use. |
443 | Status ReplaceFunction(const std::string& func, const FunctionDef& fdef, |
444 | const StackTracesMap& stack_traces = {}) |
445 | TF_LOCKS_EXCLUDED(mu_); |
446 | |
447 | // Replaces the gradient corresponding to `grad.function_name()`. Returns |
448 | // a non-OK status if "grad.function_name()" was not found in the library, OK |
449 | // otherwise. |
450 | Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); |
451 | |
452 | // Removes the function corresponding to 'func'. Returns a non-OK status if |
453 | // 'func' was not found in the library, OK otherwise. |
454 | // Please be careful when removing function: make sure there are no other |
455 | // nodes using the function, and all previous pointers returned by `Find()` |
456 | // are no longer in use. |
457 | Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); |
458 | |
459 | // Removes all the functions and gradient functions. |
460 | void Clear() TF_LOCKS_EXCLUDED(mu_); |
461 | |
462 | // Adds the functions and gradients in 'other' to this function library. |
463 | // Duplicate functions and gradients are ignored. |
464 | // This operation is atomic. |
465 | Status AddLibrary(const FunctionLibraryDefinition& other) |
466 | TF_LOCKS_EXCLUDED(mu_); |
467 | |
468 | // Adds the functions and gradients in 'lib_def' to this function library. |
469 | // Duplicate functions and gradients are ignored. |
470 | // This operation is atomic. |
471 | Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_); |
472 | |
473 | // If the gradient function for 'func' is specified explicitly in |
474 | // the library, returns the gradient function name. Otherwise, |
475 | // returns an empty string. |
476 | std::string FindGradient(const std::string& func) const |
477 | TF_LOCKS_EXCLUDED(mu_); |
478 | |
479 | // OpRegistryInterface method. Useful for constructing a Graph. |
480 | // |
481 | // If "op" is defined in the library, returns its signature. |
482 | // Otherwise, assume "op" is a primitive op and returns its op |
483 | // signature and shape inference function. |
484 | // |
485 | // NB: This function outputs a borrowed pointer, which can be invalidated by a |
486 | // subsequent call to `ReplaceFunction()` with the given name. |
487 | Status LookUp(const std::string& op_type_name, |
488 | const OpRegistrationData** op_reg_data) const override |
489 | TF_LOCKS_EXCLUDED(mu_); |
490 | |
491 | // Generates new function name with the specified prefix that is unique |
492 | // across this library. |
493 | std::string UniqueFunctionName(StringPiece prefix) const |
494 | TF_LOCKS_EXCLUDED(mu_); |
495 | |
496 | // Given a node def 'ndef', inspects attributes of the callee |
497 | // function to derive the attribute 'value' for 'attr'. Returns OK |
498 | // iff the attribute is given by the function's definition. |
499 | // TODO(irving): Remove; keep only the const Node& version. |
500 | template <typename T> |
501 | Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const; |
502 | |
503 | // Given a node, inspects attributes of the callee function to derive the |
504 | // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the |
505 | // function's definition. |
506 | template <typename T> |
507 | Status GetAttr(const Node& node, const std::string& attr, T* value) const; |
508 | |
509 | // Returns a proto representation of the state of this function library. |
510 | FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_); |
511 | |
512 | size_t num_functions() const { |
513 | tf_shared_lock l(mu_); |
514 | return function_defs_.size(); |
515 | } |
516 | |
517 | // Returns all the function names in the FunctionLibraryDefinition. |
518 | std::vector<string> ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_); |
519 | |
520 | const OpRegistryInterface* default_registry() const { |
521 | return default_registry_; |
522 | } |
523 | void set_default_registry(const OpRegistryInterface* registry) { |
524 | default_registry_ = registry; |
525 | } |
526 | |
527 | // Returns a copy of `*this` with only the subset of functions that are |
528 | // reachable from the nodes of `graph` or `func`. |
529 | FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const; |
530 | FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const; |
531 | |
532 | // Copies the function named `func` from `other` to this |
533 | // FunctionLibraryDefinition. |
534 | // REQUIRES: `this->default_registry() == other.default_registry()`. |
535 | // Returns OK on success, or error otherwise. This is a no-op if a function |
536 | // name `func` already exists in this function library, and has the same |
537 | // implementation as in `other`. If the implementations conflict, an invalid |
538 | // argument error is returned. |
539 | Status CopyFunctionDefFrom(const std::string& func, |
540 | const FunctionLibraryDefinition& other) |
541 | TF_LOCKS_EXCLUDED(mu_); |
542 | |
543 | // Returns graph with debug stack traces for the given function, or `nullptr` |
544 | // if none found. |
545 | const StackTracesMap& GetStackTraces(const std::string& func_name) const { |
546 | tf_shared_lock l(mu_); |
547 | std::shared_ptr<FunctionDefAndOpRegistration> entry = FindHelper(func_name); |
548 | if (entry) { |
549 | return entry->stack_traces; |
550 | } |
551 | static const auto* empty_map = new StackTracesMap; |
552 | return *empty_map; |
553 | } |
554 | |
555 | private: |
556 | // Shape inference for functions is handled separately by ShapeRefiner. |
557 | |
558 | struct FunctionDefAndOpRegistration { |
559 | explicit FunctionDefAndOpRegistration( |
560 | const FunctionDef& fdef_in, const StackTracesMap& stack_traces = {}); |
561 | |
562 | const FunctionDef fdef; |
563 | const OpRegistrationData op_registration_data; |
564 | const StackTracesMap stack_traces; |
565 | }; |
566 | |
567 | std::shared_ptr<FunctionDefAndOpRegistration> FindHelper( |
568 | const string& func) const TF_SHARED_LOCKS_REQUIRED(mu_); |
569 | std::string FindGradientHelper(const std::string& func) const |
570 | TF_SHARED_LOCKS_REQUIRED(mu_); |
571 | |
572 | Status AddHelper(std::shared_ptr<FunctionDefAndOpRegistration> registration, |
573 | bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
574 | |
575 | // Same as AddFunctionDef/AddGradientDef except these methods set |
576 | // `added` to true if the `fdef`/`grad` were actually added to this. |
577 | Status AddFunctionDefHelper(const FunctionDef& fdef, |
578 | const StackTracesMap& stack_traces, bool* added) |
579 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
580 | Status AddGradientDefHelper(const GradientDef& grad, bool* added) |
581 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
582 | |
583 | // Helper function for GetAttr. Returns the FunctionDef* to get the |
584 | // attr from. |
585 | const FunctionDef* GetAttrImpl(const NodeDef& ndef) const |
586 | TF_LOCKS_EXCLUDED(mu_); |
587 | |
588 | // Remove all functions in `funcs` and all gradients of functions in |
589 | // `funcs_with_grads` from this library. |
590 | Status Remove(const std::vector<string>& funcs, |
591 | const std::vector<string>& funcs_with_grads) |
592 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
593 | |
594 | // Remove `func` from the library. Returns non-OK Status unless `func` is in |
595 | // the library. This should only be called when there is a guarantee that the |
596 | // function being removed hasn't been retrieved with `Find`. |
597 | Status RemoveFunctionHelper(const std::string& func) |
598 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
599 | |
600 | // Remove gradient of function `func` from the library. Returns non-OK Status |
601 | // unless `func` has a gradient. |
602 | Status RemoveGradient(const std::string& func) |
603 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
604 | |
605 | mutable mutex mu_; |
606 | const OpRegistryInterface* default_registry_; |
607 | gtl::FlatMap<string, std::shared_ptr<FunctionDefAndOpRegistration>> |
608 | function_defs_ TF_GUARDED_BY(mu_); |
609 | gtl::FlatMap<string, string> func_grad_ TF_GUARDED_BY(mu_); |
610 | }; |
611 | |
612 | // Forward declare. Defined in common_runtime/function.h |
613 | struct FunctionBody; |
614 | |
615 | // Forward declare. Defined in common_runtime/device.h |
616 | class Device; |
617 | // Forward declare. Defined in common_runtime/device_mgr.h |
618 | class DeviceMgr; |
619 | |
620 | // Index of an _Arg node. |
621 | struct FunctionArgIndex { |
622 | explicit FunctionArgIndex(const int index) : index(index) {} |
623 | FunctionArgIndex(const int index, const int sub_index) |
624 | : index(index), sub_index(sub_index) {} |
625 | |
626 | // The value of the attribute "Index" of the _Arg node. |
627 | int index; |
628 | // Set only when the _Arg node represents multiple arguments (e.g. an _Arg |
629 | // node is replicated to multiple devices/subgraphs). Use sub-index to |
630 | // distinguish arguments with the same index. |
631 | int sub_index = -1; |
632 | }; |
633 | |
634 | class FunctionLibraryRuntime { |
635 | public: |
636 | virtual ~FunctionLibraryRuntime() {} |
637 | |
638 | // Instantiate a function with the given "attrs". |
639 | // |
640 | // Returns OK and fills in "handle" if the instantiation succeeds. |
641 | // Otherwise returns an error and "handle" is undefined. |
642 | struct InstantiateOptions { |
643 | // The canonical device name of the device on which the function |
644 | // should be instantiated. If empty, the function will be |
645 | // instantiated on the local device. |
646 | std::string target; |
647 | |
648 | // Should the function be instantiated as a multi-device function? |
649 | bool is_multi_device_function = false; |
650 | |
651 | // If true, graph passes will be skipped when instantiating the function |
652 | // since they have already run on the main function side. |
653 | bool is_component_function = false; |
654 | |
655 | // For multi-device functions, a vector of canonical device names for |
656 | // function's inputs. The device of resource inputs must be the device |
657 | // backing the resource, not the CPU device backing the resource handle. |
658 | // Must have the same length as number of inputs to the function. |
659 | std::vector<string> input_devices; |
660 | |
661 | // For multi-device functions, a vector of canonical device names for |
662 | // function's outputs. |
663 | // |
664 | // (a) If specified (must have the same length as number of outputs): |
665 | // |
666 | // Specified devices will be assigned to Retval nodes inserted into the |
667 | // function body graph in place of function outputs. It is allowed to |
668 | // specify output device as empty string, in this case Retval device |
669 | // assignment will be inferred later when function graph will be placed |
670 | // before partitioning (this is required for resource outputs). Placer will |
671 | // respect colocation constraints. |
672 | // |
673 | // (b) If not specified: |
674 | // |
675 | // Function runtime will infer Retval device by following input edges, until |
676 | // it will reach a node with a device specification. This device |
677 | // specification must identify a unique device, i.e. a general specification |
678 | // like "job:foo" matching multiple devices will result in an error. |
679 | // |
680 | // IMPORTANT: Resource outputs |
681 | // |
682 | // Multi device functions might return resources on a devices different from |
683 | // the function call device. If output device is not specified for the |
684 | // resource output, and node producing that resource is a function call, |
685 | // runtime will leave device specification empty and will rely on Placer to |
686 | // infer correct device. |
687 | std::vector<string> output_devices; |
688 | |
689 | // If set, it indicates the original output indices of a component function. |
690 | absl::optional<std::vector<int>> ret_indices = absl::nullopt; |
691 | |
692 | // Maps from a CompositeDevice name to a list of underlying physical |
693 | // devices. |
694 | absl::flat_hash_map<string, const std::vector<string>*> composite_devices; |
695 | |
696 | // This interface is EXPERIMENTAL and subject to change. |
697 | // |
698 | // For multi-device functions, a mapping from _Arg node index to type and |
699 | // shape for input resources. |
700 | // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th |
701 | // argument type must be DT_RESOURCE. |
702 | std::unordered_map<int, DtypeAndPartialTensorShape> |
703 | input_resource_dtypes_and_shapes; |
704 | |
705 | // This interface is EXPERIMENTAL and subject to change. |
706 | // |
707 | // If non-null, the runtime will use `lib_def` to resolve function(s) named |
708 | // in `function_name` and `attrs`. Otherwise, the runtime will use its |
709 | // internal library. |
710 | // |
711 | // NOTE(mrry): If provided, all functions defined in `lib_def` must be |
712 | // self-contained, and cannot refer to functions defined in other libraries. |
713 | const FunctionLibraryDefinition* lib_def = nullptr; |
714 | |
715 | // This interface is EXPERIMENTAL and subject to change. |
716 | // |
717 | // If non-empty, the runtime will use `state_handle` to identify |
718 | // cached state related the instantiated function. Two functions |
719 | // of the same name and attrs, instantiated with the same |
720 | // `state_handle` will have the same handle and share the same |
721 | // state (in stateful kernels); and two functions with different |
722 | // values for `state_handle` will have independent state. |
723 | std::string state_handle; |
724 | |
725 | // This interface is EXPERIMENTAL and subject to change. |
726 | // |
727 | // Instantiates the function using an executor of the given type. If empty, |
728 | // the default TensorFlow executor will be used. |
729 | std::string executor_type; |
730 | |
731 | // If true, the runtime will attempt to create kernels for the function at |
732 | // instantiation time, rather than on the first run. This can be used to |
733 | // surface errors earlier. |
734 | bool create_kernels_eagerly = false; |
735 | |
736 | // This interface is EXPERIMENTAL and subject to change. |
737 | // |
738 | // Instantiates the function with the provided config_proto. |
739 | ConfigProto config_proto; |
740 | |
741 | // If provided, this optimization function will be invoked before |
742 | // the placer for multi-device functions. |
743 | std::function<Status(std::vector<string> /*ret_node_names*/, |
744 | std::vector<string> /*keep_node_names*/, |
745 | FunctionLibraryDefinition*, const DeviceSet&, |
746 | Device* /*cpu_device*/, std::unique_ptr<Graph>*)> |
747 | optimize_graph_fn; |
748 | |
749 | // If set, partitioned functions will be added to `graph_collector`. |
750 | // `graph_collector` must be alive during the call to Instantiate. |
751 | GraphCollector* graph_collector = nullptr; |
752 | |
753 | // Indicates whether the multi-device function backend should default the |
754 | // placement of ops without request device to `target`. |
755 | bool default_device_to_target = true; |
756 | |
757 | // If true, the optimized Graph will be stored so that |
758 | // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized |
759 | // Graph. Otherwise, the unoptimized function Graph will be returned. |
760 | bool include_optimized_graph_in_debug_string = false; |
761 | |
762 | // If true, the function library runtime cache the function instantiation. |
763 | bool use_function_cache = false; |
764 | |
765 | // This interface is EXPERIMENTAL and subject to change. |
766 | // |
767 | // If True, allow optimizations which should be targeted at a limited |
768 | // set of small functions. For example, running kernels synchronously can |
769 | // be faster under some conditions. |
770 | bool allow_small_function_optimizations = false; |
771 | |
772 | // This interface is EXPERIMENTAL and subject to change. |
773 | // |
774 | // If True, allow graphs containing control flow nodes to be run on the |
775 | // single threaded executor. |
776 | bool allow_control_flow_sync_execution = false; |
777 | |
778 | // TODO(b/176491312): Remove this if shape inference on import flag is |
779 | // removed. If True, allows mlir roundtrip to run shape inference on import. |
780 | bool shape_inference_on_tfe_dialect_import = true; |
781 | |
782 | // Force int32 _Arg and _Retvals nodes to be left on device instead of |
783 | // pinning to host. |
784 | // |
785 | // Note that we do not pin int32 nodes to host for subgraphs running in |
786 | // TPU/XLA devices. So this is mainly used to handle the case of multi-CPU |
787 | // and GPU (non-XLA) graphs. |
788 | bool int_args_and_retvals_on_device = false; |
789 | |
790 | // This interface is EXPERIMENTAL and subject to change. |
791 | // |
792 | // Instantiates the function for XLA compilation on device_type. If empty, |
793 | // function is not compiled. |
794 | std::string xla_compile_device_type; |
795 | }; |
796 | typedef uint64 Handle; |
797 | virtual Status Instantiate(const std::string& function_name, AttrSlice attrs, |
798 | const InstantiateOptions& options, |
799 | Handle* handle) = 0; |
800 | Status Instantiate(const std::string& function_name, AttrSlice attrs, |
801 | Handle* handle) { |
802 | auto opts = absl::make_unique<InstantiateOptions>(); |
803 | return Instantiate(function_name, attrs, *opts, handle); |
804 | } |
805 | |
806 | // Releases state associated with the handle. |
807 | virtual Status ReleaseHandle(Handle handle) = 0; |
808 | |
809 | // Returns the function body for the instantiated function given its |
810 | // handle 'h'. Returns nullptr if "h" is not found. |
811 | // |
812 | // *this keeps the ownership of the returned object, which remains alive |
813 | // as long as *this. |
814 | virtual const FunctionBody* GetFunctionBody(Handle h) = 0; |
815 | |
816 | // Returns the return types for the function identified by handle `h`. |
817 | virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0; |
818 | |
819 | // Asynchronously invokes the instantiated function identified by |
820 | // "handle". |
821 | // |
822 | // If function execution succeeds, "done" is called with OK and |
823 | // "*rets" is filled with the function's return values. Otherwise, |
824 | // "done" is called with an error status. |
825 | // |
826 | // Does not take ownership of "rets". |
827 | // In the cross-process scenario, runner isn't used for making the Async |
828 | // RPC calls. |
829 | struct Options { |
830 | Options() {} |
831 | explicit Options(const int64_t step_id) : step_id(step_id) {} |
832 | // Choose a step ID that is guaranteed not to clash with any |
833 | // Session-generated step ID. DirectSession only generates |
834 | // non-negative step IDs (contiguous, starting from 0), and |
835 | // MasterSession generates 56-bit random step IDs whose MSB is |
836 | // always 0, so a negative random step ID should suffice. |
837 | const int64_t step_id = -std::abs(static_cast<int64_t>(random::New64())); |
838 | |
839 | // op_id of the function running in eager mode. Set when we want to copy |
840 | // remote outputs lazily. All components of a remote multi-device function |
841 | // should use the same op_id, in order to correctly map remote output |
842 | // tensors to the remote TensorHandles in the default device. |
843 | absl::optional<int64_t> op_id = absl::nullopt; |
844 | |
845 | RendezvousInterface* rendezvous = nullptr; |
846 | CancellationManager* cancellation_manager = nullptr; |
847 | CollectiveExecutor* collective_executor = nullptr; |
848 | ScopedStepContainer* step_container = nullptr; |
849 | StepStatsCollectorInterface* stats_collector = nullptr; |
850 | CoordinationServiceAgent* coordination_service_agent = nullptr; |
851 | |
852 | absl::optional<ManagedStackTrace> stack_trace = absl::nullopt; |
853 | |
854 | std::function<void(std::function<void()>)>* runner = nullptr; |
855 | |
856 | // Parameters for remote function execution. |
857 | bool remote_execution = false; |
858 | std::string source_device = "" ; // Fully specified device name. |
859 | |
860 | // Allocator attributes specifying where the args are / rets should be put. |
861 | // These should either be {} or match the length of args / retvals. If {}, |
862 | // the default allocator attributes will be assumed for all args / retvals. |
863 | std::vector<AllocatorAttributes> args_alloc_attrs; |
864 | std::vector<AllocatorAttributes> rets_alloc_attrs; |
865 | |
866 | // If true, we create a new IntraProcessRendezvous, else use the existing |
867 | // one. |
868 | bool create_rendezvous = false; |
869 | |
870 | // If True, allow returning dead tensors. |
871 | bool allow_dead_tensors = false; |
872 | |
873 | // If True, hint that all kernels should be treated as "inexpensive", and |
874 | // hence executed on the scheduling thread. |
875 | bool run_all_kernels_inline = false; |
876 | |
877 | // If not null, use this thread pool for intra op scheduling. |
878 | thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; |
879 | |
880 | // Returns a human readable representation of this. |
881 | std::string DebugString() const; |
882 | }; |
883 | typedef std::function<void(const Status&)> DoneCallback; |
884 | virtual void Run(const Options& opts, Handle handle, |
885 | gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, |
886 | DoneCallback done) = 0; |
887 | virtual void Run(const Options& opts, Handle handle, |
888 | CallFrameInterface* call_frame, DoneCallback done) = 0; |
889 | |
890 | virtual Status RunSync(Options opts, Handle handle, |
891 | gtl::ArraySlice<Tensor> args, |
892 | std::vector<Tensor>* rets) = 0; |
893 | virtual Status RunSync(Options opts, Handle handle, |
894 | CallFrameInterface* call_frame) = 0; |
895 | |
896 | // Creates a "kernel" for the given NodeProperties "props". |
897 | // |
898 | // If succeeds, returns OK and the caller takes the ownership of the |
899 | // returned "*kernel". Otherwise, returns an error. |
900 | virtual Status CreateKernel( |
901 | const std::shared_ptr<const NodeProperties>& props, |
902 | OpKernel** kernel) = 0; |
903 | |
904 | // Returns true iff the function named `function_name` is stateful. |
905 | // |
906 | // NOTE(mrry): This method assumes that the runtime is associated with a |
907 | // default function library, and looks up `function_name` in that library. |
908 | // It does not support overriding the function library. |
909 | virtual bool IsStateful(const std::string& function_name) const = 0; |
910 | |
911 | // Returns the device on which the function executes. |
912 | virtual Device* device() = 0; |
913 | virtual const Device* device() const = 0; |
914 | |
915 | // Returns the default runner in which the ops should be launched. If the |
916 | // device on which the function executes has a private thread pool, return |
917 | // runner on the device local thread pool. |
918 | virtual std::function<void(std::function<void()>)>* runner() = 0; |
919 | |
920 | // Get the DeviceMgr from which the device was obtained. |
921 | virtual const DeviceMgr* device_mgr() const = 0; |
922 | |
923 | // Returns the function library definition that backs this runtime. |
924 | // |
925 | // NOTE(mrry): The returned library definition is the default function library |
926 | // for this runtime. The caller may override the function library used by the |
927 | // runtime to instantiate functions, which will not be reflected in the return |
928 | // value of this function. |
929 | virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition() |
930 | const = 0; |
931 | |
932 | // Returns the environment on which the function executes. |
933 | virtual Env* env() = 0; |
934 | |
935 | // Returns the ConfigProto passed to the session used to create the function. |
936 | virtual const ConfigProto* const config_proto() = 0; |
937 | |
938 | // Returns a debug string showing the definition of the function of |
939 | // 'handle'. |
940 | virtual std::string DebugString(Handle handle) = 0; |
941 | |
942 | // Returns the graph version number. |
943 | virtual int graph_def_version() const = 0; |
944 | |
945 | typedef uint64 LocalHandle; |
946 | |
947 | // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to |
948 | // the caller), FunctionLibraryRuntime (owned by the returned |
949 | // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring |
950 | // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime |
951 | // and FunctionLibraryRuntime borrow a pointer to the |
952 | // FunctionLibraryDefinition and so the FunctionLibraryDefinition should |
953 | // outlive both. |
954 | // |
955 | // The `skip_flib_def` argument controls whether the method should clone the |
956 | // FunctionLibraryDefinition (default behavior) or return an empty function |
957 | // library. The latter is used by tf.data, which manages |
958 | // FunctionLibraryDefinitions for its functions independently (and passes |
959 | // these into the FunctionLibraryRuntime through an overlay), to avoid linear |
960 | // runtime w.r.t. to number of functions in the current function library. |
961 | virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
962 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
963 | FunctionLibraryRuntime** out_flr, |
964 | bool skip_flib_def = false) = 0; |
965 | |
966 | // Returns the name of the executor class (in the sense of |
967 | // `ExecutorFactory::GetFactory()`) that will be used based on the given |
968 | // dynamic `options` and static `attrs`. If none is specified, this method |
969 | // will return an empty string, which leaves the decision up to the runtime. |
970 | static std::string ExecutorType(const InstantiateOptions& options, |
971 | AttrSlice attrs); |
972 | }; |
973 | |
974 | // Returns the device of the `arg_index`-th function input. Update |
975 | // `composite_devices` if the input device is a composite device. |
976 | std::string GetFunctionResourceInputDevice( |
977 | const Tensor& input, const int arg_index, const FunctionDef& function_def, |
978 | absl::flat_hash_map<string, std::vector<string>>* composite_devices); |
979 | |
980 | // Returns a canonicalized string for the instantiation of the function of the |
981 | // given "name", attributes "attrs", and "options". |
982 | // |
983 | // The returned string is guaranteed to be stable within one address space. But |
984 | // it may be change as the implementation evolves. Therefore, it should not be |
985 | // persisted or compared across address spaces. |
986 | std::string Canonicalize( |
987 | const std::string& funcname, AttrSlice attrs, |
988 | const FunctionLibraryRuntime::InstantiateOptions& options); |
989 | std::string Canonicalize(const std::string& funcname, AttrSlice attrs); |
990 | |
991 | const FunctionLibraryRuntime::Handle kInvalidHandle = -1; |
992 | const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; |
993 | |
994 | class CustomKernelCreator { |
995 | public: |
996 | virtual ~CustomKernelCreator() {} |
997 | |
998 | // Given a NodeDef 'node_def' and the function library runtime 'flr', |
999 | // validate if the class supports creating such a kernel. |
1000 | virtual bool CanCreateKernel( |
1001 | const FunctionLibraryRuntime& flr, |
1002 | const std::shared_ptr<const NodeProperties>& props) const = 0; |
1003 | |
1004 | // Given a supported NodeDef, returns a kernel that computes the node. |
1005 | virtual Status CreateKernel( |
1006 | FunctionLibraryRuntime* flr, |
1007 | const std::shared_ptr<const NodeProperties>& props, |
1008 | std::unique_ptr<OpKernel>* kernel) const = 0; |
1009 | }; |
1010 | |
1011 | typedef |
1012 | #if !defined(IS_MOBILE_PLATFORM) |
1013 | absl::variant<Tensor, eager::RemoteTensorHandle*> |
1014 | FunctionArg; |
1015 | #else |
1016 | absl::variant<Tensor> |
1017 | FunctionArg; |
1018 | #endif |
1019 | |
1020 | // Either a local tensor or the shape of a remote tensor. |
1021 | typedef absl::variant<Tensor, TensorShape> FunctionRet; |
1022 | |
1023 | // Used to instantiate and run functions in a distributed system. |
1024 | class DistributedFunctionLibraryRuntime { |
1025 | public: |
1026 | virtual ~DistributedFunctionLibraryRuntime() {} |
1027 | |
1028 | // Instantiate a function on a remote target specified in `options.target`, by |
1029 | // sending the name and definition of the function to the remote worker. The |
1030 | // local `handle` is filled for the instantiated function data and can be used |
1031 | // for subsequent run function calls on the remote target. |
1032 | virtual void Instantiate( |
1033 | const std::string& function_name, |
1034 | const FunctionLibraryDefinition& lib_def, AttrSlice attrs, |
1035 | const FunctionLibraryRuntime::InstantiateOptions& options, |
1036 | FunctionLibraryRuntime::LocalHandle* handle, |
1037 | FunctionLibraryRuntime::DoneCallback done) = 0; |
1038 | |
1039 | // Run an instantiated remote function (specified by `handle`) with a list of |
1040 | // input Tensors in `args` and get its output Tensors in `rets`. The input |
1041 | // tensor data will be sent with the function execution request, and must be |
1042 | // available on the current caller side. |
1043 | // opts.runner isn't used for execution. |
1044 | virtual void Run(const FunctionLibraryRuntime::Options& opts, |
1045 | FunctionLibraryRuntime::LocalHandle handle, |
1046 | gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, |
1047 | FunctionLibraryRuntime::DoneCallback done) = 0; |
1048 | |
1049 | // Run an instantiated remote function (specified by `handle`) with a list of |
1050 | // input Tensors or RemoteTensorHandles as `args` and get its output Tensors |
1051 | // or TensorShapes in `rets`. When using RemoteTensorHandles as function |
1052 | // inputs or TensorShapes as outputs, the corresponding tensor data will be |
1053 | // resolved on the remote worker, so it is not required to be locally |
1054 | // available on the caller side. Using RemoteTensorHandle inputs is not |
1055 | // supported in TensorFlow v1 runtime. |
1056 | virtual void Run(const FunctionLibraryRuntime::Options& opts, |
1057 | FunctionLibraryRuntime::LocalHandle handle, |
1058 | gtl::ArraySlice<FunctionArg> args, |
1059 | std::vector<FunctionRet>* rets, |
1060 | FunctionLibraryRuntime::DoneCallback done) = 0; |
1061 | |
1062 | // Clean up a previously instantiated function on remote worker. |
1063 | virtual void CleanUp(uint64 step_id, |
1064 | FunctionLibraryRuntime::LocalHandle handle, |
1065 | FunctionLibraryRuntime::DoneCallback done) = 0; |
1066 | |
1067 | // DeviceMgr with *all* available devices (i.e., local and remote). |
1068 | virtual DeviceMgr* remote_device_mgr() const = 0; |
1069 | }; |
1070 | |
1071 | // Extracts the actual type from "attr_values" based on its definition |
1072 | // "arg_def". |
1073 | // |
1074 | // If "arg_def" is a N*T type, *is_type_list is set to false, and |
1075 | // *dtypes is set to be a vector of size N and each element is T. |
1076 | // |
1077 | // If "arg_def" is a list(type), *is_type_list is set to true, and |
1078 | // *dtypes is set to be a vector of types specified in attrs for |
1079 | // arg_def. |
1080 | // |
1081 | // Otherwise (arg_def is a simple type T), *is_type_list is set to |
1082 | // false, and *dtypes is set to a single element vector, whose only |
1083 | // element is T. |
1084 | Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, |
1085 | bool* is_type_list, DataTypeVector* dtypes); |
1086 | |
1087 | // To register a gradient function for a builtin op, one should use |
1088 | // REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); |
1089 | // |
1090 | // Typically, the c++ grad factory is a plan function that can be |
1091 | // converted into ::tensorflow::gradient::Creator, which is |
1092 | // std::function<Status(const AttrSlice&, FunctionDef*)>. |
1093 | // |
1094 | // A ::tensorflow::gradient::Creator should populate in FunctionDef* with a |
1095 | // definition of a brain function which compute the gradient for the |
1096 | // <op_name> when the <op_name> is instantiated with the given attrs. |
1097 | // |
1098 | // E.g., |
1099 | // |
1100 | // Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { |
1101 | // bool transpose_a; |
1102 | // TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); |
1103 | // bool transpose_b; |
1104 | // TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); |
1105 | // DataType dtype; |
1106 | // TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); |
1107 | // if (!transpose_a && !transpose_b) { |
1108 | // *g = FunctionDefHelper::Define( |
1109 | // "MatMulGrad", |
1110 | // {"x:T ", "y:T", "dz:T"}, // Inputs to this function |
1111 | // {"dx:T", "dy:T"}, // Outputs from this function |
1112 | // {"T: {float, double}"}, // Attributes needed by this function |
1113 | // { |
1114 | // {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, |
1115 | // {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, |
1116 | // {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, |
1117 | // {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, |
1118 | // }); |
1119 | // } else { |
1120 | // ... ... |
1121 | // } |
1122 | // return Status::OK(); |
1123 | // } |
1124 | // |
1125 | // NOTE: $T is substituted with the type variable "T" when the |
1126 | // gradient function MatMul is instantiated. |
1127 | // |
1128 | // TODO(zhifengc): Better documentation somewhere. |
1129 | |
1130 | // Macros to define a gradient function factory for a primitive |
1131 | // operation. |
1132 | #define REGISTER_OP_GRADIENT(name, fn) \ |
1133 | REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) |
1134 | |
1135 | #define REGISTER_OP_NO_GRADIENT(name) \ |
1136 | REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) |
1137 | |
1138 | #define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ |
1139 | REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) |
1140 | |
1141 | #define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ |
1142 | static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \ |
1143 | SHOULD_REGISTER_OP_GRADIENT && \ |
1144 | ::tensorflow::gradient::RegisterOp(name, fn) |
1145 | |
1146 | namespace gradient { |
1147 | // Register a gradient creator for the "op". |
1148 | typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator; |
1149 | bool RegisterOp(const std::string& op, Creator func); |
1150 | |
1151 | // Returns OK the gradient creator for the "op" is found (may be |
1152 | // nullptr if REGISTER_OP_NO_GRADIENT is used. |
1153 | Status GetOpGradientCreator(const std::string& op, Creator* creator); |
1154 | }; // namespace gradient |
1155 | |
1156 | // Declare explicit instantiations of GetAttr |
1157 | #define GET_ATTR(T) \ |
1158 | extern template Status FunctionLibraryDefinition::GetAttr( \ |
1159 | const Node&, const string&, T*) const; \ |
1160 | extern template Status FunctionLibraryDefinition::GetAttr( \ |
1161 | const NodeDef&, const string&, T*) const; |
1162 | GET_ATTR(string) |
1163 | GET_ATTR(bool) |
1164 | #undef GET_ATTR |
1165 | |
1166 | } // end namespace tensorflow |
1167 | |
1168 | #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ |
1169 | |