1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_DATA_METRIC_UTILS_H_
16#define TENSORFLOW_CORE_DATA_METRIC_UTILS_H_
17
18#include <cstdint>
19#include <string>
20#include <vector>
21
22#include "absl/time/time.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/platform/env.h"
25#include "tensorflow/core/platform/mutex.h"
26#include "tensorflow/core/platform/thread_annotations.h"
27
28namespace tensorflow {
29namespace data {
30
31// Exports the metrics for `GetNext` calls by tf.data iterators. When the user
32// calls `RecordStart` and `RecordStop`, it will export a latency sample. It
33// also exports throughput, tf.data iterator life time, etc. This class is
34// thread-safe. Example usage:
35//
36// ```
37// IteratorMetricsCollector metrics_collector(DEVICE_CPU, env);
38// absl::Time start_time = metrics_collector.RecordStart();
39// auto status = iterator_->GetNext(IteratorContext(std::move(params)),
40// out_tensors, end_of_sequence);
41// metrics_collector.RecordStop(start_time, *out_tensors);
42// ```
43class IteratorMetricsCollector {
44 public:
45 // Constructs a `IteratorMetricsCollector`. `device_type` is one of the
46 // devices defined in `types.h` (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc).
47 // We only collect metrics for CPU devices. This is a heuristic to avoid
48 // collecting metrics for device-side iterators created by the multi-device
49 // iterator mechanism.
50 IteratorMetricsCollector(const std::string& device_type, const Env& env);
51
52 // Starts the timer for the next `GetNext` call. Returns the start time.
53 absl::Time RecordStart();
54
55 // Records metrics for the most recent `GetNext` call, including the latency,
56 // bytes fetched, iterator life time, etc. `start_time` is the start time
57 // returned by `RecordStart`. `output` is the output of the `GetNext` call.
58 void RecordStop(absl::Time start_time, const std::vector<Tensor>& output);
59
60 private:
61 // We only collect metrics for CPU devices.
62 bool ShouldCollectMetrics() const;
63
64 // One of the devices defined in `types.h`
65 // (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc).
66 const std::string device_type_;
67 const Env& env_;
68
69 mutex mu_;
70
71 // Records the number of currently active `GetNext` calls.
72 uint64_t num_active_calls_ TF_GUARDED_BY(mu_) = 0;
73
74 // Records the start time (in microseconds) of the first `RecordStart()` call
75 // that followed the last period of inactivity.
76 uint64_t first_start_time_us_ TF_GUARDED_BY(mu_) = 0;
77
78 // Records the end time (in microseconds) of the most recent `RecordStop()`
79 // call.
80 uint64_t end_time_us_ TF_GUARDED_BY(mu_) = 0;
81};
82
83} // namespace data
84} // namespace tensorflow
85
86#endif // TENSORFLOW_CORE_DATA_METRIC_UTILS_H_
87