1#pragma once
2
3#include <cstddef>
4#include <cstdint>
5#include <list>
6#include <string>
7#include <unordered_map>
8#include <vector>
9
10#include <ATen/record_function.h>
11#include <c10/macros/Macros.h>
12#include <c10/util/Optional.h>
13#include <c10/util/hash.h>
14#include <torch/csrc/Export.h>
15#include <torch/csrc/jit/frontend/source_range.h>
16
17#ifndef _WIN32
18#include <ctime>
19#endif
20#if defined(C10_IOS) && defined(C10_MOBILE)
21#include <sys/time.h> // for gettimeofday()
22#endif
23
24#if defined(__i386__) || defined(__x86_64__) || defined(__amd64__)
25#define C10_RDTSC
26#if defined(_MSC_VER)
27#include <intrin.h>
28#elif defined(__CUDACC__) || defined(__HIPCC__)
29#undef C10_RDTSC
30#elif defined(__clang__)
31// `__rdtsc` is available by default.
32// NB: This has to be first, because Clang will also define `__GNUC__`
33#elif defined(__GNUC__)
34#include <x86intrin.h>
35#else
36#undef C10_RDTSC
37#endif
38#endif
39
40// TODO: replace with pytorch/rfcs#43 when it is ready.
41#define SOFT_ASSERT(cond, ...) \
42 [&]() -> bool { \
43 if (C10_UNLIKELY(!(cond))) { \
44 torch::profiler::impl::logSoftAssert( \
45 __func__, \
46 __FILE__, \
47 static_cast<uint32_t>(__LINE__), \
48 #cond, \
49 ::c10::str(__VA_ARGS__)); \
50 if (torch::profiler::impl::softAssertRaises()) { \
51 TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__); \
52 } else { \
53 TORCH_WARN(__VA_ARGS__); \
54 } \
55 return false; \
56 } \
57 return true; \
58 }()
59
60namespace torch {
61namespace profiler {
62namespace impl {
63TORCH_API bool softAssertRaises();
64TORCH_API void setSoftAssertRaises(c10::optional<bool> value);
65TORCH_API void logSoftAssert(
66 const char* func,
67 const char* file,
68 uint32_t line,
69 const char* cond,
70 const char* args);
71TORCH_API inline void logSoftAssert(
72 const char* func,
73 const char* file,
74 uint32_t line,
75 const char* cond,
76 ::c10::detail::CompileTimeEmptyString args) {
77 logSoftAssert(func, file, line, cond, (const char*)args);
78}
79TORCH_API void logSoftAssert(
80 const char* func,
81 const char* file,
82 uint32_t line,
83 const char* cond,
84 const std::string& args);
85
86using time_t = int64_t;
87using steady_clock_t = std::conditional<
88 std::chrono::high_resolution_clock::is_steady,
89 std::chrono::high_resolution_clock,
90 std::chrono::steady_clock>::type;
91
92inline time_t getTimeSinceEpoch() {
93 auto now = std::chrono::system_clock::now().time_since_epoch();
94 return std::chrono::duration_cast<std::chrono::nanoseconds>(now).count();
95}
96
97inline time_t getTime(bool allow_monotonic = false) {
98#if defined(C10_IOS) && defined(C10_MOBILE)
99 // clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS
100 // can't rely on CLOCK_REALTIME, as it is defined no matter if clock_gettime
101 // is implemented or not
102 struct timeval now;
103 gettimeofday(&now, NULL);
104 return static_cast<time_t>(now.tv_sec) * 1000000000 +
105 static_cast<time_t>(now.tv_usec) * 1000;
106#elif defined(_WIN32) || defined(__MACH__)
107 return std::chrono::duration_cast<std::chrono::nanoseconds>(
108 steady_clock_t::now().time_since_epoch())
109 .count();
110#else
111 // clock_gettime is *much* faster than std::chrono implementation on Linux
112 struct timespec t {};
113 auto mode = CLOCK_REALTIME;
114 if (allow_monotonic) {
115 mode = CLOCK_MONOTONIC;
116 }
117 clock_gettime(mode, &t);
118 return static_cast<time_t>(t.tv_sec) * 1000000000 +
119 static_cast<time_t>(t.tv_nsec);
120#endif
121}
122
123// We often do not need to capture true wall times. If a fast mechanism such
124// as TSC is available we can use that instead and convert back to epoch time
125// during post processing. This greatly reduce the clock's contribution to
126// profiling.
127// http://btorpey.github.io/blog/2014/02/18/clock-sources-in-linux/
128// https://quick-bench.com/q/r8opkkGZSJMu9wM_XTbDouq-0Io
129// TODO: We should use
130// `https://github.com/google/benchmark/blob/main/src/cycleclock.h`
131inline auto getApproximateTime() {
132#if defined(C10_RDTSC)
133 return static_cast<uint64_t>(__rdtsc());
134#else
135 return getTime();
136#endif
137}
138
139using approx_time_t = decltype(getApproximateTime());
140static_assert(
141 std::is_same<approx_time_t, int64_t>::value ||
142 std::is_same<approx_time_t, uint64_t>::value,
143 "Expected either int64_t (`getTime`) or uint64_t (some TSC reads).");
144
145// Convert `getCount` results to Nanoseconds since unix epoch.
146class ApproximateClockToUnixTimeConverter final {
147 public:
148 ApproximateClockToUnixTimeConverter();
149 std::function<time_t(approx_time_t)> makeConverter();
150
151 struct UnixAndApproximateTimePair {
152 time_t t_;
153 approx_time_t approx_t_;
154 };
155 static UnixAndApproximateTimePair measurePair();
156
157 private:
158 static constexpr size_t replicates = 1001;
159 using time_pairs = std::array<UnixAndApproximateTimePair, replicates>;
160 time_pairs measurePairs();
161
162 time_pairs start_times_;
163};
164
165std::string getNvtxStr(
166 const char* name,
167 int64_t sequence_nr,
168 const std::vector<std::vector<int64_t>>& shapes,
169 at::RecordFunctionHandle op_id = 0,
170 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids =
171 {});
172
173struct TORCH_API FileLineFunc {
174 std::string filename;
175 size_t line;
176 std::string funcname;
177};
178
179TORCH_API std::vector<FileLineFunc> prepareCallstack(
180 const std::vector<jit::StackEntry>& cs);
181TORCH_API std::vector<std::string> callstackStr(
182 const std::vector<FileLineFunc>& cs);
183TORCH_API std::string stacksToStr(
184 const std::vector<std::string>& stacks,
185 const char* delim);
186TORCH_API std::vector<std::vector<int64_t>> inputSizes(
187 const at::RecordFunction& fn,
188 const bool flatten_list_enabled = false);
189TORCH_API std::string shapesToStr(
190 const std::vector<std::vector<int64_t>>& shapes);
191TORCH_API std::string dtypesToStr(const std::vector<std::string>& types);
192TORCH_API std::string inputOpIdsToStr(
193 const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids);
194TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
195
196std::unordered_map<std::string, c10::IValue> TORCH_API
197saveExtraArgs(const at::RecordFunction& fn);
198
199uint64_t TORCH_API computeFlops(
200 const std::string& op_name,
201 const std::unordered_map<std::string, c10::IValue>& extra_args);
202
203template <typename T>
204class TORCH_API GlobalStateManager {
205 public:
206 static GlobalStateManager& singleton() {
207 static GlobalStateManager singleton_;
208 return singleton_;
209 }
210
211 static void push(std::shared_ptr<T>&& state) {
212 if (singleton().state_) {
213 LOG(WARNING) << "GlobalStatePtr already exists!";
214 } else {
215 singleton().state_ = std::move(state);
216 }
217 }
218
219 static auto* get() {
220 return singleton().state_.get();
221 }
222
223 static std::shared_ptr<T> pop() {
224 auto out = singleton().state_;
225 singleton().state_.reset();
226 return out;
227 }
228
229 private:
230 GlobalStateManager() = default;
231
232 std::shared_ptr<T> state_;
233};
234
235struct HashCombine {
236 template <typename T0, typename T1>
237 size_t operator()(const std::pair<T0, T1>& i) {
238 return c10::get_hash((*this)(i.first), (*this)(i.second));
239 }
240
241 template <typename... Args>
242 size_t operator()(const std::tuple<Args...>& i) {
243 return c10::get_hash(i);
244 }
245
246 template <typename T>
247 size_t operator()(const T& i) {
248 return c10::get_hash(i);
249 }
250};
251
252} // namespace impl
253} // namespace profiler
254} // namespace torch
255
256namespace torch {
257namespace autograd {
258namespace profiler {
259using torch::profiler::impl::computeFlops;
260using torch::profiler::impl::getTime;
261} // namespace profiler
262} // namespace autograd
263} // namespace torch
264