1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4
5#include <ATen/core/qualified_name.h>
6#include <torch/csrc/jit/api/module.h>
7#include <torch/csrc/jit/frontend/resolver.h>
8#include <torch/csrc/jit/serialization/import.h>
9#include <torch/csrc/jit/serialization/import_source.h>
10#include <torch/csrc/jit/testing/file_check.h>
11#include <torch/torch.h>
12
13namespace torch {
14namespace jit {
15
16static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
17class OneInterface(ModuleInterface):
18 def one(self, x: Tensor, y: Tensor) -> Tensor:
19 pass
20)JIT";
21
22static const std::vector<std::string> subModuleMethodsSrc = {R"JIT(
23def one(self, x: Tensor, y: Tensor) -> Tensor:
24 return self.attr * x + y + 1
25
26def forward(self, x: Tensor) -> Tensor:
27 return self.attr + x
28)JIT"};
29
30static const std::string parentForward = R"JIT(
31def forward(self, x: Tensor) -> Tensor:
32 return self.subMod1.one(x, x) + self.subMod2.one(x, x)
33)JIT";
34
35static void import_libs(
36 std::shared_ptr<CompilationUnit> cu,
37 const std::string& class_name,
38 const std::shared_ptr<Source>& src,
39 const std::vector<at::IValue>& tensor_table) {
40 SourceImporter si(
41 cu,
42 &tensor_table,
43 [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
44 /*version=*/2);
45 si.loadType(QualifiedName(class_name));
46}
47
48TEST(ModuleAPITest, MethodRunAsync) {
49 // Module m("m");
50 // m.define(R"(
51 // def forward(self):
52 // r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
53 // r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
54 // return r1.wait() + r2.wait()
55 // )");
56 std::string filePath(__FILE__);
57 auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
58 // borrow model file from TEST(GraphExecutorTest, runAsync_executor)
59 testModelFile.append("test_interpreter_async.pt");
60 auto m = load(testModelFile);
61
62 auto counter = 0;
63 std::mutex mtx;
64
65 auto launcher = [&](std::function<void()> f) {
66 mtx.lock();
67 ++counter;
68 mtx.unlock();
69 at::launch(std::move(f));
70 };
71
72 auto method = m.get_method("forward");
73
74 std::vector<IValue> stack;
75 auto kwargs = std::unordered_map<std::string, at::IValue>();
76 auto future = method.run_async(stack, kwargs, launcher);
77
78 future->wait();
79
80 // expect 2 forks and 2 wait callbacks being excuted on provided taskLauncher
81 // but ivalue::Future would be marked completed and release wait before
82 // finishing all callbacks
83 ASSERT_GE(counter, 2);
84}
85
86TEST(ModuleAPITest, Clone) {
87 auto cu = std::make_shared<CompilationUnit>();
88 // creating child module
89 auto child = ClassType::create("child", cu, true);
90 auto attr_name = "attr";
91 child->addAttribute(attr_name, IntType::get());
92 Module c1(cu, child);
93 auto v1 = IValue(2);
94 c1.register_attribute(attr_name, IntType::get(), v1, false);
95 Module c2(cu, child);
96 auto v2 = IValue(3);
97 c2.register_attribute(attr_name, IntType::get(), v2, false);
98
99 // attach two child module instance to parent that shares
100 // ClassType
101 auto parent = ClassType::create("parent", cu, true);
102 Module p(cu, parent);
103 p.register_attribute("c1", c1.type(), c1._ivalue(), false);
104 p.register_attribute("c2", c2.type(), c2._ivalue(), false);
105
106 // clone parent
107 Module p2 = p.clone();
108 // check the two child module has the same ClassType
109 ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
110 // but different instances
111 ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
112 ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
113}
114
115TEST(ModuleAPITest, CloneWithModuleInterface) {
116 auto cu = std::make_shared<CompilationUnit>();
117
118 // define a initial module with two submods share same interface
119 Module parentMod("parentMod", cu);
120 Module subMod1("subMod1", cu);
121 Module subMod2("subMod2", cu);
122
123 std::vector<at::IValue> constantTable;
124 import_libs(
125 cu,
126 "__torch__.OneInterface",
127 std::make_shared<Source>(moduleInterfaceSrc),
128 constantTable);
129
130 auto v1 = IValue(2);
131 subMod1.register_attribute("attr", IntType::get(), v1, false);
132
133 auto v2 = IValue(4);
134 subMod2.register_attribute("attr", IntType::get(), v2, false);
135
136 for (const std::string& method : subModuleMethodsSrc) {
137 subMod1.define(method, nativeResolver());
138 subMod2.define(method, nativeResolver());
139 }
140
141 parentMod.register_attribute(
142 "subMod1",
143 cu->get_interface("__torch__.OneInterface"),
144 subMod1._ivalue());
145 parentMod.register_attribute(
146 "subMod2",
147 cu->get_interface("__torch__.OneInterface"),
148 subMod2._ivalue());
149
150 parentMod.define(parentForward, nativeResolver());
151
152 Module clonedMod = parentMod.clone();
153
154 // clone will copy both type and data, therefore we'll have a
155 // different type
156 ASSERT_NE(clonedMod.type(), parentMod.type());
157}
158
159TEST(ModuleAPITest, Copy) {
160 auto cu = std::make_shared<CompilationUnit>();
161 auto cls = ClassType::create("foo.bar", cu, true);
162 auto attr_name = "attr";
163 cls->addAttribute(attr_name, IntType::get());
164 Module m(cu, cls);
165 auto v = IValue(2);
166 m.register_attribute(attr_name, IntType::get(), v, false);
167
168 Module m2 = m.clone();
169 Module m3 = m.copy();
170
171 // Make sure copy works
172 ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
173 ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
174
175 // clone will copy both type and data, therefore we'll have a
176 // different type
177 ASSERT_NE(m.type(), m2.type());
178 // copy only copies data, type is shared
179 ASSERT_EQ(m.type(), m3.type());
180
181 // change value of copied instance
182 m3.register_attribute(attr_name, IntType::get(), IValue(3), false);
183 // Verify value of original instance doesn't change
184 ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
185 ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
186}
187
188TEST(ModuleAPITest, DeepCopy) {
189 auto cu = std::make_shared<CompilationUnit>();
190 auto cls = ClassType::create("foo.bar", cu, true);
191 auto str_attr = "str_attr";
192 auto int_attr = "int_attr";
193 auto tensor_attr = "tensor_attr";
194 auto tensor_list_attr = "tensor_list_attr";
195 cls->addAttribute(int_attr, IntType::get());
196 cls->addAttribute(str_attr, StringType::get());
197 cls->addAttribute(tensor_attr, TensorType::get());
198 cls->addAttribute(tensor_list_attr, ListType::ofTensors());
199 Module m(cu, cls);
200 c10::List<at::Tensor> list({at::rand(5), at::rand(5)});
201 m.setattr(int_attr, IValue(2));
202 m.setattr(str_attr, IValue("str"));
203 m.setattr(tensor_attr, at::randn(5));
204 m.setattr(tensor_list_attr, list);
205
206 Module m2 = m.deepcopy();
207 Module m3 = m.copy();
208 // Make sure copy works
209 ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
210 ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
211
212 // Test overlaps
213 ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
214 ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
215
216 // Both deepcopy and copy will preserve the type
217 ASSERT_EQ(m.type(), m2.type());
218 ASSERT_EQ(m.type(), m3.type());
219
220 // change int value of copied instances
221 m2.setattr(int_attr, IValue(3));
222 m3.setattr(int_attr, IValue(4));
223
224 // Verify value of original instance doesn't change
225 ASSERT_EQ(m.attr(int_attr).toInt(), 2);
226 ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
227 ASSERT_EQ(m3.attr(int_attr).toInt(), 4);
228
229 // change Tensor value of copied instances
230 at::Tensor t1 = m.attr(tensor_attr).toTensor();
231 at::Tensor t2 =
232 m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
233 at::Tensor t3 =
234 m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
235 // check copy works
236 ASSERT_TRUE(t1.equal(t2));
237 ASSERT_TRUE(t1.equal(t3));
238
239 // zero out t1
240 t1.zero_();
241 // check that t2 is not affected because it is a deep copy
242 ASSERT_TRUE(!t1.equal(t2));
243 // check that t3 is the same as t1 since it is a shallow copy
244 ASSERT_TRUE(t1.equal(t3));
245}
246
247TEST(ModuleAPITest, DeepCopyString) {
248 auto cu = std::make_shared<CompilationUnit>();
249 auto cls = ClassType::create("foo.bar", cu, true);
250 auto attr1 = "attr1";
251 cls->addAttribute(attr1, StringType::get());
252 std::string str = "str";
253 Module m(cu, cls);
254 m.setattr(attr1, str);
255 auto copied = m.deepcopy();
256 auto original_str = str;
257 ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
258 // check string mutation is not reflected in the copied module
259 str += "str";
260 ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
261}
262
263TEST(ModuleAPITest, DeepCopyEnum) {
264 auto cu = std::make_shared<CompilationUnit>();
265 auto cls = ClassType::create("foo.bar", cu, true);
266 auto enum_attr = "enum_attr";
267 auto int_enum_type = EnumType::create(
268 "enum_class",
269 IntType::get(),
270 {{"enum_name_1", 1}, {"enum_name_2", 2}},
271 cu);
272 cls->addAttribute(enum_attr, int_enum_type);
273 Module m(cu, cls);
274 m.setattr(
275 enum_attr,
276 IValue(c10::make_intrusive<ivalue::EnumHolder>(
277 int_enum_type, "enum_name_1", 1)));
278 Module m2 = m.deepcopy();
279
280 // Make sure deepcopy works
281 c10::ivalue::EnumHolder* m2_holder = m2.attr(enum_attr).toEnumHolder().get();
282 ASSERT_EQ(m2_holder->value().toInt(), 1);
283 ASSERT_EQ(m2_holder->name(), "enum_name_1");
284 ASSERT_EQ(m2_holder->type(), int_enum_type);
285
286 // Test overlaps
287 ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
288
289 // Deepcopy will preserve the type
290 ASSERT_EQ(m.type(), m2.type());
291
292 // Change original, should not affect deepcopy
293 m.setattr(
294 enum_attr,
295 IValue(c10::make_intrusive<ivalue::EnumHolder>(
296 int_enum_type, "enum_name_2", 2)));
297 ASSERT_NE(
298 m.attr(enum_attr).toEnumHolder().get()->value().toInt(),
299 m2.attr(enum_attr).toEnumHolder().get()->value().toInt());
300}
301
302TEST(ModuleAPITest, DeepCopyPreservesAliasing) {
303 // check deepcopy preserves aliasing
304 auto cu = std::make_shared<CompilationUnit>();
305 auto cls = ClassType::create("foo.bar", cu, true);
306 auto attr1 = "attr1";
307 auto attr2 = "attr2";
308 auto attr3 = "attr3";
309 auto attr4 = "attr4";
310 cls->addAttribute(attr1, ListType::ofTensors());
311 cls->addAttribute(attr2, ListType::ofTensors());
312 cls->addAttribute(attr3, TensorType::get());
313 cls->addAttribute(attr4, TensorType::get());
314 Module m(cu, cls);
315 auto t1 = at::rand(5);
316 auto t2 = at::rand(5);
317 auto t3 = at::rand(5);
318 auto t4 = at::rand({5, 2});
319 c10::List<at::Tensor> list1({t1, t2});
320 c10::List<at::Tensor> list2({t1, t3});
321 // first element of attr1 and attr2 are aliased
322 m.setattr(attr1, list1);
323 m.setattr(attr2, list2);
324 m.setattr(attr3, t4);
325 m.setattr(attr4, t4.view(-1));
326
327 auto copied = m.deepcopy();
328 // test tensor aliasing
329 auto copied_attr1_t1 = copied.attr(attr1).toList().get(0);
330 auto copied_attr2_t1 = copied.attr(attr2).toList().get(0);
331 ASSERT_TRUE(copied_attr1_t1.isAliasOf(copied_attr2_t1));
332
333 // test aliasing from view
334 auto copied_attr3 = copied.attr(attr3);
335 auto copied_attr4 = copied.attr(attr3);
336 ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
337}
338
339TEST(ModuleAPITest, Constants) {
340 auto cu = std::make_shared<CompilationUnit>();
341 auto cls = ClassType::create("foo.bar", cu, true);
342 auto attr_name = "attr";
343 auto const_name = "const";
344 cls->addAttribute(attr_name, IntType::get());
345 cls->addConstant(const_name, IValue(3));
346 Module m(cu, cls);
347 auto v = IValue(2);
348 m.register_attribute(attr_name, IntType::get(), v, false);
349 ASSERT_TRUE(m.hasattr(attr_name));
350 ASSERT_TRUE(m.hasattr(const_name));
351 ASSERT_EQ(m.attr(attr_name).toInt(), 2);
352 ASSERT_EQ(m.attr(const_name).toInt(), 3);
353}
354
355TEST(ModuleAPITest, Parameters) {
356 auto cu = std::make_shared<CompilationUnit>();
357 auto cls = ClassType::create("foo.bar", cu, true);
358 Module m(cu, cls);
359 // Tensor parameter
360 m.register_parameter(
361 "tensor_param", at::empty({3}, at::kFloat), /* is_buffer */ false);
362 // None parameter
363 m.register_attribute(
364 "none_param", NoneType::get(), IValue(), /* is_param */ true);
365 m.register_attribute(
366 "none_param2", NoneType::get(), IValue(), /* is_param */ true);
367 auto param_list = m.parameters();
368 ASSERT_EQ(param_list.size(), 1);
369 ASSERT_TRUE(m.hasattr("tensor_param"));
370 ASSERT_TRUE(m.hasattr("none_param"));
371 ASSERT_TRUE(m.hasattr("none_param2"));
372}
373
374TEST(ModuleAPITest, Define) {
375 Module m("m");
376 m.register_parameter("foo", torch::ones({}), false);
377 m.define(R"(
378 def add_it(self, x, b : int = 4):
379 return self.foo + x + b
380 )");
381 auto result = m.run_method("add_it", torch::ones({}));
382 AT_ASSERT(result.toTensor().item<float>() == 6);
383}
384
385TEST(ModuleAPITest, Freezing) {
386 Module m("m");
387 m.register_parameter("foo", torch::ones({}), false);
388 m.define(R"(
389 def forward(self, x, b : int = 4):
390 return self.foo + x + b
391 )");
392 m.eval();
393 auto forward_g = m.get_method("forward").graph();
394 testing::FileCheck().check("GetAttr")->run(*forward_g);
395
396 // Removal of GetAttr is done by freezing
397 auto frozen_mod = torch::jit::freeze(m);
398 forward_g = frozen_mod.get_method("forward").graph();
399 testing::FileCheck().check_not("GetAttr")->run(*forward_g);
400
401 // If no training mode is set, the module is NOT frozen by OFI
402 auto frozen_mod2 = torch::jit::optimize_for_inference(m);
403 forward_g = frozen_mod2.get_method("forward").graph();
404 testing::FileCheck().check("GetAttr")->run(*forward_g);
405}
406
407TEST(ModuleAPITest, OfiFreezesTraining) {
408 Module m("m");
409 m.register_parameter("foo", torch::ones({}), false);
410 m.define(R"(
411 def forward(self, x, b : int = 4):
412 return self.foo + x + b
413 )");
414 m.register_attribute("training", BoolType::get(), true);
415 m.eval();
416
417 // Before freezing, we have a GetAttr check
418 auto forward_g = m.get_method("forward").graph();
419 testing::FileCheck().check("GetAttr")->run(*forward_g);
420
421 // Demonstrate that freezing happens when OFI is called
422 // Removal of GetAttr is done by freezing, but only when training
423 // attribute is set
424 auto frozen_mod = torch::jit::optimize_for_inference(m);
425 forward_g = frozen_mod.get_method("forward").graph();
426 testing::FileCheck().check_not("GetAttr")->run(*forward_g);
427}
428
429TEST(ModuleAPITest, To_CUDA) {
430 Module m("test");
431 {
432 // test cuda to cpu for params and buffers
433 m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
434 m.register_buffer("bar", torch::ones({}, at::kCUDA));
435
436 m.to(at::kCUDA);
437 m.to(at::kCPU);
438 AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
439 AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
440 }
441 {
442 // test cpu to cuda for params and buffers
443 m.register_parameter("foo", torch::ones({}), false);
444 m.register_buffer("bar", torch::ones({}));
445
446 m.to(at::kCUDA);
447 AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
448 AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
449 }
450}
451
452} // namespace jit
453} // namespace torch
454