1#include <torch/torch.h>
2#include <chrono>
3#include <condition_variable>
4#include <mutex>
5
6class Baton {
7 public:
8 void post() {
9 std::unique_lock<std::mutex> l(lock_);
10 done_ = true;
11 cv_.notify_all();
12 }
13 void wait() {
14 std::unique_lock<std::mutex> l(lock_);
15 while (!done_) {
16 cv_.wait(l);
17 }
18 }
19
20 private:
21 std::mutex lock_;
22 std::condition_variable cv_;
23 bool done_{false};
24};
25
26void AtLaunch_Base(int32_t numIters) {
27 struct Helper {
28 explicit Helper(int32_t lim) : limit_(lim) {}
29 void operator()() {
30 if (++val_ == limit_) {
31 done.post();
32 } else {
33 at::launch([this]() { (*this)(); });
34 }
35 }
36 int val_{0};
37 int limit_;
38 Baton done;
39 };
40 Helper h(numIters);
41 auto start = std::chrono::system_clock::now();
42 h();
43 h.done.wait();
44 std::cout << "NoData "
45 << static_cast<double>(
46 std::chrono::duration_cast<std::chrono::microseconds>(
47 std::chrono::system_clock::now() - start)
48 .count()) /
49 static_cast<double>(numIters)
50 << " usec/each\n";
51}
52
53void AtLaunch_WithData(int32_t numIters, int32_t vecSize) {
54 struct Helper {
55 explicit Helper(int32_t lim) : limit_(lim) {}
56 void operator()(std::vector<int32_t> v) {
57 if (++val_ == limit_) {
58 done.post();
59 } else {
60 at::launch([this, v = std::move(v)]() { (*this)(v); });
61 }
62 }
63 int val_{0};
64 int limit_;
65 Baton done;
66 };
67 Helper h(numIters);
68 std::vector<int32_t> v(vecSize, 0);
69 auto start = std::chrono::system_clock::now();
70 h(v);
71 h.done.wait();
72 std::cout << "WithData(" << vecSize << "): "
73 << static_cast<double>(
74 std::chrono::duration_cast<std::chrono::microseconds>(
75 std::chrono::system_clock::now() - start)
76 .count()) /
77 static_cast<double>(numIters)
78 << " usec/each\n";
79}
80
81int main(int argc, char** argv) {
82 int32_t N = 1000000;
83 AtLaunch_Base(N);
84 AtLaunch_WithData(N, 0);
85 AtLaunch_WithData(N, 4);
86 AtLaunch_WithData(N, 256);
87 return 0;
88}
89