1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/Optional.h> |
4 | #include <test/cpp/jit/test_utils.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/ir/irparser.h> |
7 | #include <torch/csrc/jit/runtime/script_profile.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(ScriptProfileTest, Basic) { |
13 | const std::string source_string = R"V0G0N( |
14 | def foo(a, b): |
15 | return a + b # |
16 | )V0G0N" ; |
17 | auto begin = source_string.find("return" ); |
18 | auto end = source_string.find(" #" ); |
19 | |
20 | Graph g; |
21 | const auto graph_string = R"IR( |
22 | graph(%a : Tensor, |
23 | %b : Tensor): |
24 | %2 : int = prim::Constant[value=1]() |
25 | %3 : Tensor = aten::add(%a, %b, %2) |
26 | return (%3))IR" ; |
27 | |
28 | torch::jit::parseIR(graph_string, &g); |
29 | auto source = std::make_shared<Source>(source_string, "" , 0); |
30 | auto node = *g.nodes().begin(); |
31 | node->setSourceRange(SourceRange{source, begin, end}); |
32 | |
33 | ScriptProfile p; |
34 | p.enable(); |
35 | { |
36 | profiling::InstructionSpan g0(*node); |
37 | profiling::InstructionSpan g1(*node); |
38 | profiling::InstructionSpan g2(*node); |
39 | } |
40 | p.disable(); |
41 | |
42 | auto stats = p.dumpStats(); |
43 | EXPECT_EQ(stats.size(), 1); |
44 | auto it = stats.find(*source.get()); |
45 | EXPECT_NE(it, stats.end()); |
46 | auto& lines = it->second; |
47 | EXPECT_EQ(lines.size(), 1); |
48 | const auto& stat = lines.at(source->lineno_for_offset(begin)); |
49 | EXPECT_EQ(stat.count, 3); |
50 | } |
51 | |
52 | TEST(ScriptProfileTest, CallingOrder) { |
53 | ScriptProfile p; |
54 | p.enable(); |
55 | EXPECT_THROW(p.dumpStats(), c10::Error); |
56 | p.disable(); |
57 | auto dp = std::make_shared<profiling::Datapoint>(SourceRange{}); |
58 | EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error); |
59 | } |
60 | |
61 | } // namespace jit |
62 | } // namespace torch |
63 | |