1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ |
17 | |
18 | #include <deque> |
19 | #include <memory> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
24 | #include "tensorflow/core/framework/collective.h" |
25 | |
26 | namespace tensorflow { |
27 | class Device; |
28 | |
29 | // Basic ring-algorithm implementation to be further specialized |
30 | // for specific collective functions. |
31 | class RingAlg : public CollectiveImplementationInterface { |
32 | public: |
33 | explicit RingAlg(CollectiveType type, const string& name); |
34 | ~RingAlg() override {} |
35 | |
36 | // Establishes the requested number of subdivision permutations based on the |
37 | // ring order implicit in the device order. |
38 | Status InitializeCollectiveParams(CollectiveParams* col_params) override; |
39 | |
40 | // Initializes members of CollectiveContext not yet initialized, i.e. device |
41 | // and device_locality. Also saves the CollectiveContext in this object. |
42 | Status InitializeCollectiveContext( |
43 | std::shared_ptr<CollectiveContext> col_ctx) override; |
44 | |
45 | protected: |
46 | // Called when a bad status is received that implies we should terminate |
47 | // execution and return a bad status. |
48 | void StartAbort(const Status& s); |
49 | void Finish(bool ok); |
50 | |
51 | // Current status of a RingField |
52 | enum RingFieldAction { |
53 | RF_INIT = 0, // Just initialized for a pass |
54 | RF_RECV, // Recv pending |
55 | RF_REDUCE, // Reduce pending |
56 | RF_FINALIZE, // FinalOp pending |
57 | RF_SEND_READY, // Ready to send |
58 | RF_SEND, // Send pending |
59 | RF_DONE, // No more work |
60 | }; |
61 | |
62 | // Tracks progress of actions on a single subfield of the entire tensor. |
63 | struct RingField { |
64 | int16 chunk_idx; // major division index |
65 | int16 subdiv_idx; // minor division index |
66 | int16 sc_idx; // subchunk index |
67 | int16 rank; // rank within subdiv permutation |
68 | int16 recv_dev_idx; // dev from which value should be recv'd |
69 | RingFieldAction action; |
70 | bool second_pass; |
71 | bool recv_is_remote = false; |
72 | bool send_is_remote = false; |
73 | bool do_send = false; // is the value sent in this pass? |
74 | bool do_recv = false; // is the value recv'd in this pass? |
75 | bool is_final = false; // is the last field in the pass for this rank |
76 | Tensor chunk; // alias to field values |
77 | Tensor tmp_chunk; |
78 | Status status; |
79 | string DebugString() const; |
80 | }; |
81 | virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, |
82 | int field_idx); |
83 | void AdvanceToSecondPass(RingField* rf); |
84 | void DispatchSend(RingField* rf, const StatusCallback& done); |
85 | void DispatchRecv(RingField* rf, const StatusCallback& done); |
86 | |
87 | // For constructing log messages for debugging. |
88 | string FieldState(); |
89 | string TensorDebugString(const Tensor& tensor); |
90 | |
91 | // Producer/Consumer Queue of RingField structs. |
92 | class PCQueue { |
93 | public: |
94 | void Enqueue(RingField* rf); |
95 | RingField* Dequeue(); |
96 | |
97 | private: |
98 | mutex pcq_mu_; |
99 | condition_variable cv_; |
100 | int waiter_count_ TF_GUARDED_BY(pcq_mu_) = 0; |
101 | std::deque<RingField*> deque_ TF_GUARDED_BY(pcq_mu_); |
102 | }; |
103 | |
104 | const CollectiveType type_; |
105 | const string name_; |
106 | std::shared_ptr<CollectiveContext> col_ctx_; |
107 | const CollectiveParams* col_params_; // Not owned |
108 | StatusCallback done_; |
109 | int group_size_; |
110 | int num_subdivs_; |
111 | Tensor group_size_tensor_; |
112 | Notification group_size_tensor_ready_; |
113 | std::unique_ptr<CollectiveAdapter> ca_; |
114 | mutex status_mu_; |
115 | Status status_ TF_GUARDED_BY(status_mu_); |
116 | std::vector<RingField> rfv_; |
117 | }; |
118 | |
119 | } // namespace tensorflow |
120 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ |
121 | |