1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ |
18 | |
19 | #include "tensorflow/core/common_runtime/device.h" |
20 | #include "tensorflow/core/framework/device_base.h" |
21 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
22 | |
23 | namespace stream_executor { |
24 | class Stream; |
25 | } // namespace stream_executor |
26 | |
27 | namespace tensorflow { |
28 | |
29 | class GPUDeviceContext : public DeviceContext { |
30 | public: |
31 | // Does not take ownership of streams. |
32 | GPUDeviceContext(int stream_id, se::Stream* stream, |
33 | #if TENSORFLOW_USE_ROCM |
34 | se::Stream* nccl_stream, |
35 | #endif |
36 | se::Stream* host_to_device_stream, |
37 | se::Stream* device_to_host_stream, |
38 | gtl::InlinedVector<se::Stream*, 4> device_to_device_stream, |
39 | Allocator* host_memory_allocator) |
40 | : stream_id_(stream_id), |
41 | stream_(stream), |
42 | #if TENSORFLOW_USE_ROCM |
43 | nccl_stream_(nccl_stream), |
44 | #endif |
45 | host_to_device_stream_(host_to_device_stream), |
46 | device_to_host_stream_(device_to_host_stream), |
47 | device_to_device_stream_(device_to_device_stream), |
48 | host_memory_allocator_(host_memory_allocator) { |
49 | } |
50 | |
51 | ~GPUDeviceContext() override {} |
52 | |
53 | se::Stream* stream() const override { return stream_; } |
54 | #if TENSORFLOW_USE_ROCM |
55 | se::Stream* nccl_stream() const { return nccl_stream_; } |
56 | #endif |
57 | se::Stream* host_to_device_stream() const { return host_to_device_stream_; } |
58 | se::Stream* device_to_host_stream() const { return device_to_host_stream_; } |
59 | se::Stream* device_to_device_stream(int index) const { |
60 | return device_to_device_stream_[index % device_to_device_stream_.size()]; |
61 | } |
62 | int stream_id() const { return stream_id_; } |
63 | Allocator* host_memory_allocator() const override { |
64 | return host_memory_allocator_; |
65 | } |
66 | |
67 | void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, |
68 | Tensor* device_tensor, StatusCallback done, |
69 | bool sync_dst_compute) const override; |
70 | |
71 | void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece edge_name, |
72 | Device* device, Tensor* cpu_tensor, |
73 | StatusCallback done) override; |
74 | |
75 | void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, |
76 | Tensor* output_tensor, |
77 | StatusCallback done) const override; |
78 | |
79 | void MaintainLifetimeOnStream(const Tensor* t, |
80 | se::Stream* stream) const override {} |
81 | |
82 | Status ThenExecute(Device* device, se::Stream* stream, |
83 | std::function<void()> func) override; |
84 | |
85 | private: |
86 | int stream_id_; |
87 | // The default primary stream to use for this context. |
88 | // All the memory belongs to this stream. |
89 | se::Stream* stream_; |
90 | #if TENSORFLOW_USE_ROCM |
91 | // The stream to use for nccl operations. |
92 | se::Stream* nccl_stream_; |
93 | #endif |
94 | // The stream to use for copying data from host into GPU. |
95 | se::Stream* host_to_device_stream_; |
96 | // The stream to use for copying data from GPU to host. |
97 | se::Stream* device_to_host_stream_; |
98 | // Streams to use for copying data between GPUs. |
99 | gtl::InlinedVector<se::Stream*, 4> device_to_device_stream_; |
100 | // The allocator to use for allocating pinned host memory. |
101 | // Not owned. |
102 | Allocator* host_memory_allocator_; |
103 | }; |
104 | |
105 | } // namespace tensorflow |
106 | |
107 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ |
108 | |