1#include <c10/util/ThreadLocal.h>
2#include <gtest/gtest.h>
3
4#include <atomic>
5#include <thread>
6
7namespace {
8
9TEST(ThreadLocal, TestNoOpScopeWithOneVar) {
10 C10_DEFINE_TLS_static(std::string, str);
11}
12
13TEST(ThreadLocalTest, TestNoOpScopeWithTwoVars) {
14 C10_DEFINE_TLS_static(std::string, str);
15 C10_DEFINE_TLS_static(std::string, str2);
16}
17
18TEST(ThreadLocalTest, TestScopeWithOneVar) {
19 C10_DEFINE_TLS_static(std::string, str);
20 EXPECT_EQ(*str, std::string());
21 EXPECT_EQ(*str, "");
22
23 *str = "abc";
24 EXPECT_EQ(*str, "abc");
25 EXPECT_EQ(str->length(), 3);
26 EXPECT_EQ(str.get(), "abc");
27}
28
29TEST(ThreadLocalTest, TestScopeWithTwoVars) {
30 C10_DEFINE_TLS_static(std::string, str);
31 EXPECT_EQ(*str, "");
32
33 C10_DEFINE_TLS_static(std::string, str2);
34
35 *str = "abc";
36 EXPECT_EQ(*str, "abc");
37 EXPECT_EQ(*str2, "");
38
39 *str2 = *str;
40 EXPECT_EQ(*str, "abc");
41 EXPECT_EQ(*str2, "abc");
42
43 str->clear();
44 EXPECT_EQ(*str, "");
45 EXPECT_EQ(*str2, "abc");
46}
47
48TEST(ThreadLocalTest, TestInnerScopeWithTwoVars) {
49 C10_DEFINE_TLS_static(std::string, str);
50 *str = "abc";
51
52 {
53 C10_DEFINE_TLS_static(std::string, str2);
54 EXPECT_EQ(*str2, "");
55
56 *str2 = *str;
57 EXPECT_EQ(*str, "abc");
58 EXPECT_EQ(*str2, "abc");
59
60 str->clear();
61 EXPECT_EQ(*str2, "abc");
62 }
63
64 EXPECT_EQ(*str, "");
65}
66
67struct Foo {
68 C10_DECLARE_TLS_class_static(Foo, std::string, str_);
69};
70
71C10_DEFINE_TLS_class_static(Foo, std::string, str_);
72
73TEST(ThreadLocalTest, TestClassScope) {
74 EXPECT_EQ(*Foo::str_, "");
75
76 *Foo::str_ = "abc";
77 EXPECT_EQ(*Foo::str_, "abc");
78 EXPECT_EQ(Foo::str_->length(), 3);
79 EXPECT_EQ(Foo::str_.get(), "abc");
80}
81
82C10_DEFINE_TLS_static(std::string, global_);
83C10_DEFINE_TLS_static(std::string, global2_);
84TEST(ThreadLocalTest, TestTwoGlobalScopeVars) {
85 EXPECT_EQ(*global_, "");
86 EXPECT_EQ(*global2_, "");
87
88 *global_ = "abc";
89 EXPECT_EQ(global_->length(), 3);
90 EXPECT_EQ(*global_, "abc");
91 EXPECT_EQ(*global2_, "");
92
93 *global2_ = *global_;
94 EXPECT_EQ(*global_, "abc");
95 EXPECT_EQ(*global2_, "abc");
96
97 global_->clear();
98 EXPECT_EQ(*global_, "");
99 EXPECT_EQ(*global2_, "abc");
100 EXPECT_EQ(global2_.get(), "abc");
101}
102
103C10_DEFINE_TLS_static(std::string, global3_);
104TEST(ThreadLocalTest, TestGlobalWithLocalScopeVars) {
105 *global3_ = "abc";
106
107 C10_DEFINE_TLS_static(std::string, str);
108
109 std::swap(*global3_, *str);
110 EXPECT_EQ(*str, "abc");
111 EXPECT_EQ(*global3_, "");
112}
113
114TEST(ThreadLocalTest, TestThreadWithLocalScopeVar) {
115 C10_DEFINE_TLS_static(std::string, str);
116 *str = "abc";
117
118 std::atomic_bool b(false);
119 std::thread t([&b]() {
120 EXPECT_EQ(*str, "");
121 *str = "def";
122 b = true;
123 EXPECT_EQ(*str, "def");
124 });
125 t.join();
126
127 EXPECT_TRUE(b);
128 EXPECT_EQ(*str, "abc");
129}
130
131C10_DEFINE_TLS_static(std::string, global4_);
132TEST(ThreadLocalTest, TestThreadWithGlobalScopeVar) {
133 *global4_ = "abc";
134
135 std::atomic_bool b(false);
136 std::thread t([&b]() {
137 EXPECT_EQ(*global4_, "");
138 *global4_ = "def";
139 b = true;
140 EXPECT_EQ(*global4_, "def");
141 });
142 t.join();
143
144 EXPECT_TRUE(b);
145 EXPECT_EQ(*global4_, "abc");
146}
147
148TEST(ThreadLocalTest, TestObjectsAreReleased) {
149 static std::atomic<int> ctors{0};
150 static std::atomic<int> dtors{0};
151 struct A {
152 A() : i() {
153 ++ctors;
154 }
155
156 ~A() {
157 ++dtors;
158 }
159
160 A(const A&) = delete;
161 A& operator=(const A&) = delete;
162
163 int i;
164 };
165
166 C10_DEFINE_TLS_static(A, a);
167
168 std::atomic_bool b(false);
169 std::thread t([&b]() {
170 EXPECT_EQ(a->i, 0);
171 a->i = 1;
172 EXPECT_EQ(a->i, 1);
173 b = true;
174 });
175 t.join();
176
177 EXPECT_TRUE(b);
178
179 EXPECT_EQ(ctors, 1);
180 EXPECT_EQ(dtors, 1);
181}
182
183TEST(ThreadLocalTest, TestObjectsAreReleasedByNonstaticThreadLocal) {
184 static std::atomic<int> ctors(0);
185 static std::atomic<int> dtors(0);
186 struct A {
187 A() : i() {
188 ++ctors;
189 }
190
191 ~A() {
192 ++dtors;
193 }
194
195 A(const A&) = delete;
196 A& operator=(const A&) = delete;
197
198 int i;
199 };
200
201 std::atomic_bool b(false);
202 std::thread t([&b]() {
203#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
204 ::c10::ThreadLocal<A> a;
205#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
206 ::c10::ThreadLocal<A> a([]() {
207 static thread_local A var;
208 return &var;
209 });
210#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
211
212 EXPECT_EQ(a->i, 0);
213 a->i = 1;
214 EXPECT_EQ(a->i, 1);
215 b = true;
216 });
217 t.join();
218
219 EXPECT_TRUE(b);
220
221 EXPECT_EQ(ctors, 1);
222 EXPECT_EQ(dtors, 1);
223}
224
225} // namespace
226