1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
17
18#include <memory>
19#include <unordered_map>
20#include <vector>
21
22#include "tensorflow/core/framework/allocator.h"
23#include "tensorflow/core/framework/step_stats.pb.h"
24#include "tensorflow/core/framework/tracking_allocator.h"
25#include "tensorflow/core/lib/gtl/inlined_vector.h"
26#include "tensorflow/core/platform/env.h"
27#include "tensorflow/core/platform/mutex.h"
28#include "tensorflow/core/platform/thread_annotations.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32
33class AllocatorMemoryUsed;
34class CostModelManager;
35class Graph;
36class NodeDef;
37class NodeExecStats;
38class OpKernelContext;
39class StepStats;
40class StepStatsCollector;
41class Tensor;
42
43// Statistics collection interface for individual node execution.
44//
45// See `NodeExecStatsWrapper` for a concrete implementation of this interface
46// that interfaces with the `Session` layer.
47class NodeExecStatsInterface {
48 public:
49 virtual ~NodeExecStatsInterface() {}
50
51 // Called when the statistics collection for the node has finished. Once this
52 // method is called, the caller should not make assumptions about the validity
53 // of this object.
54 virtual void Done(const string& device) = 0;
55
56 // Called immediately after this node starts being processed by the executor.
57 virtual void RecordExecutorStarted() = 0;
58
59 // Called immediately before this node's `Compute()` or `ComputeAsync()`
60 // method is called.
61 virtual void RecordComputeStarted() = 0;
62
63 // Called immediately after this node's `Compute()` method returned (or, for
64 // asynchronous operations, the callback passed to its `ComputeAsync()` method
65 // was called).
66 virtual void RecordComputeEnded() = 0;
67
68 // Called immediately after this executor finishes processing this node.
69 virtual void RecordExecutorEnded() = 0;
70
71 // Returns `true` if this object should track memory allocations.
72 virtual bool TrackAllocations() const = 0;
73
74 // Records information about the memory allocated during the execution of this
75 // node.
76 //
77 // Takes ownership of any `TrackingAllocator` objects stored in `ctx`.
78 virtual void SetMemory(OpKernelContext* ctx) = 0;
79
80 // Records information about the tensor produced by this node at the given
81 // output slot.
82 virtual void SetOutput(int slot, const Tensor* tensor) = 0;
83
84 // Records the absolute time in nanoseconds at which this node became
85 // runnable (i.e. was scheduled for execution).
86 virtual void SetScheduled(int64_t nanos) = 0;
87};
88
89// Wraps NodeExecStats and adds allocation to it.
90class NodeExecStatsWrapper : public NodeExecStatsInterface {
91 public:
92 // Does not take ownership of `node` or `step_stats_collector`.
93 NodeExecStatsWrapper(const NodeDef* node,
94 StepStatsCollector* step_stats_collector);
95
96 // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
97 NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats,
98 const NodeDef* node,
99 StepStatsCollector* step_stats_collector);
100
101 // Destructor calls Finalize() to release the TrackingAllocators.
102 ~NodeExecStatsWrapper() override { Finalize(); }
103
104 void Done(const string& device) override;
105 void RecordExecutorStarted() override;
106 void RecordComputeStarted() override;
107 void RecordComputeEnded() override;
108 void RecordExecutorEnded() override;
109 bool TrackAllocations() const override { return true; }
110 void SetMemory(OpKernelContext* ctx) override;
111 void SetOutput(int slot, const Tensor* tensor) override;
112 void SetScheduled(int64_t nanos) override;
113
114 private:
115 friend class StepStatsCollector;
116
117 NodeExecStats* stats() { return stats_.get(); }
118
119 // Populates stats_ and releases TrackingAllocator.
120 void Finalize();
121
122 // Does not take ownership of the `allocator`.
123 // Takes ownership of `tracking_allocator`.
124 void AddAllocation(Allocator* allocator,
125 TrackingAllocator* tracking_allocator);
126
127 gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
128 allocations_;
129 std::unique_ptr<NodeExecStats> stats_;
130 const NodeDef* const node_; // Not owned.
131 StepStatsCollector* const step_stats_collector_; // Not owned.
132};
133
134// Statistics collection interface for step execution.
135//
136// See `StepStatsCollector` for a concrete implementation of this interface
137// that interfaces with the `Session` layer.
138class StepStatsCollectorInterface {
139 public:
140 virtual ~StepStatsCollectorInterface() {}
141
142 // Creates an instance of `NodeExecStatsInterface` that should be used for
143 // collecting statistics about individual node execution.
144 virtual NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) = 0;
145
146 // Generates a string reporting the currently used memory based
147 // on ResourceExhausted OOM `err` message.
148 // `err` message needs to contain device name and allocator name, e.g.:
149 // "ResourceExhaustedError: OOM when allocating tensor ...
150 // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
151 virtual string ReportAllocsOnResourceExhausted(const string& err) = 0;
152};
153
154// StepStatsCollector manages the collection of a StepStats object.
155// The StepStats object holds multiple DeviceStats.
156// Each DeviceStats object holds multiple NodeExecStats.
157class StepStatsCollector : public StepStatsCollectorInterface {
158 public:
159 // Does not take ownership of `step_stats`.
160 explicit StepStatsCollector(StepStats* step_stats);
161
162 // BuildCostModel builds or updates a CostModel managed by cost_model_manager,
163 // using the currently collected DeviceStats associated with the devices in
164 // device_map.
165 void BuildCostModel(
166 CostModelManager* cost_model_manager,
167 const std::unordered_map<string, const Graph*>& device_map);
168
169 // Saves node statistics to the DeviceStats object associated with device.
170 // Should be called before Finalize.
171 void Save(const string& device, NodeExecStats* node_stats_pb);
172 void Save(const string& device, NodeExecStatsWrapper* node_stats);
173
174 // Saves thread name.
175 void SaveThreadName(const string& device, const uint32 thread_id,
176 const string& thread_name);
177
178 NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override;
179 string ReportAllocsOnResourceExhausted(const string& err) override;
180
181 // The following 2 Finalize methods populate the StepStats passed
182 // from the constructor. Calling it more than once won't have any effect.
183 // User shouldn't call Save() methods after Finalize.
184 void Finalize();
185 // swaps the content of StepStats* from constructor with 'ss'.
186 void FinalizeAndSwap(StepStats* step_stats);
187
188 private:
189 // TODO(suharshs): Make this configurable if its not possible to find a value
190 // that works for all cases.
191 static constexpr uint64 kMaxCollectedNodes = 1 << 20;
192
193 typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
194 typedef std::unordered_map<uint32, string> ThreadNamesMap;
195
196 void FinalizeInternal() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197
198 mutex mu_;
199 bool finalized_ TF_GUARDED_BY(mu_);
200 std::unordered_map<string, NodeStatsVector> dev_stats_ TF_GUARDED_BY(mu_);
201 std::unordered_map<string, ThreadNamesMap> thread_names_ TF_GUARDED_BY(mu_);
202 StepStats* step_stats_ TF_GUARDED_BY(mu_);
203 uint64 collected_nodes_ TF_GUARDED_BY(mu_) = 0;
204};
205
206} // namespace tensorflow
207
208#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
209