1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ |
17 | |
18 | #include <memory> |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/common_runtime/buf_rendezvous.h" |
22 | #include "tensorflow/core/framework/collective.h" |
23 | #include "tensorflow/core/framework/device_attributes.pb.h" |
24 | #include "tensorflow/core/platform/unbounded_work_queue.h" |
25 | |
26 | namespace tensorflow { |
27 | class CollectiveImplementation; |
28 | class DeviceMgr; |
29 | class Device; |
30 | |
31 | // Helper interface that aliases regular subfields of a Tensor as separate |
32 | // Tensors for in-place update. |
33 | class CollectiveAdapter { |
34 | public: |
35 | virtual ~CollectiveAdapter() {} |
36 | |
37 | // Move the backing tensor to 'output' with its original storage and |
38 | // shape. After this call this CollectiveAdapter object should be |
39 | // deleted immediately without calling any of its other methods. |
40 | virtual void ConsumeFinalValue(Tensor* output) = 0; |
41 | |
42 | // const access to entire intermediate value for debugging |
43 | virtual const Tensor& Value() const = 0; |
44 | |
45 | // Returns tensor for chunk i which aliases the backing buffer. |
46 | virtual Tensor ChunkAlias(int i) = 0; |
47 | |
48 | // Returns tensor allocated on the same device but with its own |
49 | // separate backing buffer. Will have same type and size as |
50 | // chunk i. |
51 | virtual Tensor TempChunk(int i) const = 0; |
52 | |
53 | // Bytes in chunk i |
54 | virtual int64_t ChunkBytes(int i) const = 0; |
55 | |
56 | // Generate a CPU RAM scalar tensor of the same DataType as the |
57 | // backing tensor with the given integer value. |
58 | virtual Tensor Scalar(int v) const = 0; |
59 | |
60 | // Generate a scalar tensor of same DataType and on the same device |
61 | // as the backing tensor. |
62 | virtual Tensor Scalar(Allocator* a, |
63 | const AllocationAttributes& attr) const = 0; |
64 | |
65 | // Debugging string describing buffer location |
66 | virtual string TBounds(const Tensor& t) const = 0; |
67 | |
68 | virtual string DebugString() const = 0; |
69 | |
70 | // Computes the number of elements per alias chunk tensor. |
71 | // |
72 | // A CHECK in tensor.cc expects that the memory buffer backing a |
73 | // Tensor will be aligned according to EIGEN_MAX_ALIGN_BYTES. To |
74 | // ensure that all chunk aliasing Tensors maintain this alignment we |
75 | // need to pick a chunk size that preserves it. Note than in extreme |
76 | // cases (impractical, but possible with very small tensors) one or |
77 | // more tail chunks can end up emptby. |
78 | static int64_t AlignedChunkElts(int64_t elt_bytes, int64_t total_elts, |
79 | int64_t num_chunks); |
80 | }; |
81 | |
82 | // Create a CollectiveAdaptor wrapping 'output', specialized to its |
83 | // data-type and shape. If align_chunks == true then chunk size may |
84 | // be larger than output->NumElements() / num_chunks and one or more |
85 | // of the suffix chunks may be empty. Chunks will be arranged to start |
86 | // and end on alignment boundaries. If align_chunks == false then |
87 | // output->NumElements() % num_chunks must be 0 and all chunks will |
88 | // have exactly the same size, ignoring alignment issues. |
89 | CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, |
90 | Allocator* allocator, |
91 | bool align_chunks = true); |
92 | |
93 | // Default implementation of CollectiveExecutor. Delegates the actual |
94 | // work of moving data to a class specialized for the operation type, |
95 | // arguments and device+interconnect topology. |
96 | class BaseCollectiveExecutor : public CollectiveExecutor { |
97 | public: |
98 | BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem, |
99 | CollectiveRemoteAccess* remote_access, int64_t step_id, |
100 | const DeviceMgr* dev_mgr, |
101 | std::shared_ptr<UnboundedWorkQueue> work_queue) |
102 | : CollectiveExecutor(cem), |
103 | step_id_(step_id), |
104 | dev_mgr_(dev_mgr), |
105 | remote_access_(remote_access), |
106 | work_queue_(std::move(work_queue)) {} |
107 | |
108 | ~BaseCollectiveExecutor() override; |
109 | |
110 | void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_); |
111 | |
112 | void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, |
113 | const string& exec_key, StatusCallback done) override; |
114 | |
115 | void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, |
116 | CancellationManager* cancel_mgr, |
117 | StatusCallback done) override; |
118 | |
119 | CollectiveRemoteAccess* remote_access() override { |
120 | return remote_access_.get(); |
121 | } |
122 | |
123 | void RunClosure(std::function<void()> closure) override { |
124 | work_queue_->Schedule(std::move(closure)); |
125 | } |
126 | |
127 | // If we need to enforce an ordering on any portion of collective |
128 | // implementation, and the ordering is encoded via attribute on the collective |
129 | // op, this function will block until all dependencies for this collective |
130 | // have completed. |
131 | void WaitForDependencies(const CollectiveParams& col_params) override; |
132 | // Record that this collective has completed the portion of the implementation |
133 | // that needs to be ordered wrt other collectives, to unblock any of its |
134 | // dependent ops. |
135 | void UnblockDependencies(const CollectiveParams& col_params) override; |
136 | |
137 | protected: |
138 | const int64_t step_id_; |
139 | const DeviceMgr* dev_mgr_; // Not owned. |
140 | std::unique_ptr<CollectiveRemoteAccess> remote_access_; |
141 | // Ownership of `work_queue_` is shared between `this` and |
142 | // `CollectiveExecutorMgr`. |
143 | std::shared_ptr<UnboundedWorkQueue> work_queue_; |
144 | mutex launch_mu_; |
145 | condition_variable launch_cv_; |
146 | // collective instance key -> number of local devices for which NCCL ops have |
147 | // been launched. |
148 | std::unordered_map<int32, int32> launched_ TF_GUARDED_BY(launch_mu_); |
149 | mutex status_mu_; |
150 | Status status_ TF_GUARDED_BY(status_mu_); |
151 | |
152 | private: |
153 | Status CreateCollective(const CollectiveParams& col_params, |
154 | CollectiveImplementationInterface** col_impl); |
155 | // Check if all ops on which this collective depends on have launched. |
156 | bool CheckDependencies(const CollectiveParams& col_params) |
157 | TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_); |
158 | // Tries to return the status that is the original error. It returns the |
159 | // aborted status if the collective executor is aborted. |
160 | Status GetStatus(const Status& s) TF_LOCKS_EXCLUDED(status_mu_); |
161 | }; |
162 | |
163 | } // namespace tensorflow |
164 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ |
165 | |