1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ |
17 | |
18 | #include <functional> |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "absl/strings/string_view.h" |
23 | #include "tensorflow/core/framework/allocator.h" |
24 | #include "tensorflow/core/framework/cancellation.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/platform/mutex.h" |
27 | |
28 | namespace tensorflow { |
29 | class Device; |
30 | class DeviceContext; |
31 | class DeviceMgr; |
32 | class Tensor; |
33 | |
34 | // EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local |
35 | // Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense |
36 | // numeric types. Similar to Rendezvous but never owns a Ref on the |
37 | // tensor, instead it uses an explicit callback to the producer when |
38 | // the consumer side is finished with the value. This allows the |
39 | // producer to perform in-place updates on the source buffer or to take |
40 | // other actions that depend on knowing the consumer has passed a certain |
41 | // execution point. |
42 | class BufRendezvous { |
43 | public: |
44 | explicit BufRendezvous(uint64 step_id, const DeviceMgr* dev_mgr) |
45 | : step_id_(step_id), dev_mgr_(dev_mgr) {} |
46 | |
47 | virtual ~BufRendezvous(); |
48 | |
49 | // Inform all waiting parties that this BufRendezvous is defunct because of |
50 | // an error Status interrupting the Step. |
51 | void StartAbort(const Status& s); |
52 | |
53 | struct Hook; |
54 | // Provided by the consumer to be called when access to the buffer |
55 | // is available. If the Status arg is not OK, then hook will not |
56 | // be populated. Ownership of Hook passes to consumer with the |
57 | // callback. |
58 | typedef std::function<void(const Status&, Hook*)> ConsumerCallback; |
59 | // Provided by the producer to be called when the consumer has finished |
60 | // reading the buffer and will no longer access it. |
61 | typedef std::function<void(const Status&)> ProducerCallback; |
62 | |
63 | struct Hook { |
64 | Device* prod_dev; |
65 | DeviceContext* prod_ctx; |
66 | const Tensor* prod_value; |
67 | AllocatorAttributes prod_attr; |
68 | ProducerCallback prod_cb; |
69 | ConsumerCallback cons_cb; |
70 | CancellationManager* cancellation_manager; |
71 | CancellationToken cancellation_token; |
72 | explicit Hook(CancellationManager* cancellation_manager, |
73 | CancellationToken cancellation_token) |
74 | : prod_dev(nullptr), |
75 | prod_ctx(nullptr), |
76 | prod_value(nullptr), |
77 | prod_cb(nullptr), |
78 | cons_cb(nullptr), |
79 | cancellation_manager(cancellation_manager), |
80 | cancellation_token(cancellation_token) {} |
81 | string DebugString() const; |
82 | }; |
83 | |
84 | // Called to advertise availability of a Tensor value corresponding |
85 | // to key. That value must stay valid until done is called. |
86 | // |
87 | // If a non-null cancellation manager is provided, this function registers a |
88 | // callback to delete the hook and invoke provider/consumer callbacks with |
89 | // cancelled error. |
90 | void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx, |
91 | const Tensor* v, const AllocatorAttributes& attr, |
92 | const ProducerCallback& done, |
93 | CancellationManager* cancellation_manager); |
94 | |
95 | // Called to request access to a Tensor value corresponding to key. |
96 | // Consumer is provided with a Hook as soon as available. |
97 | // |
98 | // This function also checks that the current incarnation number of the |
99 | // `device` that produced this value matches the `incarnation` expected by the |
100 | // consumer, and invokes `done` with `FailedPrecondition` status and |
101 | // `nullptr` hook if it does not match. |
102 | // |
103 | // If a non-null cancellation manager is provided, this function registers a |
104 | // callback to delete the hook and invoke provider/consumer callbacks with |
105 | // cancelled error. |
106 | virtual void ConsumeBuf(const string& key, const string& device, |
107 | const uint64 incarnation, |
108 | const ConsumerCallback& done, |
109 | CancellationManager* cancellation_manager); |
110 | |
111 | // Cancel the rendezvous entry corresponding to `key`. Triggered by the |
112 | // cancellation manager. No-op if the rendezvous was already successful. |
113 | void CancelHook(const string& key); |
114 | |
115 | // Consumer must call this function when it's done reading the Hook provided |
116 | // by the ConsumerCallback. This function will invoke the producer callback |
117 | // and then delete h. |
118 | static void DoneWithHook(Hook* h); |
119 | |
120 | // Write the current contents of the table to the INFO log. |
121 | void LogContents(); |
122 | |
123 | protected: |
124 | const uint64 step_id_; |
125 | const DeviceMgr* const dev_mgr_; // Not owned. |
126 | mutex mu_; |
127 | Status status_ TF_GUARDED_BY(mu_); |
128 | typedef absl::flat_hash_map<string, Hook*> HookTable; |
129 | HookTable hook_table_ TF_GUARDED_BY(mu_); |
130 | |
131 | void PurgeTable(const Status& s, HookTable* table); |
132 | }; |
133 | } // namespace tensorflow |
134 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ |
135 | |