1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/frontend/parser.h>
4#include <torch/csrc/jit/frontend/resolver.h>
5
6namespace torch {
7namespace jit {
8constexpr c10::string_view testSource = R"JIT(
9 class FooTest:
10 def __init__(self, x):
11 self.x = x
12
13 def get_x(self):
14 return self.x
15
16 an_attribute : Tensor
17)JIT";
18
19TEST(ClassParserTest, Basic) {
20 Parser p(std::make_shared<Source>(testSource));
21 std::vector<Def> definitions;
22 std::vector<Resolver> resolvers;
23
24 const auto classDef = ClassDef(p.parseClass());
25 p.lexer().expect(TK_EOF);
26
27 ASSERT_EQ(classDef.name().name(), "FooTest");
28 ASSERT_EQ(classDef.body().size(), 3);
29 ASSERT_EQ(Def(classDef.body()[0]).name().name(), "__init__");
30 ASSERT_EQ(Def(classDef.body()[1]).name().name(), "get_x");
31 ASSERT_EQ(
32 Var(Assign(classDef.body()[2]).lhs()).name().name(), "an_attribute");
33 ASSERT_FALSE(Assign(classDef.body()[2]).rhs().present());
34 ASSERT_TRUE(Assign(classDef.body()[2]).type().present());
35}
36} // namespace jit
37} // namespace torch
38