1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_PLATFORM_TRACING_H_ |
17 | #define TENSORFLOW_TSL_PLATFORM_TRACING_H_ |
18 | |
19 | // Tracing interface |
20 | |
21 | #include <array> |
22 | |
23 | #include "tensorflow/tsl/platform/macros.h" |
24 | #include "tensorflow/tsl/platform/platform.h" |
25 | #include "tensorflow/tsl/platform/stringpiece.h" |
26 | #include "tensorflow/tsl/platform/types.h" |
27 | |
28 | namespace tsl { |
29 | namespace tracing { |
30 | |
31 | // This enumeration contains the identifiers of all TensorFlow CPU profiler |
32 | // events. It must be kept in sync with the code in GetEventCategoryName(). |
33 | enum struct EventCategory : unsigned { |
34 | kScheduleClosure = 0, |
35 | kRunClosure = 1, |
36 | kCompute = 2, |
37 | kNumCategories = 3 // sentinel - keep last |
38 | }; |
39 | constexpr unsigned GetNumEventCategories() { |
40 | return static_cast<unsigned>(EventCategory::kNumCategories); |
41 | } |
42 | const char* GetEventCategoryName(EventCategory); |
43 | |
44 | // Interface for CPU profiler events. |
45 | class EventCollector { |
46 | public: |
47 | virtual ~EventCollector() {} |
48 | virtual void RecordEvent(uint64 arg) const = 0; |
49 | virtual void StartRegion(uint64 arg) const = 0; |
50 | virtual void StopRegion() const = 0; |
51 | |
52 | // Annotates the current thread with a name. |
53 | static void SetCurrentThreadName(const char* name); |
54 | // Returns whether event collection is enabled. |
55 | static bool IsEnabled(); |
56 | |
57 | private: |
58 | friend void SetEventCollector(EventCategory, const EventCollector*); |
59 | friend const EventCollector* GetEventCollector(EventCategory); |
60 | |
61 | static std::array<const EventCollector*, GetNumEventCategories()> instances_; |
62 | }; |
63 | // Set the callback for RecordEvent and ScopedRegion of category. |
64 | // Not thread safe. Only call while EventCollector::IsEnabled returns false. |
65 | void SetEventCollector(EventCategory category, const EventCollector* collector); |
66 | |
67 | // Returns the callback for RecordEvent and ScopedRegion of category if |
68 | // EventCollector::IsEnabled(), otherwise returns null. |
69 | inline const EventCollector* GetEventCollector(EventCategory category) { |
70 | if (EventCollector::IsEnabled()) { |
71 | return EventCollector::instances_[static_cast<unsigned>(category)]; |
72 | } |
73 | return nullptr; |
74 | } |
75 | |
76 | // Returns a unique id to pass to RecordEvent/ScopedRegion. Never returns zero. |
77 | uint64 GetUniqueArg(); |
78 | |
79 | // Returns an id for name to pass to RecordEvent/ScopedRegion. |
80 | uint64 GetArgForName(StringPiece name); |
81 | |
82 | // Records an atomic event through the currently registered EventCollector. |
83 | inline void RecordEvent(EventCategory category, uint64 arg) { |
84 | if (auto collector = GetEventCollector(category)) { |
85 | collector->RecordEvent(arg); |
86 | } |
87 | } |
88 | |
89 | // Records an event for the duration of the instance lifetime through the |
90 | // currently registered EventCollector. |
91 | class ScopedRegion { |
92 | public: |
93 | ScopedRegion(ScopedRegion&& other) noexcept // Move-constructible. |
94 | : collector_(other.collector_) { |
95 | other.collector_ = nullptr; |
96 | } |
97 | |
98 | ScopedRegion(EventCategory category, uint64 arg) |
99 | : collector_(GetEventCollector(category)) { |
100 | if (collector_) { |
101 | collector_->StartRegion(arg); |
102 | } |
103 | } |
104 | |
105 | // Same as ScopedRegion(category, GetUniqueArg()), but faster if |
106 | // EventCollector::IsEnabled() returns false. |
107 | explicit ScopedRegion(EventCategory category) |
108 | : collector_(GetEventCollector(category)) { |
109 | if (collector_) { |
110 | collector_->StartRegion(GetUniqueArg()); |
111 | } |
112 | } |
113 | |
114 | // Same as ScopedRegion(category, GetArgForName(name)), but faster if |
115 | // EventCollector::IsEnabled() returns false. |
116 | ScopedRegion(EventCategory category, StringPiece name) |
117 | : collector_(GetEventCollector(category)) { |
118 | if (collector_) { |
119 | collector_->StartRegion(GetArgForName(name)); |
120 | } |
121 | } |
122 | |
123 | ~ScopedRegion() { |
124 | if (collector_) { |
125 | collector_->StopRegion(); |
126 | } |
127 | } |
128 | |
129 | bool IsEnabled() const { return collector_ != nullptr; } |
130 | |
131 | private: |
132 | TF_DISALLOW_COPY_AND_ASSIGN(ScopedRegion); |
133 | |
134 | const EventCollector* collector_; |
135 | }; |
136 | |
137 | // Return the pathname of the directory where we are writing log files. |
138 | const char* GetLogDir(); |
139 | |
140 | } // namespace tracing |
141 | } // namespace tsl |
142 | |
143 | #if defined(PLATFORM_GOOGLE) |
144 | #include "tensorflow/tsl/platform/google/tracing_impl.h" |
145 | #else |
146 | #include "tensorflow/tsl/platform/default/tracing_impl.h" |
147 | #endif |
148 | |
149 | #endif // TENSORFLOW_TSL_PLATFORM_TRACING_H_ |
150 | |