1 | #include <torch/torch.h> |
---|---|
2 | #include <chrono> |
3 | #include <condition_variable> |
4 | #include <mutex> |
5 | |
6 | class Baton { |
7 | public: |
8 | void post() { |
9 | std::unique_lock<std::mutex> l(lock_); |
10 | done_ = true; |
11 | cv_.notify_all(); |
12 | } |
13 | void wait() { |
14 | std::unique_lock<std::mutex> l(lock_); |
15 | while (!done_) { |
16 | cv_.wait(l); |
17 | } |
18 | } |
19 | |
20 | private: |
21 | std::mutex lock_; |
22 | std::condition_variable cv_; |
23 | bool done_{false}; |
24 | }; |
25 | |
26 | void AtLaunch_Base(int32_t numIters) { |
27 | struct Helper { |
28 | explicit Helper(int32_t lim) : limit_(lim) {} |
29 | void operator()() { |
30 | if (++val_ == limit_) { |
31 | done.post(); |
32 | } else { |
33 | at::launch([this]() { (*this)(); }); |
34 | } |
35 | } |
36 | int val_{0}; |
37 | int limit_; |
38 | Baton done; |
39 | }; |
40 | Helper h(numIters); |
41 | auto start = std::chrono::system_clock::now(); |
42 | h(); |
43 | h.done.wait(); |
44 | std::cout << "NoData " |
45 | << static_cast<double>( |
46 | std::chrono::duration_cast<std::chrono::microseconds>( |
47 | std::chrono::system_clock::now() - start) |
48 | .count()) / |
49 | static_cast<double>(numIters) |
50 | << " usec/each\n"; |
51 | } |
52 | |
53 | void AtLaunch_WithData(int32_t numIters, int32_t vecSize) { |
54 | struct Helper { |
55 | explicit Helper(int32_t lim) : limit_(lim) {} |
56 | void operator()(std::vector<int32_t> v) { |
57 | if (++val_ == limit_) { |
58 | done.post(); |
59 | } else { |
60 | at::launch([this, v = std::move(v)]() { (*this)(v); }); |
61 | } |
62 | } |
63 | int val_{0}; |
64 | int limit_; |
65 | Baton done; |
66 | }; |
67 | Helper h(numIters); |
68 | std::vector<int32_t> v(vecSize, 0); |
69 | auto start = std::chrono::system_clock::now(); |
70 | h(v); |
71 | h.done.wait(); |
72 | std::cout << "WithData("<< vecSize << "): " |
73 | << static_cast<double>( |
74 | std::chrono::duration_cast<std::chrono::microseconds>( |
75 | std::chrono::system_clock::now() - start) |
76 | .count()) / |
77 | static_cast<double>(numIters) |
78 | << " usec/each\n"; |
79 | } |
80 | |
81 | int main(int argc, char** argv) { |
82 | int32_t N = 1000000; |
83 | AtLaunch_Base(N); |
84 | AtLaunch_WithData(N, 0); |
85 | AtLaunch_WithData(N, 4); |
86 | AtLaunch_WithData(N, 256); |
87 | return 0; |
88 | } |
89 |