1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/testing/file_check.h>
5#include <torch/torch.h>
6
7namespace torch {
8namespace jit {
9
10TEST(ClassTypeTest, AddRemoveAttr) {
11 auto cu = std::make_shared<CompilationUnit>();
12 auto cls = ClassType::create("foo.bar", cu, true);
13 cls->addAttribute("attr1", TensorType::get(), true);
14 cls->addAttribute("attr2", TensorType::get());
15 cls->addAttribute("attr3", TensorType::get());
16 ASSERT_TRUE(cls->hasAttribute("attr1"));
17 ASSERT_TRUE(cls->hasAttribute("attr2"));
18 ASSERT_TRUE(cls->hasAttribute("attr3"));
19
20 // removing attribute attr2
21 cls->unsafeRemoveAttribute("attr2");
22 ASSERT_TRUE(cls->hasAttribute("attr1"));
23 ASSERT_FALSE(cls->hasAttribute("attr2"));
24 ASSERT_TRUE(cls->hasAttribute("attr3"));
25
26 // removing parameter attr1
27 cls->unsafeRemoveAttribute("attr1");
28 ASSERT_FALSE(cls->hasAttribute("attr1"));
29 ASSERT_FALSE(cls->hasAttribute("attr2"));
30 ASSERT_TRUE(cls->hasAttribute("attr3"));
31
32 // check that we can still add a non-parameter attr1 with
33 // different type
34 cls->addAttribute("attr1", IntType::get());
35}
36
37TEST(ClassTypeTest, AddRemoveConstant) {
38 auto cu = std::make_shared<CompilationUnit>();
39 auto cls = ClassType::create("foo.bar", cu);
40 cls->addConstant("const1", IValue(1));
41 cls->addConstant("const2", IValue(2));
42 cls->addConstant("const3", IValue(3));
43 ASSERT_EQ(cls->numConstants(), 3);
44 ASSERT_TRUE(cls->hasConstant("const1"));
45 ASSERT_TRUE(cls->hasConstant("const2"));
46 ASSERT_TRUE(cls->hasConstant("const3"));
47 ASSERT_FALSE(cls->hasConstant("const4"));
48
49 ASSERT_EQ(cls->getConstant("const1").toInt(), 1);
50 ASSERT_EQ(cls->getConstant("const2").toInt(), 2);
51 ASSERT_EQ(cls->getConstant("const3").toInt(), 3);
52
53 cls->unsafeRemoveConstant("const2");
54 ASSERT_TRUE(cls->hasConstant("const1"));
55 ASSERT_FALSE(cls->hasConstant("const2"));
56 ASSERT_TRUE(cls->hasConstant("const3"));
57}
58
59TEST(ClassTypeTest, IdenticalTypesDifferentCus) {
60 auto cu1 = std::make_shared<CompilationUnit>();
61 auto cu2 = std::make_shared<CompilationUnit>();
62
63 // Create two identically named ClassTypes and put them
64 // in separate compilation units.
65 auto cls1 = ClassType::create("foo", cu1);
66 auto cls2 = ClassType::create("foo", cu2);
67
68 // Create a function that accepts "foo" (cls1) as input.
69 Argument arg("arg", cls1);
70 Argument ret("ret", IntType::get());
71
72 FunctionSchema schema("fn", "", {arg}, {ret});
73
74 jit::BuiltinOpFunction method(
75 "method",
76 std::move(schema),
77 [](jit::Stack& stack) mutable -> void {
78 pop(stack);
79 push(stack, 0);
80 },
81 "");
82
83 // Create an object of type cls2.
84 Object obj(cu2, cls2);
85
86 // Call method with the above object; this should
87 // throw an error because the types have identical
88 // names but are in different compilation units.
89 Stack stack;
90 push(stack, obj._ivalue());
91 try {
92 method(stack, {});
93 } catch (const std::exception& e) {
94 // Check that the exception contains the address of the compilation unit
95 // in addition to the ClassType's name.
96 testing::FileCheck()
97 .check("foo (of Python compilation unit at: 0x")
98 ->check_same(")")
99 ->check("foo (of Python compilation unit at: 0x")
100 ->check_same(")")
101 ->run(e.what());
102
103 return;
104 }
105
106 // This should never execute.
107 ASSERT_TRUE(false);
108}
109
110} // namespace jit
111} // namespace torch
112