1#pragma once
2
3#ifndef _WIN32
4#include <signal.h>
5#include <sys/wait.h>
6#include <unistd.h>
7#endif
8
9#include <sys/types.h>
10#include <cstring>
11
12#include <condition_variable>
13#include <mutex>
14#include <string>
15#include <system_error>
16#include <vector>
17
18namespace c10d {
19namespace test {
20
21class Semaphore {
22 public:
23 void post(int n = 1) {
24 std::unique_lock<std::mutex> lock(m_);
25 n_ += n;
26 cv_.notify_all();
27 }
28
29 void wait(int n = 1) {
30 std::unique_lock<std::mutex> lock(m_);
31 while (n_ < n) {
32 cv_.wait(lock);
33 }
34 n_ -= n;
35 }
36
37 protected:
38 int n_ = 0;
39 std::mutex m_;
40 std::condition_variable cv_;
41};
42
43#ifdef _WIN32
44std::string autoGenerateTmpFilePath() {
45 char tmp[L_tmpnam_s];
46 errno_t err;
47 err = tmpnam_s(tmp, L_tmpnam_s);
48 if (err != 0)
49 {
50 throw std::system_error(errno, std::system_category());
51 }
52 return std::string(tmp);
53}
54
55std::string tmppath() {
56 const char* tmpfile = getenv("TMPFILE");
57 if (tmpfile) {
58 return std::string(tmpfile);
59 }
60 else {
61 return autoGenerateTmpFilePath();
62 }
63}
64#else
65std::string tmppath() {
66 // TMPFILE is for manual test execution during which the user will specify
67 // the full temp file path using the environmental variable TMPFILE
68 const char* tmpfile = getenv("TMPFILE");
69 if (tmpfile) {
70 return std::string(tmpfile);
71 }
72
73 const char* tmpdir = getenv("TMPDIR");
74 if (tmpdir == nullptr) {
75 tmpdir = "/tmp";
76 }
77
78 // Create template
79 std::vector<char> tmp(256);
80 auto len = snprintf(tmp.data(), tmp.size(), "%s/testXXXXXX", tmpdir);
81 tmp.resize(len);
82
83 // Create temporary file
84 auto fd = mkstemp(&tmp[0]);
85 if (fd == -1) {
86 throw std::system_error(errno, std::system_category());
87 }
88 close(fd);
89 return std::string(tmp.data(), tmp.size());
90}
91#endif
92
93bool isTSANEnabled() {
94 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
95 return s && strcmp(s, "1") == 0;
96}
97struct TemporaryFile {
98 std::string path;
99
100 TemporaryFile() {
101 path = tmppath();
102 }
103
104 ~TemporaryFile() {
105 unlink(path.c_str());
106 }
107};
108
109#ifndef _WIN32
110struct Fork {
111 pid_t pid;
112
113 Fork() {
114 pid = fork();
115 if (pid < 0) {
116 throw std::system_error(errno, std::system_category(), "fork");
117 }
118 }
119
120 ~Fork() {
121 if (pid > 0) {
122 kill(pid, SIGKILL);
123 waitpid(pid, nullptr, 0);
124 }
125 }
126
127 bool isChild() {
128 return pid == 0;
129 }
130};
131#endif
132
133} // namespace test
134} // namespace c10d
135