1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
17
18#include <deque>
19#include <memory>
20#include <string>
21#include <vector>
22
23#include "tensorflow/core/common_runtime/base_collective_executor.h"
24#include "tensorflow/core/framework/collective.h"
25
26namespace tensorflow {
27class Device;
28
29// Basic ring-algorithm implementation to be further specialized
30// for specific collective functions.
31class RingAlg : public CollectiveImplementationInterface {
32 public:
33 explicit RingAlg(CollectiveType type, const string& name);
34 ~RingAlg() override {}
35
36 // Establishes the requested number of subdivision permutations based on the
37 // ring order implicit in the device order.
38 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
39
40 // Initializes members of CollectiveContext not yet initialized, i.e. device
41 // and device_locality. Also saves the CollectiveContext in this object.
42 Status InitializeCollectiveContext(
43 std::shared_ptr<CollectiveContext> col_ctx) override;
44
45 protected:
46 // Called when a bad status is received that implies we should terminate
47 // execution and return a bad status.
48 void StartAbort(const Status& s);
49 void Finish(bool ok);
50
51 // Current status of a RingField
52 enum RingFieldAction {
53 RF_INIT = 0, // Just initialized for a pass
54 RF_RECV, // Recv pending
55 RF_REDUCE, // Reduce pending
56 RF_FINALIZE, // FinalOp pending
57 RF_SEND_READY, // Ready to send
58 RF_SEND, // Send pending
59 RF_DONE, // No more work
60 };
61
62 // Tracks progress of actions on a single subfield of the entire tensor.
63 struct RingField {
64 int16 chunk_idx; // major division index
65 int16 subdiv_idx; // minor division index
66 int16 sc_idx; // subchunk index
67 int16 rank; // rank within subdiv permutation
68 int16 recv_dev_idx; // dev from which value should be recv'd
69 RingFieldAction action;
70 bool second_pass;
71 bool recv_is_remote = false;
72 bool send_is_remote = false;
73 bool do_send = false; // is the value sent in this pass?
74 bool do_recv = false; // is the value recv'd in this pass?
75 bool is_final = false; // is the last field in the pass for this rank
76 Tensor chunk; // alias to field values
77 Tensor tmp_chunk;
78 Status status;
79 string DebugString() const;
80 };
81 virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
82 int field_idx);
83 void AdvanceToSecondPass(RingField* rf);
84 void DispatchSend(RingField* rf, const StatusCallback& done);
85 void DispatchRecv(RingField* rf, const StatusCallback& done);
86
87 // For constructing log messages for debugging.
88 string FieldState();
89 string TensorDebugString(const Tensor& tensor);
90
91 // Producer/Consumer Queue of RingField structs.
92 class PCQueue {
93 public:
94 void Enqueue(RingField* rf);
95 RingField* Dequeue();
96
97 private:
98 mutex pcq_mu_;
99 condition_variable cv_;
100 int waiter_count_ TF_GUARDED_BY(pcq_mu_) = 0;
101 std::deque<RingField*> deque_ TF_GUARDED_BY(pcq_mu_);
102 };
103
104 const CollectiveType type_;
105 const string name_;
106 std::shared_ptr<CollectiveContext> col_ctx_;
107 const CollectiveParams* col_params_; // Not owned
108 StatusCallback done_;
109 int group_size_;
110 int num_subdivs_;
111 Tensor group_size_tensor_;
112 Notification group_size_tensor_ready_;
113 std::unique_ptr<CollectiveAdapter> ca_;
114 mutex status_mu_;
115 Status status_ TF_GUARDED_BY(status_mu_);
116 std::vector<RingField> rfv_;
117};
118
119} // namespace tensorflow
120#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
121