1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ |
16 | #define TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ |
17 | |
18 | #include <cstdint> |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "absl/time/time.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/platform/env.h" |
25 | #include "tensorflow/core/platform/mutex.h" |
26 | #include "tensorflow/core/platform/thread_annotations.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace data { |
30 | |
31 | // Exports the metrics for `GetNext` calls by tf.data iterators. When the user |
32 | // calls `RecordStart` and `RecordStop`, it will export a latency sample. It |
33 | // also exports throughput, tf.data iterator life time, etc. This class is |
34 | // thread-safe. Example usage: |
35 | // |
36 | // ``` |
37 | // IteratorMetricsCollector metrics_collector(DEVICE_CPU, env); |
38 | // absl::Time start_time = metrics_collector.RecordStart(); |
39 | // auto status = iterator_->GetNext(IteratorContext(std::move(params)), |
40 | // out_tensors, end_of_sequence); |
41 | // metrics_collector.RecordStop(start_time, *out_tensors); |
42 | // ``` |
43 | class IteratorMetricsCollector { |
44 | public: |
45 | // Constructs a `IteratorMetricsCollector`. `device_type` is one of the |
46 | // devices defined in `types.h` (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc). |
47 | // We only collect metrics for CPU devices. This is a heuristic to avoid |
48 | // collecting metrics for device-side iterators created by the multi-device |
49 | // iterator mechanism. |
50 | IteratorMetricsCollector(const std::string& device_type, const Env& env); |
51 | |
52 | // Starts the timer for the next `GetNext` call. Returns the start time. |
53 | absl::Time RecordStart(); |
54 | |
55 | // Records metrics for the most recent `GetNext` call, including the latency, |
56 | // bytes fetched, iterator life time, etc. `start_time` is the start time |
57 | // returned by `RecordStart`. `output` is the output of the `GetNext` call. |
58 | void RecordStop(absl::Time start_time, const std::vector<Tensor>& output); |
59 | |
60 | private: |
61 | // We only collect metrics for CPU devices. |
62 | bool ShouldCollectMetrics() const; |
63 | |
64 | // One of the devices defined in `types.h` |
65 | // (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc). |
66 | const std::string device_type_; |
67 | const Env& env_; |
68 | |
69 | mutex mu_; |
70 | |
71 | // Records the number of currently active `GetNext` calls. |
72 | uint64_t num_active_calls_ TF_GUARDED_BY(mu_) = 0; |
73 | |
74 | // Records the start time (in microseconds) of the first `RecordStart()` call |
75 | // that followed the last period of inactivity. |
76 | uint64_t first_start_time_us_ TF_GUARDED_BY(mu_) = 0; |
77 | |
78 | // Records the end time (in microseconds) of the most recent `RecordStop()` |
79 | // call. |
80 | uint64_t end_time_us_ TF_GUARDED_BY(mu_) = 0; |
81 | }; |
82 | |
83 | } // namespace data |
84 | } // namespace tensorflow |
85 | |
86 | #endif // TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ |
87 | |