1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ |
17 | |
18 | #include "tensorflow/core/common_runtime/buf_rendezvous.h" |
19 | #include "tensorflow/core/common_runtime/device_mgr.h" |
20 | #include "tensorflow/core/framework/collective.h" |
21 | #include "tensorflow/core/framework/rendezvous.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Basic implementation of PerStepCollectiveRemoteAccess. |
26 | class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { |
27 | public: |
28 | CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr, |
29 | DeviceResolverInterface* dev_resolver, |
30 | int64_t step_id) |
31 | : dev_mgr_(dev_mgr), |
32 | dev_resolver_(dev_resolver), |
33 | buf_rendezvous_(step_id, dev_mgr), |
34 | step_id_(step_id) {} |
35 | |
36 | ~CollectiveRemoteAccessLocal() override = default; |
37 | |
38 | void StartAbort(const Status& s) override; |
39 | |
40 | void RecvFromPeer(const string& peer_device, const string& peer_task, |
41 | bool peer_is_local, const string& key, Device* to_device, |
42 | DeviceContext* to_device_ctx, |
43 | const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, |
44 | const DeviceLocality& client_locality, |
45 | int dev_to_dev_stream_index, |
46 | CancellationManager* cancellation_manager, |
47 | const StatusCallback& done) override; |
48 | |
49 | void PostToPeer(const string& peer_device, const string& peer_task, |
50 | const string& key, Device* from_device, |
51 | DeviceContext* from_device_ctx, |
52 | const AllocatorAttributes& from_alloc_attr, |
53 | const Tensor* from_tensor, |
54 | const DeviceLocality& client_locality, |
55 | CancellationManager* cancellation_manager, |
56 | const StatusCallback& done) override; |
57 | |
58 | void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, |
59 | const StatusCallback& done) override; |
60 | |
61 | BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } |
62 | |
63 | // Copy utility that always copies bytes from src to dst even if |
64 | // they are on the same device, unlike CopyTensor::ViaDMA which will |
65 | // just change the dst buffer pointer in that case. |
66 | static void MemCpyAsync(DeviceContext* src_dev_ctx, |
67 | DeviceContext* dst_dev_ctx, Device* src_dev, |
68 | Device* dst_dev, const AllocatorAttributes& src_attr, |
69 | const AllocatorAttributes& dst_attr, |
70 | const Tensor* src, Tensor* dst, |
71 | int dev_to_dev_stream_index, |
72 | const StatusCallback& done); |
73 | |
74 | protected: |
75 | const DeviceMgr* dev_mgr_; // not owned |
76 | DeviceResolverInterface* dev_resolver_; // not owned |
77 | BufRendezvous buf_rendezvous_; |
78 | int64_t step_id_; |
79 | }; |
80 | |
81 | } // namespace tensorflow |
82 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ |
83 | |