1 | #pragma once |
2 | |
3 | #ifndef _WIN32 |
4 | #include <signal.h> |
5 | #include <sys/wait.h> |
6 | #include <unistd.h> |
7 | #endif |
8 | |
9 | #include <sys/types.h> |
10 | #include <cstring> |
11 | |
12 | #include <condition_variable> |
13 | #include <mutex> |
14 | #include <string> |
15 | #include <system_error> |
16 | #include <vector> |
17 | |
18 | namespace c10d { |
19 | namespace test { |
20 | |
21 | class Semaphore { |
22 | public: |
23 | void post(int n = 1) { |
24 | std::unique_lock<std::mutex> lock(m_); |
25 | n_ += n; |
26 | cv_.notify_all(); |
27 | } |
28 | |
29 | void wait(int n = 1) { |
30 | std::unique_lock<std::mutex> lock(m_); |
31 | while (n_ < n) { |
32 | cv_.wait(lock); |
33 | } |
34 | n_ -= n; |
35 | } |
36 | |
37 | protected: |
38 | int n_ = 0; |
39 | std::mutex m_; |
40 | std::condition_variable cv_; |
41 | }; |
42 | |
43 | #ifdef _WIN32 |
44 | std::string autoGenerateTmpFilePath() { |
45 | char tmp[L_tmpnam_s]; |
46 | errno_t err; |
47 | err = tmpnam_s(tmp, L_tmpnam_s); |
48 | if (err != 0) |
49 | { |
50 | throw std::system_error(errno, std::system_category()); |
51 | } |
52 | return std::string(tmp); |
53 | } |
54 | |
55 | std::string tmppath() { |
56 | const char* tmpfile = getenv("TMPFILE" ); |
57 | if (tmpfile) { |
58 | return std::string(tmpfile); |
59 | } |
60 | else { |
61 | return autoGenerateTmpFilePath(); |
62 | } |
63 | } |
64 | #else |
65 | std::string tmppath() { |
66 | // TMPFILE is for manual test execution during which the user will specify |
67 | // the full temp file path using the environmental variable TMPFILE |
68 | const char* tmpfile = getenv("TMPFILE" ); |
69 | if (tmpfile) { |
70 | return std::string(tmpfile); |
71 | } |
72 | |
73 | const char* tmpdir = getenv("TMPDIR" ); |
74 | if (tmpdir == nullptr) { |
75 | tmpdir = "/tmp" ; |
76 | } |
77 | |
78 | // Create template |
79 | std::vector<char> tmp(256); |
80 | auto len = snprintf(tmp.data(), tmp.size(), "%s/testXXXXXX" , tmpdir); |
81 | tmp.resize(len); |
82 | |
83 | // Create temporary file |
84 | auto fd = mkstemp(&tmp[0]); |
85 | if (fd == -1) { |
86 | throw std::system_error(errno, std::system_category()); |
87 | } |
88 | close(fd); |
89 | return std::string(tmp.data(), tmp.size()); |
90 | } |
91 | #endif |
92 | |
93 | bool isTSANEnabled() { |
94 | auto s = std::getenv("PYTORCH_TEST_WITH_TSAN" ); |
95 | return s && strcmp(s, "1" ) == 0; |
96 | } |
97 | struct TemporaryFile { |
98 | std::string path; |
99 | |
100 | TemporaryFile() { |
101 | path = tmppath(); |
102 | } |
103 | |
104 | ~TemporaryFile() { |
105 | unlink(path.c_str()); |
106 | } |
107 | }; |
108 | |
109 | #ifndef _WIN32 |
110 | struct Fork { |
111 | pid_t pid; |
112 | |
113 | Fork() { |
114 | pid = fork(); |
115 | if (pid < 0) { |
116 | throw std::system_error(errno, std::system_category(), "fork" ); |
117 | } |
118 | } |
119 | |
120 | ~Fork() { |
121 | if (pid > 0) { |
122 | kill(pid, SIGKILL); |
123 | waitpid(pid, nullptr, 0); |
124 | } |
125 | } |
126 | |
127 | bool isChild() { |
128 | return pid == 0; |
129 | } |
130 | }; |
131 | #endif |
132 | |
133 | } // namespace test |
134 | } // namespace c10d |
135 | |