1#include <gtest/gtest.h>
2
3#include <c10/util/Optional.h>
4#include <test/cpp/jit/test_utils.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/ir/irparser.h>
7#include <torch/csrc/jit/runtime/script_profile.h>
8
9namespace torch {
10namespace jit {
11
12TEST(ScriptProfileTest, Basic) {
13 const std::string source_string = R"V0G0N(
14 def foo(a, b):
15 return a + b #
16 )V0G0N";
17 auto begin = source_string.find("return");
18 auto end = source_string.find(" #");
19
20 Graph g;
21 const auto graph_string = R"IR(
22 graph(%a : Tensor,
23 %b : Tensor):
24 %2 : int = prim::Constant[value=1]()
25 %3 : Tensor = aten::add(%a, %b, %2)
26 return (%3))IR";
27
28 torch::jit::parseIR(graph_string, &g);
29 auto source = std::make_shared<Source>(source_string, "", 0);
30 auto node = *g.nodes().begin();
31 node->setSourceRange(SourceRange{source, begin, end});
32
33 ScriptProfile p;
34 p.enable();
35 {
36 profiling::InstructionSpan g0(*node);
37 profiling::InstructionSpan g1(*node);
38 profiling::InstructionSpan g2(*node);
39 }
40 p.disable();
41
42 auto stats = p.dumpStats();
43 EXPECT_EQ(stats.size(), 1);
44 auto it = stats.find(*source.get());
45 EXPECT_NE(it, stats.end());
46 auto& lines = it->second;
47 EXPECT_EQ(lines.size(), 1);
48 const auto& stat = lines.at(source->lineno_for_offset(begin));
49 EXPECT_EQ(stat.count, 3);
50}
51
52TEST(ScriptProfileTest, CallingOrder) {
53 ScriptProfile p;
54 p.enable();
55 EXPECT_THROW(p.dumpStats(), c10::Error);
56 p.disable();
57 auto dp = std::make_shared<profiling::Datapoint>(SourceRange{});
58 EXPECT_THROW(p.addDatapoint(std::move(dp)), c10::Error);
59}
60
61} // namespace jit
62} // namespace torch
63