1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/scheduler.h" |
17 | |
18 | #include <queue> |
19 | |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/common_runtime/device_set.h" |
22 | #include "tensorflow/core/graph/graph.h" |
23 | #include "tensorflow/core/util/util.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | namespace { |
28 | |
29 | // Initialize the pending count for each node. |
30 | void InitializePending(const Graph* graph, std::vector<int>* pending) { |
31 | pending->resize(graph->num_node_ids()); |
32 | for (const Node* node : graph->nodes()) { |
33 | const int id = node->id(); |
34 | int num_in_edges = 0; |
35 | if (IsMerge(node)) { |
36 | // For forward execution order, Merge nodes are special. We process |
37 | // them only once when one of its inputs is processed. |
38 | for (const Edge* edge : node->in_edges()) { |
39 | if (edge->IsControlEdge()) { |
40 | // Bit 0 is reserved to indicate if there is a data input. |
41 | num_in_edges += 2; |
42 | } |
43 | } |
44 | } else { |
45 | num_in_edges = node->in_edges().size(); |
46 | } |
47 | (*pending)[id] = num_in_edges; |
48 | } |
49 | } |
50 | |
51 | // Return true if the update makes the destination of the edge ready to run. |
52 | bool UpdatePending(const Edge* edge, std::vector<int>* pending_count) { |
53 | const Node* out = edge->dst(); |
54 | if (IsMerge(out)) { |
55 | if (edge->IsControlEdge()) { |
56 | (*pending_count)[out->id()] -= 2; |
57 | // Return true if we already got at least one input edge |
58 | // and a control edge is the enabling one. |
59 | return ((*pending_count)[out->id()] == 1); |
60 | } else { |
61 | int count = (*pending_count)[out->id()]; |
62 | (*pending_count)[out->id()] |= 0x1; |
63 | // If the first input edge is the enabling one, the count goes from |
64 | // 0 to 1 in this step. Return true iff count was zero. |
65 | return (count == 0); |
66 | } |
67 | } else { |
68 | return (--(*pending_count)[out->id()] == 0); |
69 | } |
70 | } |
71 | |
72 | } // end namespace |
73 | |
74 | SlackAnalysis::SlackAnalysis(const Graph* g, const CostModel* cost_model) |
75 | : graph_(g), cost_model_(cost_model) {} |
76 | |
77 | Microseconds SlackAnalysis::ComputeAsap(std::vector<Microseconds>* asap_times) { |
78 | asap_times->resize(graph_->num_node_ids()); |
79 | |
80 | std::vector<int> pending_count(graph_->num_node_ids()); |
81 | InitializePending(graph_, &pending_count); |
82 | |
83 | std::deque<const Node*> queue; |
84 | Node* srcNode = graph_->source_node(); |
85 | queue.push_back(srcNode); |
86 | (*asap_times)[srcNode->id()] = 0; |
87 | |
88 | while (!queue.empty()) { |
89 | const Node* curr = queue.front(); |
90 | queue.pop_front(); |
91 | Microseconds ctime = cost_model_->TimeEstimate(curr); |
92 | for (const Edge* out_edge : curr->out_edges()) { |
93 | // The time needed for 'out' to get its input from 'curr'. |
94 | Microseconds copy_time(0); |
95 | const Node* out = out_edge->dst(); |
96 | if (!out_edge->IsControlEdge() && |
97 | curr->assigned_device_name() != out->assigned_device_name()) { |
98 | // Add an arbitrary 10microsecs for each copy. |
99 | // TODO(yuanbyu): Use below with the real cost model. |
100 | // int index = out_edge->src_output(); |
101 | // Bytes nb = cost_model_->SizeEstimate(curr, index); |
102 | // copy_time = CostModel::CopyTimeEstimate(nb); |
103 | copy_time = 10; |
104 | } |
105 | Microseconds new_asap = (*asap_times)[curr->id()] + ctime + copy_time; |
106 | if ((*asap_times)[out->id()] < new_asap) { |
107 | (*asap_times)[out->id()] = new_asap; |
108 | } |
109 | |
110 | bool is_ready = UpdatePending(out_edge, &pending_count); |
111 | if (is_ready) { |
112 | queue.push_back(out); |
113 | } |
114 | } |
115 | } |
116 | return (*asap_times)[graph_->sink_node()->id()]; |
117 | } |
118 | |
119 | Microseconds SlackAnalysis::ComputeAlap(std::vector<Microseconds>* alap_times) { |
120 | alap_times->resize(graph_->num_node_ids()); |
121 | |
122 | std::vector<int> pending_count; |
123 | pending_count.resize(graph_->num_node_ids()); |
124 | for (const Node* n : graph_->nodes()) { |
125 | // For reverse execution order, Switch nodes are special. We process |
126 | // them only once when one of its outputs is processed. |
127 | if (IsSwitch(n)) { |
128 | int32_t num_control_edges = 0; |
129 | for (const Edge* edge : n->out_edges()) { |
130 | if (edge->IsControlEdge()) { |
131 | num_control_edges++; |
132 | } |
133 | } |
134 | pending_count[n->id()] = num_control_edges + 1; |
135 | } else { |
136 | pending_count[n->id()] = n->out_edges().size(); |
137 | } |
138 | } |
139 | |
140 | std::deque<const Node*> queue; |
141 | Node* sinkNode = graph_->sink_node(); |
142 | queue.push_back(sinkNode); |
143 | (*alap_times)[sinkNode->id()] = 0; |
144 | |
145 | while (!queue.empty()) { |
146 | const Node* curr = queue.front(); |
147 | queue.pop_front(); |
148 | for (const Edge* in_edge : curr->in_edges()) { |
149 | // The time needed for 'curr' to get its input from 'src'. |
150 | Microseconds copy_time(0); |
151 | const Node* src = in_edge->src(); |
152 | if (!in_edge->IsControlEdge() && |
153 | src->assigned_device_name() != curr->assigned_device_name()) { |
154 | // TODO(yuanbyu): Use the real cost model |
155 | // int index = out_edge->src_output(); |
156 | // Bytes nb = cost_model_->SizeEstimate(curr, index); |
157 | // copy_time = CostModel::CopyTimeEstimate(nb); |
158 | copy_time = 10; |
159 | } |
160 | Microseconds ctime = cost_model_->TimeEstimate(src); |
161 | Microseconds new_latest = (*alap_times)[curr->id()] - ctime - copy_time; |
162 | if ((*alap_times)[src->id()] > new_latest) { |
163 | (*alap_times)[src->id()] = new_latest; |
164 | } |
165 | |
166 | int count = --pending_count[src->id()]; |
167 | if (count == 0) { |
168 | queue.push_back(src); |
169 | } |
170 | } |
171 | } |
172 | return (*alap_times)[graph_->source_node()->id()]; |
173 | } |
174 | |
175 | void SlackAnalysis::ComputeSlack(std::vector<int64_t>* slacks) { |
176 | std::vector<Microseconds> asap_times; |
177 | std::vector<Microseconds> alap_times; |
178 | ComputeAsap(&asap_times); |
179 | ComputeAlap(&alap_times); |
180 | slacks->resize(graph_->num_node_ids()); |
181 | Node* srcNode = graph_->source_node(); |
182 | Microseconds makespan = alap_times[srcNode->id()]; |
183 | for (Node* node : graph_->nodes()) { |
184 | Microseconds latest_stime = alap_times[node->id()] - makespan; |
185 | (*slacks)[node->id()] = (latest_stime - asap_times[node->id()]).value(); |
186 | } |
187 | } |
188 | |
189 | GreedyScheduler::GreedyScheduler(const DeviceSet* devices, |
190 | const CostModel* cost_model, const Graph* g, |
191 | std::vector<int64_t>* priority) |
192 | : devices_(devices), |
193 | cost_model_(cost_model), |
194 | graph_(g), |
195 | priority_(priority) { |
196 | for (Device* d : devices_->devices()) { |
197 | Sim* s = new Sim; |
198 | // The number of compute units on a device. Set to 2 for now. |
199 | s->degree_parallelism = 2; |
200 | s->num_running = 0; |
201 | device_states_.insert(std::make_pair(d->name(), s)); |
202 | } |
203 | } |
204 | |
205 | GreedyScheduler::~GreedyScheduler() { |
206 | for (auto& ds : device_states_) { |
207 | delete ds.second; |
208 | } |
209 | } |
210 | |
211 | Microseconds GreedyScheduler::ComputeSchedule( |
212 | std::vector<Microseconds>* start_times) { |
213 | // Initialize pending_count |
214 | std::vector<int> pending_count(graph_->num_node_ids()); |
215 | InitializePending(graph_, &pending_count); |
216 | |
217 | // Initialize event queue |
218 | std::priority_queue<Event> event_queue; |
219 | Event src_event; |
220 | src_event.node = graph_->source_node(); |
221 | src_event.time = 0; |
222 | src_event.is_completion = true; |
223 | event_queue.push(src_event); |
224 | Microseconds max_completion = Microseconds(0); |
225 | |
226 | while (!event_queue.empty()) { |
227 | Event event = event_queue.top(); |
228 | event_queue.pop(); |
229 | if (event.is_completion) { |
230 | Sim* sim = device_states_[event.node->assigned_device_name()]; |
231 | --sim->num_running; |
232 | |
233 | if (event.time > max_completion) { |
234 | max_completion = event.time; |
235 | } |
236 | |
237 | for (const Edge* out_edge : event.node->out_edges()) { |
238 | Microseconds copy_time(0); |
239 | const Node* out = out_edge->dst(); |
240 | if (!out_edge->IsControlEdge() && |
241 | event.node->assigned_device_name() != out->assigned_device_name()) { |
242 | // TODO(yuanbyu): Use below with the real cost model. |
243 | // int index = out_edge->src_output(); |
244 | // Bytes nb = cost_model_->SizeEstimate(event.node, index); |
245 | // copy_time = CostModel::CopyTimeEstimate(nb); |
246 | copy_time = 10; |
247 | } |
248 | if ((*start_times)[out->id()] < event.time + copy_time) { |
249 | (*start_times)[out->id()] = event.time + copy_time; |
250 | } |
251 | |
252 | bool is_ready = UpdatePending(out_edge, &pending_count); |
253 | if (is_ready) { |
254 | Event e{out, (*start_times)[out->id()], false}; |
255 | event_queue.push(e); |
256 | } |
257 | } |
258 | } else { |
259 | Sim* sim = device_states_[event.node->assigned_device_name()]; |
260 | sim->ready_nodes.push_back(event.node); |
261 | } |
262 | |
263 | for (auto& x : device_states_) { |
264 | Sim* sim = x.second; |
265 | while (sim->num_running < sim->degree_parallelism && |
266 | !sim->ready_nodes.empty()) { |
267 | Event e; |
268 | e.node = GetNodeWithHighestPriority(sim->ready_nodes); |
269 | e.time = event.time + cost_model_->TimeEstimate(e.node); |
270 | e.is_completion = true; |
271 | event_queue.push(e); |
272 | (*start_times)[e.node->id()] = event.time; |
273 | ++sim->num_running; |
274 | } |
275 | } |
276 | } |
277 | return max_completion; |
278 | } |
279 | |
280 | const Node* GreedyScheduler::GetNodeWithHighestPriority( |
281 | const std::vector<const Node*>& nodes) { |
282 | const Node* curr_node = nullptr; |
283 | int64_t curr_priority = kint64max; |
284 | for (const Node* n : nodes) { |
285 | if ((*priority_)[n->id()] < curr_priority) { |
286 | curr_node = n; |
287 | curr_priority = (*priority_)[n->id()]; |
288 | } |
289 | } |
290 | return curr_node; |
291 | } |
292 | |
293 | PriorityScheduler::PriorityScheduler(const DeviceSet* devices, |
294 | const CostModel* cost_model, |
295 | const Graph* g) |
296 | : devices_(devices), cost_model_(cost_model), graph_(g) {} |
297 | |
298 | Microseconds PriorityScheduler::ComputeSchedule( |
299 | std::vector<Microseconds>* start_times) { |
300 | std::vector<int64_t> slacks; |
301 | SlackAnalysis slack(graph_, cost_model_); |
302 | slack.ComputeSlack(&slacks); |
303 | GreedyScheduler greedysched(devices_, cost_model_, graph_, &slacks); |
304 | return greedysched.ComputeSchedule(start_times); |
305 | } |
306 | |
307 | Microseconds PriorityScheduler::AssignPriorities( |
308 | std::vector<int64_t>* priorities) { |
309 | std::vector<Microseconds> start_times; |
310 | Microseconds makespan = ComputeSchedule(&start_times); |
311 | |
312 | for (const Node* n : graph_->nodes()) { |
313 | (*priorities)[n->id()] = start_times[n->id()].value(); |
314 | } |
315 | return makespan; |
316 | } |
317 | |
318 | } // namespace tensorflow |
319 | |