1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/cancellation.h" |
22 | #include "tensorflow/core/framework/control_flow.h" |
23 | #include "tensorflow/core/framework/device_base.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/lib/core/refcount.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/util/device_name_utils.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | class DeviceMgr; |
32 | |
33 | // A Rendezvous is an abstraction for passing tensors from producers |
34 | // to consumers. A rendezvous is a table of channels. Each channel is |
35 | // keyed by a rendezvous key. The key encodes a pair of <producer, |
36 | // consumer>, where the producer and the consumer are tensorflow |
37 | // devices. |
38 | // |
39 | // The producer calls the Send() method to send one tensor over one |
40 | // named channel. The consumer calls the Recv() method to receive one |
41 | // tensor from a named channel. A sequence of tensors can be passed |
42 | // from the producer to the consumer. The consumer receives them in |
43 | // the order as the producer sends them. |
44 | // |
45 | // A consumer may safely request the tensor before or after it has |
46 | // been produced. A consumer has the choice of making a blocking call |
47 | // or providing a callback: in either case, the consumer receives the |
48 | // Tensor as soon as it is available. A producer never blocks. |
49 | class RendezvousInterface { |
50 | public: |
51 | struct Args { |
52 | DeviceContext* device_context = nullptr; |
53 | AllocatorAttributes alloc_attrs; |
54 | CancellationManager* cancellation_manager = nullptr; // not owned. |
55 | }; |
56 | |
57 | // Parses the key constructed by CreateKey and parse src/dst device |
58 | // names into structures respectively. |
59 | struct ParsedKey { |
60 | StringPiece src_device; |
61 | DeviceNameUtils::ParsedName src; |
62 | uint64 src_incarnation = 0; |
63 | StringPiece dst_device; |
64 | DeviceNameUtils::ParsedName dst; |
65 | StringPiece edge_name; |
66 | |
67 | ParsedKey() {} |
68 | ParsedKey(const ParsedKey& b) { *this = b; } |
69 | |
70 | ParsedKey& operator=(const ParsedKey& b); |
71 | StringPiece FullKey() const { return buf_; } |
72 | |
73 | private: |
74 | friend class Rendezvous; |
75 | friend class SendOp; |
76 | friend class RecvOp; |
77 | std::string buf_; |
78 | }; |
79 | |
80 | // The caller is a tensor producer and it sends a message (a tensor |
81 | // "val" and a bool "is_dead") under the given "key". |
82 | // |
83 | // {val, is_dead} is bundled as a message sent and received. |
84 | // Typically, is_dead is set by some control flow nodes |
85 | // (e.g., a not-taken branch). args is passed by Send to the |
86 | // Recv function to communicate any information that the Recv |
87 | // function might need. This is typically only necessary for |
88 | // Send/Recv on the same worker. |
89 | // |
90 | // Send() never blocks. |
91 | virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, |
92 | const bool is_dead) = 0; |
93 | |
94 | // Callback provided by a tensor consumer waiting on the rendezvous. |
95 | // It will be invoked when the tensor is available, or when a non-OK |
96 | // status arises in the production of that tensor. It also gets |
97 | // two Rendezvous::Args, one provided by the sender, the other by the |
98 | // receiver, which may be needed when a non-CPU device is in use |
99 | // by either side. |
100 | typedef std::function<void(const Status&, const Args&, const Args&, |
101 | const Tensor&, const bool)> |
102 | DoneCallback; |
103 | |
104 | virtual void RecvAsync(const ParsedKey& key, const Args& args, |
105 | DoneCallback done) = 0; |
106 | |
107 | // Synchronous wrapper for RecvAsync. |
108 | Status Recv(const ParsedKey& key, const Args& args, Tensor* val, |
109 | bool* is_dead, int64_t timeout_ms); |
110 | Status Recv(const ParsedKey& key, const Args& args, Tensor* val, |
111 | bool* is_dead); |
112 | |
113 | // Aborts all pending and future Send/Recv with the given "status". |
114 | // |
115 | // StartAbort() does not wait for ongoing calls to finish. |
116 | // REQUIRES: !status.ok() |
117 | virtual void StartAbort(const Status& status) = 0; |
118 | |
119 | protected: |
120 | virtual ~RendezvousInterface(); |
121 | |
122 | virtual bool is_cross_process() { return false; } |
123 | friend class ProcessFunctionLibraryRuntime; |
124 | }; |
125 | |
126 | // A reference-counted implementation of RendezvousInterface. |
127 | // |
128 | // This class is used in cases where a rendezvous may be shared between multiple |
129 | // threads with no clear owner. |
130 | class Rendezvous : public RendezvousInterface, public core::RefCounted { |
131 | public: |
132 | class Factory { |
133 | public: |
134 | // Default to a factory that evaluates to false. |
135 | Factory() : valid_(false) {} |
136 | |
137 | Factory(std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)> |
138 | create_fn, |
139 | std::function<Status(const int64_t)> cleanup_fn) |
140 | : valid_(true), |
141 | create_fn_(std::move(create_fn)), |
142 | cleanup_fn_(std::move(cleanup_fn)) {} |
143 | |
144 | // If no clean up fn is provided, just put in a dummy. |
145 | // For backwards compatibility. |
146 | explicit Factory( |
147 | std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)> |
148 | create_fn) |
149 | : valid_(true), |
150 | create_fn_(std::move(create_fn)), |
151 | cleanup_fn_([](const int64_t step_id) { return OkStatus(); }) {} |
152 | |
153 | explicit operator bool() const { return valid_; } |
154 | |
155 | Status operator()(const int64_t step_id, const DeviceMgr* device_mgr, |
156 | Rendezvous** rendez) const { |
157 | return create_fn_(step_id, device_mgr, rendez); |
158 | } |
159 | |
160 | Status CleanUp(const int64_t step_id) const { return cleanup_fn_(step_id); } |
161 | |
162 | private: |
163 | bool valid_; |
164 | std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)> |
165 | create_fn_; |
166 | std::function<Status(const int64_t)> cleanup_fn_; |
167 | }; |
168 | |
169 | // Constructs a rendezvous key for the tensor of "name" sent from |
170 | // "src_device" to "dst_device". The tensor is generated in the frame |
171 | // and iteration specified by "frame_iter". |
172 | static std::string CreateKey(const std::string& src_device, |
173 | uint64 src_incarnation, |
174 | const std::string& dst_device, |
175 | const std::string& name, |
176 | const FrameAndIter& frame_iter); |
177 | |
178 | static Status ParseKey(StringPiece key, ParsedKey* out); |
179 | }; |
180 | |
181 | // Returns a Rendezvous instance that is limited to use only by |
182 | // producers and consumers in the local process. The caller assumes |
183 | // ownership of one Ref() on the returned object. |
184 | Rendezvous* NewLocalRendezvous(); |
185 | |
186 | } // end namespace tensorflow |
187 | |
188 | #endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ |
189 | |