1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
17
18#include <functional>
19#include <string>
20
21#include "absl/container/flat_hash_map.h"
22#include "absl/strings/string_view.h"
23#include "tensorflow/core/framework/allocator.h"
24#include "tensorflow/core/framework/cancellation.h"
25#include "tensorflow/core/lib/core/status.h"
26#include "tensorflow/core/platform/mutex.h"
27
28namespace tensorflow {
29class Device;
30class DeviceContext;
31class DeviceMgr;
32class Tensor;
33
34// EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local
35// Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense
36// numeric types. Similar to Rendezvous but never owns a Ref on the
37// tensor, instead it uses an explicit callback to the producer when
38// the consumer side is finished with the value. This allows the
39// producer to perform in-place updates on the source buffer or to take
40// other actions that depend on knowing the consumer has passed a certain
41// execution point.
42class BufRendezvous {
43 public:
44 explicit BufRendezvous(uint64 step_id, const DeviceMgr* dev_mgr)
45 : step_id_(step_id), dev_mgr_(dev_mgr) {}
46
47 virtual ~BufRendezvous();
48
49 // Inform all waiting parties that this BufRendezvous is defunct because of
50 // an error Status interrupting the Step.
51 void StartAbort(const Status& s);
52
53 struct Hook;
54 // Provided by the consumer to be called when access to the buffer
55 // is available. If the Status arg is not OK, then hook will not
56 // be populated. Ownership of Hook passes to consumer with the
57 // callback.
58 typedef std::function<void(const Status&, Hook*)> ConsumerCallback;
59 // Provided by the producer to be called when the consumer has finished
60 // reading the buffer and will no longer access it.
61 typedef std::function<void(const Status&)> ProducerCallback;
62
63 struct Hook {
64 Device* prod_dev;
65 DeviceContext* prod_ctx;
66 const Tensor* prod_value;
67 AllocatorAttributes prod_attr;
68 ProducerCallback prod_cb;
69 ConsumerCallback cons_cb;
70 CancellationManager* cancellation_manager;
71 CancellationToken cancellation_token;
72 explicit Hook(CancellationManager* cancellation_manager,
73 CancellationToken cancellation_token)
74 : prod_dev(nullptr),
75 prod_ctx(nullptr),
76 prod_value(nullptr),
77 prod_cb(nullptr),
78 cons_cb(nullptr),
79 cancellation_manager(cancellation_manager),
80 cancellation_token(cancellation_token) {}
81 string DebugString() const;
82 };
83
84 // Called to advertise availability of a Tensor value corresponding
85 // to key. That value must stay valid until done is called.
86 //
87 // If a non-null cancellation manager is provided, this function registers a
88 // callback to delete the hook and invoke provider/consumer callbacks with
89 // cancelled error.
90 void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
91 const Tensor* v, const AllocatorAttributes& attr,
92 const ProducerCallback& done,
93 CancellationManager* cancellation_manager);
94
95 // Called to request access to a Tensor value corresponding to key.
96 // Consumer is provided with a Hook as soon as available.
97 //
98 // This function also checks that the current incarnation number of the
99 // `device` that produced this value matches the `incarnation` expected by the
100 // consumer, and invokes `done` with `FailedPrecondition` status and
101 // `nullptr` hook if it does not match.
102 //
103 // If a non-null cancellation manager is provided, this function registers a
104 // callback to delete the hook and invoke provider/consumer callbacks with
105 // cancelled error.
106 virtual void ConsumeBuf(const string& key, const string& device,
107 const uint64 incarnation,
108 const ConsumerCallback& done,
109 CancellationManager* cancellation_manager);
110
111 // Cancel the rendezvous entry corresponding to `key`. Triggered by the
112 // cancellation manager. No-op if the rendezvous was already successful.
113 void CancelHook(const string& key);
114
115 // Consumer must call this function when it's done reading the Hook provided
116 // by the ConsumerCallback. This function will invoke the producer callback
117 // and then delete h.
118 static void DoneWithHook(Hook* h);
119
120 // Write the current contents of the table to the INFO log.
121 void LogContents();
122
123 protected:
124 const uint64 step_id_;
125 const DeviceMgr* const dev_mgr_; // Not owned.
126 mutex mu_;
127 Status status_ TF_GUARDED_BY(mu_);
128 typedef absl::flat_hash_map<string, Hook*> HookTable;
129 HookTable hook_table_ TF_GUARDED_BY(mu_);
130
131 void PurgeTable(const Status& s, HookTable* table);
132};
133} // namespace tensorflow
134#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
135