1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ |
17 | |
18 | #include <map> |
19 | |
20 | #include "tensorflow/core/framework/rendezvous.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | typedef std::map<string, Tensor> NamedTensors; |
26 | typedef std::function<void(const Status&)> StatusCallback; |
27 | |
28 | // Uses `rendezvous` to send tensors in `tensors_to_send`. `device_context` |
29 | // should be the DeviceContext associated with the source of the tensors. |
30 | // `alloc_attrs` contains information about how the `tensors_to_send` are |
31 | // allocated. `alloc_attrs` should either be {} or should match the length of |
32 | // `keys`. |
33 | Status SendTensorsToRendezvous( |
34 | RendezvousInterface* rendezvous, DeviceContext* device_context, |
35 | const std::vector<AllocatorAttributes>& alloc_attrs, |
36 | const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send); |
37 | |
38 | // Uses `rendezvous` to obtain tensors. `device_context` should be the |
39 | // DeviceContext associated with the receiving device. `alloc_attrs` contains |
40 | // information as how to store the received tensors. Should be {} or match the |
41 | // length of `keys`. |
42 | void RecvOutputsFromRendezvousAsync( |
43 | RendezvousInterface* rendezvous, DeviceContext* device_context, |
44 | const std::vector<AllocatorAttributes>& alloc_attrs, |
45 | const std::vector<string>& keys, std::vector<Tensor>* received_tensors, |
46 | StatusCallback done); |
47 | |
48 | Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, |
49 | NamedTensors* out, |
50 | const Rendezvous::Args& args); |
51 | |
52 | } // namespace tensorflow |
53 | |
54 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ |
55 | |