1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");;
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "absl/strings/string_view.h"
17#include "pybind11/pybind11.h"
18#include "tensorflow/cc/saved_model/metrics.h"
19
20namespace tensorflow {
21namespace saved_model {
22namespace python {
23
24namespace py = pybind11;
25
26void DefineMetricsModule(py::module main_module) {
27 auto m = main_module.def_submodule("metrics");
28
29 m.doc() = "Python bindings for TensorFlow SavedModel and Checkpoint Metrics.";
30
31 m.def(
32 "IncrementWrite",
33 [](const char* write_version) {
34 metrics::SavedModelWrite(write_version).IncrementBy(1);
35 },
36 py::kw_only(), py::arg("write_version"),
37 py::doc("Increment the '/tensorflow/core/saved_model/write/count' "
38 "counter."));
39
40 m.def(
41 "GetWrite",
42 [](const char* write_version) {
43 return metrics::SavedModelWrite(write_version).value();
44 },
45 py::kw_only(), py::arg("write_version"),
46 py::doc("Get value of '/tensorflow/core/saved_model/write/count' "
47 "counter."));
48
49 m.def(
50 "IncrementWriteApi",
51 [](const char* api_label) {
52 metrics::SavedModelWriteApi(api_label).IncrementBy(1);
53 },
54 py::doc("Increment the '/tensorflow/core/saved_model/write/api' "
55 "counter for API with `api_label`"));
56
57 m.def(
58 "GetWriteApi",
59 [](const char* api_label) {
60 return metrics::SavedModelWriteApi(api_label).value();
61 },
62 py::doc("Get value of '/tensorflow/core/saved_model/write/api' "
63 "counter for `api_label` cell."));
64
65 m.def(
66 "IncrementRead",
67 [](const char* write_version) {
68 metrics::SavedModelRead(write_version).IncrementBy(1);
69 },
70 py::kw_only(), py::arg("write_version"),
71 py::doc("Increment the '/tensorflow/core/saved_model/read/count' "
72 "counter after reading a SavedModel with the specifed "
73 "`write_version`."));
74
75 m.def(
76 "GetRead",
77 [](const char* write_version) {
78 return metrics::SavedModelRead(write_version).value();
79 },
80 py::kw_only(), py::arg("write_version"),
81 py::doc("Get value of '/tensorflow/core/saved_model/read/count' "
82 "counter for SavedModels with the specified `write_version`."));
83
84 m.def(
85 "IncrementReadApi",
86 [](const char* api_label) {
87 metrics::SavedModelReadApi(api_label).IncrementBy(1);
88 },
89 py::doc("Increment the '/tensorflow/core/saved_model/read/api' "
90 "counter for API with `api_label`."));
91
92 m.def(
93 "GetReadApi",
94 [](const char* api_label) {
95 return metrics::SavedModelReadApi(api_label).value();
96 },
97 py::doc("Get value of '/tensorflow/core/saved_model/read/api' "
98 "counter for `api_label` cell."));
99
100 m.def(
101 "AddCheckpointReadDuration",
102 [](const char* api_label, double microseconds) {
103 metrics::CheckpointReadDuration(api_label).Add(microseconds);
104 },
105 py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
106 py::doc("Add `microseconds` to the cell `api_label`for "
107 "'/tensorflow/core/checkpoint/read/read_durations'."));
108
109 m.def(
110 "GetCheckpointReadDurations",
111 [](const char* api_label) {
112 // This function is called sparingly in unit tests, so protobuf
113 // (de)-serialization round trip is not an issue.
114 return py::bytes(metrics::CheckpointReadDuration(api_label)
115 .value()
116 .SerializeAsString());
117 },
118 py::kw_only(), py::arg("api_label"),
119 py::doc("Get serialized HistogramProto of `api_label` cell for "
120 "'/tensorflow/core/checkpoint/read/read_durations'."));
121
122 m.def(
123 "AddCheckpointWriteDuration",
124 [](const char* api_label, double microseconds) {
125 metrics::CheckpointWriteDuration(api_label).Add(microseconds);
126 },
127 py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
128 py::doc("Add `microseconds` to the cell `api_label` for "
129 "'/tensorflow/core/checkpoint/write/write_durations'."));
130
131 m.def(
132 "GetCheckpointWriteDurations",
133 [](const char* api_label) {
134 // This function is called sparingly, so protobuf (de)-serialization
135 // round trip is not an issue.
136 return py::bytes(metrics::CheckpointWriteDuration(api_label)
137 .value()
138 .SerializeAsString());
139 },
140 py::kw_only(), py::arg("api_label"),
141 py::doc("Get serialized HistogramProto of `api_label` cell for "
142 "'/tensorflow/core/checkpoint/write/write_durations'."));
143
144 m.def(
145 "AddAsyncCheckpointWriteDuration",
146 [](const char* api_label, double microseconds) {
147 metrics::AsyncCheckpointWriteDuration(api_label).Add(microseconds);
148 },
149 py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
150 py::doc("Add `microseconds` to the cell `api_label` for "
151 "'/tensorflow/core/checkpoint/write/async_write_durations'."));
152
153 m.def(
154 "GetAsyncCheckpointWriteDurations",
155 [](const char* api_label) {
156 // This function is called sparingly, so protobuf (de)-serialization
157 // round trip is not an issue.
158 return py::bytes(metrics::AsyncCheckpointWriteDuration(api_label)
159 .value()
160 .SerializeAsString());
161 },
162 py::kw_only(), py::arg("api_label"),
163 py::doc("Get serialized HistogramProto of `api_label` cell for "
164 "'/tensorflow/core/checkpoint/write/async_write_durations'."));
165
166 m.def(
167 "AddTrainingTimeSaved",
168 [](const char* api_label, double microseconds) {
169 metrics::TrainingTimeSaved(api_label).IncrementBy(microseconds);
170 },
171 py::kw_only(), py::arg("api_label"), py::arg("microseconds"),
172 py::doc("Add `microseconds` to the cell `api_label` for "
173 "'/tensorflow/core/checkpoint/write/training_time_saved'."));
174
175 m.def(
176 "GetTrainingTimeSaved",
177 [](const char* api_label) {
178 return metrics::TrainingTimeSaved(api_label).value();
179 },
180 py::kw_only(), py::arg("api_label"),
181 py::doc("Get cell `api_label` for "
182 "'/tensorflow/core/checkpoint/write/training_time_saved'."));
183
184 m.def(
185 "CalculateFileSize",
186 [](const char* filename) {
187 Env* env = Env::Default();
188 uint64 filesize = 0;
189 if (!env->GetFileSize(filename, &filesize).ok()) {
190 return (int64_t)-1;
191 }
192 // Convert to MB.
193 int64_t filesize_mb = filesize / 1000;
194 // Round to the nearest 100 MB.
195 // Smaller multiple.
196 int64_t a = (filesize_mb / 100) * 100;
197 // Larger multiple.
198 int64_t b = a + 100;
199 // Return closest of two.
200 return (filesize_mb - a > b - filesize_mb) ? b : a;
201 },
202 py::doc("Calculate filesize (MB) for `filename`, rounding to the nearest "
203 "100MB. Returns -1 if `filename` is invalid."));
204
205 m.def(
206 "RecordCheckpointSize",
207 [](const char* api_label, int64_t filesize) {
208 metrics::CheckpointSize(api_label, filesize).IncrementBy(1);
209 },
210 py::kw_only(), py::arg("api_label"), py::arg("filesize"),
211 py::doc("Increment the "
212 "'/tensorflow/core/checkpoint/write/checkpoint_size' counter for "
213 "cell (api_label, filesize) after writing a checkpoint."));
214
215 m.def(
216 "GetCheckpointSize",
217 [](const char* api_label, uint64 filesize) {
218 return metrics::CheckpointSize(api_label, filesize).value();
219 },
220 py::kw_only(), py::arg("api_label"), py::arg("filesize"),
221 py::doc("Get cell (api_label, filesize) for "
222 "'/tensorflow/core/checkpoint/write/checkpoint_size'."));
223}
224
225} // namespace python
226} // namespace saved_model
227} // namespace tensorflow
228