1 | #include <gtest/gtest.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/frontend/parser.h> |
4 | #include <torch/csrc/jit/frontend/resolver.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | constexpr c10::string_view testSource = R"JIT( |
9 | class FooTest: |
10 | def __init__(self, x): |
11 | self.x = x |
12 | |
13 | def get_x(self): |
14 | return self.x |
15 | |
16 | an_attribute : Tensor |
17 | )JIT"; |
18 | |
19 | TEST(ClassParserTest, Basic) { |
20 | Parser p(std::make_shared<Source>(testSource)); |
21 | std::vector<Def> definitions; |
22 | std::vector<Resolver> resolvers; |
23 | |
24 | const auto classDef = ClassDef(p.parseClass()); |
25 | p.lexer().expect(TK_EOF); |
26 | |
27 | ASSERT_EQ(classDef.name().name(), "FooTest"); |
28 | ASSERT_EQ(classDef.body().size(), 3); |
29 | ASSERT_EQ(Def(classDef.body()[0]).name().name(), "__init__"); |
30 | ASSERT_EQ(Def(classDef.body()[1]).name().name(), "get_x"); |
31 | ASSERT_EQ( |
32 | Var(Assign(classDef.body()[2]).lhs()).name().name(), "an_attribute"); |
33 | ASSERT_FALSE(Assign(classDef.body()[2]).rhs().present()); |
34 | ASSERT_TRUE(Assign(classDef.body()[2]).type().present()); |
35 | } |
36 | } // namespace jit |
37 | } // namespace torch |
38 |