1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/testing/file_check.h> |
5 | #include <torch/torch.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | TEST(ClassTypeTest, AddRemoveAttr) { |
11 | auto cu = std::make_shared<CompilationUnit>(); |
12 | auto cls = ClassType::create("foo.bar" , cu, true); |
13 | cls->addAttribute("attr1" , TensorType::get(), true); |
14 | cls->addAttribute("attr2" , TensorType::get()); |
15 | cls->addAttribute("attr3" , TensorType::get()); |
16 | ASSERT_TRUE(cls->hasAttribute("attr1" )); |
17 | ASSERT_TRUE(cls->hasAttribute("attr2" )); |
18 | ASSERT_TRUE(cls->hasAttribute("attr3" )); |
19 | |
20 | // removing attribute attr2 |
21 | cls->unsafeRemoveAttribute("attr2" ); |
22 | ASSERT_TRUE(cls->hasAttribute("attr1" )); |
23 | ASSERT_FALSE(cls->hasAttribute("attr2" )); |
24 | ASSERT_TRUE(cls->hasAttribute("attr3" )); |
25 | |
26 | // removing parameter attr1 |
27 | cls->unsafeRemoveAttribute("attr1" ); |
28 | ASSERT_FALSE(cls->hasAttribute("attr1" )); |
29 | ASSERT_FALSE(cls->hasAttribute("attr2" )); |
30 | ASSERT_TRUE(cls->hasAttribute("attr3" )); |
31 | |
32 | // check that we can still add a non-parameter attr1 with |
33 | // different type |
34 | cls->addAttribute("attr1" , IntType::get()); |
35 | } |
36 | |
37 | TEST(ClassTypeTest, AddRemoveConstant) { |
38 | auto cu = std::make_shared<CompilationUnit>(); |
39 | auto cls = ClassType::create("foo.bar" , cu); |
40 | cls->addConstant("const1" , IValue(1)); |
41 | cls->addConstant("const2" , IValue(2)); |
42 | cls->addConstant("const3" , IValue(3)); |
43 | ASSERT_EQ(cls->numConstants(), 3); |
44 | ASSERT_TRUE(cls->hasConstant("const1" )); |
45 | ASSERT_TRUE(cls->hasConstant("const2" )); |
46 | ASSERT_TRUE(cls->hasConstant("const3" )); |
47 | ASSERT_FALSE(cls->hasConstant("const4" )); |
48 | |
49 | ASSERT_EQ(cls->getConstant("const1" ).toInt(), 1); |
50 | ASSERT_EQ(cls->getConstant("const2" ).toInt(), 2); |
51 | ASSERT_EQ(cls->getConstant("const3" ).toInt(), 3); |
52 | |
53 | cls->unsafeRemoveConstant("const2" ); |
54 | ASSERT_TRUE(cls->hasConstant("const1" )); |
55 | ASSERT_FALSE(cls->hasConstant("const2" )); |
56 | ASSERT_TRUE(cls->hasConstant("const3" )); |
57 | } |
58 | |
59 | TEST(ClassTypeTest, IdenticalTypesDifferentCus) { |
60 | auto cu1 = std::make_shared<CompilationUnit>(); |
61 | auto cu2 = std::make_shared<CompilationUnit>(); |
62 | |
63 | // Create two identically named ClassTypes and put them |
64 | // in separate compilation units. |
65 | auto cls1 = ClassType::create("foo" , cu1); |
66 | auto cls2 = ClassType::create("foo" , cu2); |
67 | |
68 | // Create a function that accepts "foo" (cls1) as input. |
69 | Argument arg("arg" , cls1); |
70 | Argument ret("ret" , IntType::get()); |
71 | |
72 | FunctionSchema schema("fn" , "" , {arg}, {ret}); |
73 | |
74 | jit::BuiltinOpFunction method( |
75 | "method" , |
76 | std::move(schema), |
77 | [](jit::Stack& stack) mutable -> void { |
78 | pop(stack); |
79 | push(stack, 0); |
80 | }, |
81 | "" ); |
82 | |
83 | // Create an object of type cls2. |
84 | Object obj(cu2, cls2); |
85 | |
86 | // Call method with the above object; this should |
87 | // throw an error because the types have identical |
88 | // names but are in different compilation units. |
89 | Stack stack; |
90 | push(stack, obj._ivalue()); |
91 | try { |
92 | method(stack, {}); |
93 | } catch (const std::exception& e) { |
94 | // Check that the exception contains the address of the compilation unit |
95 | // in addition to the ClassType's name. |
96 | testing::FileCheck() |
97 | .check("foo (of Python compilation unit at: 0x" ) |
98 | ->check_same(")" ) |
99 | ->check("foo (of Python compilation unit at: 0x" ) |
100 | ->check_same(")" ) |
101 | ->run(e.what()); |
102 | |
103 | return; |
104 | } |
105 | |
106 | // This should never execute. |
107 | ASSERT_TRUE(false); |
108 | } |
109 | |
110 | } // namespace jit |
111 | } // namespace torch |
112 | |