1// Copyright 2020 The Marl Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// marl::DAG<> provides an ahead of time, declarative, directed acyclic
16// task graph.
17
18#ifndef marl_dag_h
19#define marl_dag_h
20
21#include "containers.h"
22#include "export.h"
23#include "memory.h"
24#include "scheduler.h"
25#include "waitgroup.h"
26
27namespace marl {
28namespace detail {
29using DAGCounter = std::atomic<uint32_t>;
30template <typename T>
31struct DAGRunContext {
32 T data;
33 Allocator::unique_ptr<DAGCounter> counters;
34
35 template <typename F>
36 MARL_NO_EXPORT inline void invoke(F&& f) {
37 f(data);
38 }
39};
40template <>
41struct DAGRunContext<void> {
42 Allocator::unique_ptr<DAGCounter> counters;
43
44 template <typename F>
45 MARL_NO_EXPORT inline void invoke(F&& f) {
46 f();
47 }
48};
49template <typename T>
50struct DAGWork {
51 using type = std::function<void(T)>;
52};
53template <>
54struct DAGWork<void> {
55 using type = std::function<void()>;
56};
57} // namespace detail
58
59///////////////////////////////////////////////////////////////////////////////
60// Forward declarations
61///////////////////////////////////////////////////////////////////////////////
62template <typename T>
63class DAG;
64
65template <typename T>
66class DAGBuilder;
67
68template <typename T>
69class DAGNodeBuilder;
70
71///////////////////////////////////////////////////////////////////////////////
72// DAGBase<T>
73///////////////////////////////////////////////////////////////////////////////
74
75// DAGBase is derived by DAG<T> and DAG<void>. It has no public API.
76template <typename T>
77class DAGBase {
78 protected:
79 friend DAGBuilder<T>;
80 friend DAGNodeBuilder<T>;
81
82 using RunContext = detail::DAGRunContext<T>;
83 using Counter = detail::DAGCounter;
84 using NodeIndex = size_t;
85 using Work = typename detail::DAGWork<T>::type;
86 static const constexpr size_t NumReservedNodes = 32;
87 static const constexpr size_t NumReservedNumOuts = 4;
88 static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0);
89 static const constexpr NodeIndex RootIndex = 0;
90 static const constexpr NodeIndex InvalidNodeIndex =
91 ~static_cast<NodeIndex>(0);
92
93 // DAG work node.
94 struct Node {
95 MARL_NO_EXPORT inline Node() = default;
96 MARL_NO_EXPORT inline Node(Work&& work);
97
98 // The work to perform for this node in the graph.
99 Work work;
100
101 // counterIndex if valid, is the index of the counter in the RunContext for
102 // this node. The counter is decremented for each completed dependency task
103 // (ins), and once it reaches 0, this node will be invoked.
104 size_t counterIndex = InvalidCounterIndex;
105
106 // Indices for all downstream nodes.
107 containers::vector<NodeIndex, NumReservedNumOuts> outs;
108 };
109
110 // initCounters() allocates and initializes the ctx->coutners from
111 // initialCounters.
112 MARL_NO_EXPORT inline void initCounters(RunContext* ctx,
113 Allocator* allocator);
114
115 // notify() is called each time a dependency task (ins) has completed for the
116 // node with the given index.
117 // If all dependency tasks have completed (or this is the root node) then
118 // notify() returns true and the caller should then call invoke().
119 MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex);
120
121 // invoke() calls the work function for the node with the given index, then
122 // calls notify() and possibly invoke() for all the dependee nodes.
123 MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*);
124
125 // nodes is the full list of the nodes in the graph.
126 // nodes[0] is always the root node, which has no dependencies (ins).
127 containers::vector<Node, NumReservedNodes> nodes;
128
129 // initialCounters is a list of initial counter values to be copied to
130 // RunContext::counters on DAG<>::run().
131 // initialCounters is indexed by Node::counterIndex, and only contains counts
132 // for nodes that have at least 2 dependencies (ins) - because of this the
133 // number of entries in initialCounters may be fewer than nodes.
134 containers::vector<uint32_t, NumReservedNodes> initialCounters;
135};
136
137template <typename T>
138DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {}
139
140template <typename T>
141void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) {
142 auto numCounters = initialCounters.size();
143 ctx->counters = allocator->make_unique_n<Counter>(numCounters);
144 for (size_t i = 0; i < numCounters; i++) {
145 ctx->counters.get()[i] = {initialCounters[i]};
146 }
147}
148
149template <typename T>
150bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) {
151 Node* node = &nodes[nodeIdx];
152
153 // If we have multiple dependencies, decrement the counter and check whether
154 // we've reached 0.
155 if (node->counterIndex == InvalidCounterIndex) {
156 return true;
157 }
158 auto counters = ctx->counters.get();
159 auto counter = --counters[node->counterIndex];
160 return counter == 0;
161}
162
163template <typename T>
164void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) {
165 Node* node = &nodes[nodeIdx];
166
167 // Run this node's work.
168 if (node->work) {
169 ctx->invoke(node->work);
170 }
171
172 // Then call notify() on all dependees (outs), and invoke() those that
173 // returned true.
174 // We buffer the node to invoke (toInvoke) so we can schedule() all but the
175 // last node to invoke(), and directly call the last invoke() on this thread.
176 // This is done to avoid the overheads of scheduling when a direct call would
177 // suffice.
178 NodeIndex toInvoke = InvalidNodeIndex;
179 for (NodeIndex idx : node->outs) {
180 if (notify(ctx, idx)) {
181 if (toInvoke != InvalidNodeIndex) {
182 wg->add(1);
183 // Schedule while promoting the WaitGroup capture from a pointer
184 // reference to a value. This ensures that the WaitGroup isn't dropped
185 // while in use.
186 schedule(
187 [=](WaitGroup wg) {
188 invoke(ctx, toInvoke, &wg);
189 wg.done();
190 },
191 *wg);
192 }
193 toInvoke = idx;
194 }
195 }
196 if (toInvoke != InvalidNodeIndex) {
197 invoke(ctx, toInvoke, wg);
198 }
199}
200
201///////////////////////////////////////////////////////////////////////////////
202// DAGNodeBuilder<T>
203///////////////////////////////////////////////////////////////////////////////
204
205// DAGNodeBuilder is the builder interface for a DAG node.
206template <typename T>
207class DAGNodeBuilder {
208 using NodeIndex = typename DAGBase<T>::NodeIndex;
209
210 public:
211 // then() builds and returns a new DAG node that will be invoked after this
212 // node has completed.
213 //
214 // F is a function that will be called when the new DAG node is invoked, with
215 // the signature:
216 // void(T) when T is not void
217 // or
218 // void() when T is void
219 template <typename F>
220 MARL_NO_EXPORT inline DAGNodeBuilder then(F&&);
221
222 private:
223 friend DAGBuilder<T>;
224 MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex);
225 DAGBuilder<T>* builder;
226 NodeIndex index;
227};
228
229template <typename T>
230DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index)
231 : builder(builder), index(index) {}
232
233template <typename T>
234template <typename F>
235DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) {
236 auto node = builder->node(std::move(work));
237 builder->addDependency(*this, node);
238 return node;
239}
240
241///////////////////////////////////////////////////////////////////////////////
242// DAGBuilder<T>
243///////////////////////////////////////////////////////////////////////////////
244template <typename T>
245class DAGBuilder {
246 public:
247 // DAGBuilder constructor
248 MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default);
249
250 // root() returns the root DAG node.
251 MARL_NO_EXPORT inline DAGNodeBuilder<T> root();
252
253 // node() builds and returns a new DAG node with no initial dependencies.
254 // The returned node must be attached to the graph in order to invoke F or any
255 // of the dependees of this returned node.
256 //
257 // F is a function that will be called when the new DAG node is invoked, with
258 // the signature:
259 // void(T) when T is not void
260 // or
261 // void() when T is void
262 template <typename F>
263 MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work);
264
265 // node() builds and returns a new DAG node that depends on all the tasks in
266 // after to be completed before invoking F.
267 //
268 // F is a function that will be called when the new DAG node is invoked, with
269 // the signature:
270 // void(T) when T is not void
271 // or
272 // void() when T is void
273 template <typename F>
274 MARL_NO_EXPORT inline DAGNodeBuilder<T> node(
275 F&& work,
276 std::initializer_list<DAGNodeBuilder<T>> after);
277
278 // addDependency() adds parent as dependency on child. All dependencies of
279 // child must have completed before child is invoked.
280 MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent,
281 DAGNodeBuilder<T> child);
282
283 // build() constructs and returns the DAG. No other methods of this class may
284 // be called after calling build().
285 MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build();
286
287 private:
288 static const constexpr size_t NumReservedNumIns = 4;
289 using Node = typename DAG<T>::Node;
290
291 // The DAG being built.
292 Allocator::unique_ptr<DAG<T>> dag;
293
294 // Number of dependencies (ins) for each node in dag->nodes.
295 containers::vector<uint32_t, NumReservedNumIns> numIns;
296};
297
298template <typename T>
299DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */)
300 : dag(allocator->make_unique<DAG<T>>()), numIns(allocator) {
301 // Add root
302 dag->nodes.emplace_back(Node{});
303 numIns.emplace_back(0);
304}
305
306template <typename T>
307DAGNodeBuilder<T> DAGBuilder<T>::root() {
308 return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex};
309}
310
311template <typename T>
312template <typename F>
313DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) {
314 return node(std::forward<F>(work), {});
315}
316
317template <typename T>
318template <typename F>
319DAGNodeBuilder<T> DAGBuilder<T>::node(
320 F&& work,
321 std::initializer_list<DAGNodeBuilder<T>> after) {
322 MARL_ASSERT(numIns.size() == dag->nodes.size(),
323 "NodeBuilder vectors out of sync");
324 auto index = dag->nodes.size();
325 numIns.emplace_back(0);
326 dag->nodes.emplace_back(Node{std::move(work)});
327 auto node = DAGNodeBuilder<T>{this, index};
328 for (auto in : after) {
329 addDependency(in, node);
330 }
331 return node;
332}
333
334template <typename T>
335void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent,
336 DAGNodeBuilder<T> child) {
337 numIns[child.index]++;
338 dag->nodes[parent.index].outs.push_back(child.index);
339}
340
341template <typename T>
342Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() {
343 auto numNodes = dag->nodes.size();
344 MARL_ASSERT(numIns.size() == dag->nodes.size(),
345 "NodeBuilder vectors out of sync");
346 for (size_t i = 0; i < numNodes; i++) {
347 if (numIns[i] > 1) {
348 auto& node = dag->nodes[i];
349 node.counterIndex = dag->initialCounters.size();
350 dag->initialCounters.push_back(numIns[i]);
351 }
352 }
353 return std::move(dag);
354}
355
356///////////////////////////////////////////////////////////////////////////////
357// DAG<T>
358///////////////////////////////////////////////////////////////////////////////
359template <typename T = void>
360class DAG : public DAGBase<T> {
361 public:
362 using Builder = DAGBuilder<T>;
363 using NodeBuilder = DAGNodeBuilder<T>;
364
365 // run() invokes the function of each node in the graph of the DAG, passing
366 // data to each, starting with the root node. All dependencies need to have
367 // completed their function before dependees will be invoked.
368 MARL_NO_EXPORT inline void run(T& data,
369 Allocator* allocator = Allocator::Default);
370};
371
372template <typename T>
373void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) {
374 typename DAGBase<T>::RunContext ctx{arg};
375 this->initCounters(&ctx, allocator);
376 WaitGroup wg;
377 this->invoke(&ctx, this->RootIndex, &wg);
378 wg.wait();
379}
380
381///////////////////////////////////////////////////////////////////////////////
382// DAG<void>
383///////////////////////////////////////////////////////////////////////////////
384template <>
385class DAG<void> : public DAGBase<void> {
386 public:
387 using Builder = DAGBuilder<void>;
388 using NodeBuilder = DAGNodeBuilder<void>;
389
390 // run() invokes the function of each node in the graph of the DAG, starting
391 // with the root node. All dependencies need to have completed their function
392 // before dependees will be invoked.
393 MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default);
394};
395
396void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) {
397 typename DAGBase<void>::RunContext ctx{};
398 this->initCounters(&ctx, allocator);
399 WaitGroup wg;
400 this->invoke(&ctx, this->RootIndex, &wg);
401 wg.wait();
402}
403
404} // namespace marl
405
406#endif // marl_dag_h
407