1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <algorithm> |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace meta_schedule { |
25 | |
26 | /**************** Profiler ****************/ |
27 | |
28 | Map<String, FloatImm> ProfilerNode::Get() const { |
29 | Map<String, FloatImm> ret; |
30 | for (const auto& kv : stats_sec) { |
31 | ret.Set(kv.first, FloatImm(DataType::Float(64), kv.second)); |
32 | } |
33 | return ret; |
34 | } |
35 | |
36 | String ProfilerNode::Table() const { |
37 | CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run the profiler first." ; |
38 | CHECK(stats_sec.count("Total" )) |
39 | << "ValueError: The total time is not recorded. This method should be called only after " |
40 | "exiting the profiler's with scope." ; |
41 | double total = stats_sec.at("Total" ); |
42 | struct Entry { |
43 | String name; |
44 | double minutes; |
45 | double percentage; |
46 | bool operator<(const Entry& other) const { return percentage > other.percentage; } |
47 | }; |
48 | std::vector<Entry> table_entry; |
49 | for (const auto& kv : stats_sec) { |
50 | table_entry.push_back(Entry{kv.first, kv.second / 60.0, kv.second / total * 100.0}); |
51 | } |
52 | std::sort(table_entry.begin(), table_entry.end()); |
53 | support::TablePrinter p; |
54 | p.Row() << "ID" |
55 | << "Name" |
56 | << "Time (min)" |
57 | << "Percentage" ; |
58 | p.Separator(); |
59 | for (int i = 0, n = table_entry.size(); i < n; ++i) { |
60 | if (i == 0) { |
61 | p.Row() << "" << table_entry[i].name << table_entry[i].minutes << table_entry[i].percentage; |
62 | } else { |
63 | p.Row() << i << table_entry[i].name << table_entry[i].minutes << table_entry[i].percentage; |
64 | } |
65 | } |
66 | p.Separator(); |
67 | return p.AsStr(); |
68 | } |
69 | |
70 | Profiler::Profiler() { |
71 | ObjectPtr<ProfilerNode> n = make_object<ProfilerNode>(); |
72 | n->stats_sec.clear(); |
73 | n->total_timer = nullptr; |
74 | data_ = n; |
75 | } |
76 | |
77 | PackedFunc ProfilerTimedScope(String name) { |
78 | if (Optional<Profiler> opt_profiler = Profiler::Current()) { |
79 | return TypedPackedFunc<void()>([profiler = opt_profiler.value(), // |
80 | tik = std::chrono::high_resolution_clock::now(), // |
81 | name = std::move(name)]() { |
82 | auto tok = std::chrono::high_resolution_clock::now(); |
83 | double duration = |
84 | std::chrono::duration_cast<std::chrono::nanoseconds>(tok - tik).count() / 1e9; |
85 | profiler->stats_sec[name] += duration; |
86 | }); |
87 | } |
88 | return nullptr; |
89 | } |
90 | |
91 | ScopedTimer Profiler::TimedScope(String name) { return ScopedTimer(ProfilerTimedScope(name)); } |
92 | |
93 | /**************** Context Manager ****************/ |
94 | |
95 | std::vector<Profiler>* ThreadLocalProfilers() { |
96 | static thread_local std::vector<Profiler> profilers; |
97 | return &profilers; |
98 | } |
99 | |
100 | void Profiler::EnterWithScope() { |
101 | ThreadLocalProfilers()->push_back(*this); |
102 | (*this)->total_timer = ProfilerTimedScope("Total" ); |
103 | } |
104 | |
105 | void Profiler::ExitWithScope() { |
106 | ThreadLocalProfilers()->pop_back(); |
107 | if ((*this)->total_timer != nullptr) { |
108 | (*this)->total_timer(); |
109 | (*this)->total_timer = nullptr; |
110 | } |
111 | } |
112 | |
113 | Optional<Profiler> Profiler::Current() { |
114 | std::vector<Profiler>* profilers = ThreadLocalProfilers(); |
115 | if (profilers->empty()) { |
116 | return NullOpt; |
117 | } else { |
118 | return profilers->back(); |
119 | } |
120 | } |
121 | |
122 | TVM_REGISTER_NODE_TYPE(ProfilerNode); |
123 | TVM_REGISTER_GLOBAL("meta_schedule.Profiler" ).set_body_typed([]() -> Profiler { |
124 | return Profiler(); |
125 | }); |
126 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope" ) |
127 | .set_body_method(&Profiler::EnterWithScope); |
128 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope" ) |
129 | .set_body_method(&Profiler::ExitWithScope); |
130 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent" ).set_body_typed(Profiler::Current); |
131 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet" ).set_body_method<Profiler>(&ProfilerNode::Get); |
132 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTable" ).set_body_method<Profiler>(&ProfilerNode::Table); |
133 | TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope" ).set_body_typed(ProfilerTimedScope); |
134 | |
135 | } // namespace meta_schedule |
136 | } // namespace tvm |
137 | |