1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License");; |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "absl/strings/string_view.h" |
17 | #include "pybind11/pybind11.h" |
18 | #include "tensorflow/cc/saved_model/metrics.h" |
19 | |
20 | namespace tensorflow { |
21 | namespace saved_model { |
22 | namespace python { |
23 | |
24 | namespace py = pybind11; |
25 | |
26 | void DefineMetricsModule(py::module main_module) { |
27 | auto m = main_module.def_submodule("metrics" ); |
28 | |
29 | m.doc() = "Python bindings for TensorFlow SavedModel and Checkpoint Metrics." ; |
30 | |
31 | m.def( |
32 | "IncrementWrite" , |
33 | [](const char* write_version) { |
34 | metrics::SavedModelWrite(write_version).IncrementBy(1); |
35 | }, |
36 | py::kw_only(), py::arg("write_version" ), |
37 | py::doc("Increment the '/tensorflow/core/saved_model/write/count' " |
38 | "counter." )); |
39 | |
40 | m.def( |
41 | "GetWrite" , |
42 | [](const char* write_version) { |
43 | return metrics::SavedModelWrite(write_version).value(); |
44 | }, |
45 | py::kw_only(), py::arg("write_version" ), |
46 | py::doc("Get value of '/tensorflow/core/saved_model/write/count' " |
47 | "counter." )); |
48 | |
49 | m.def( |
50 | "IncrementWriteApi" , |
51 | [](const char* api_label) { |
52 | metrics::SavedModelWriteApi(api_label).IncrementBy(1); |
53 | }, |
54 | py::doc("Increment the '/tensorflow/core/saved_model/write/api' " |
55 | "counter for API with `api_label`" )); |
56 | |
57 | m.def( |
58 | "GetWriteApi" , |
59 | [](const char* api_label) { |
60 | return metrics::SavedModelWriteApi(api_label).value(); |
61 | }, |
62 | py::doc("Get value of '/tensorflow/core/saved_model/write/api' " |
63 | "counter for `api_label` cell." )); |
64 | |
65 | m.def( |
66 | "IncrementRead" , |
67 | [](const char* write_version) { |
68 | metrics::SavedModelRead(write_version).IncrementBy(1); |
69 | }, |
70 | py::kw_only(), py::arg("write_version" ), |
71 | py::doc("Increment the '/tensorflow/core/saved_model/read/count' " |
72 | "counter after reading a SavedModel with the specifed " |
73 | "`write_version`." )); |
74 | |
75 | m.def( |
76 | "GetRead" , |
77 | [](const char* write_version) { |
78 | return metrics::SavedModelRead(write_version).value(); |
79 | }, |
80 | py::kw_only(), py::arg("write_version" ), |
81 | py::doc("Get value of '/tensorflow/core/saved_model/read/count' " |
82 | "counter for SavedModels with the specified `write_version`." )); |
83 | |
84 | m.def( |
85 | "IncrementReadApi" , |
86 | [](const char* api_label) { |
87 | metrics::SavedModelReadApi(api_label).IncrementBy(1); |
88 | }, |
89 | py::doc("Increment the '/tensorflow/core/saved_model/read/api' " |
90 | "counter for API with `api_label`." )); |
91 | |
92 | m.def( |
93 | "GetReadApi" , |
94 | [](const char* api_label) { |
95 | return metrics::SavedModelReadApi(api_label).value(); |
96 | }, |
97 | py::doc("Get value of '/tensorflow/core/saved_model/read/api' " |
98 | "counter for `api_label` cell." )); |
99 | |
100 | m.def( |
101 | "AddCheckpointReadDuration" , |
102 | [](const char* api_label, double microseconds) { |
103 | metrics::CheckpointReadDuration(api_label).Add(microseconds); |
104 | }, |
105 | py::kw_only(), py::arg("api_label" ), py::arg("microseconds" ), |
106 | py::doc("Add `microseconds` to the cell `api_label`for " |
107 | "'/tensorflow/core/checkpoint/read/read_durations'." )); |
108 | |
109 | m.def( |
110 | "GetCheckpointReadDurations" , |
111 | [](const char* api_label) { |
112 | // This function is called sparingly in unit tests, so protobuf |
113 | // (de)-serialization round trip is not an issue. |
114 | return py::bytes(metrics::CheckpointReadDuration(api_label) |
115 | .value() |
116 | .SerializeAsString()); |
117 | }, |
118 | py::kw_only(), py::arg("api_label" ), |
119 | py::doc("Get serialized HistogramProto of `api_label` cell for " |
120 | "'/tensorflow/core/checkpoint/read/read_durations'." )); |
121 | |
122 | m.def( |
123 | "AddCheckpointWriteDuration" , |
124 | [](const char* api_label, double microseconds) { |
125 | metrics::CheckpointWriteDuration(api_label).Add(microseconds); |
126 | }, |
127 | py::kw_only(), py::arg("api_label" ), py::arg("microseconds" ), |
128 | py::doc("Add `microseconds` to the cell `api_label` for " |
129 | "'/tensorflow/core/checkpoint/write/write_durations'." )); |
130 | |
131 | m.def( |
132 | "GetCheckpointWriteDurations" , |
133 | [](const char* api_label) { |
134 | // This function is called sparingly, so protobuf (de)-serialization |
135 | // round trip is not an issue. |
136 | return py::bytes(metrics::CheckpointWriteDuration(api_label) |
137 | .value() |
138 | .SerializeAsString()); |
139 | }, |
140 | py::kw_only(), py::arg("api_label" ), |
141 | py::doc("Get serialized HistogramProto of `api_label` cell for " |
142 | "'/tensorflow/core/checkpoint/write/write_durations'." )); |
143 | |
144 | m.def( |
145 | "AddAsyncCheckpointWriteDuration" , |
146 | [](const char* api_label, double microseconds) { |
147 | metrics::AsyncCheckpointWriteDuration(api_label).Add(microseconds); |
148 | }, |
149 | py::kw_only(), py::arg("api_label" ), py::arg("microseconds" ), |
150 | py::doc("Add `microseconds` to the cell `api_label` for " |
151 | "'/tensorflow/core/checkpoint/write/async_write_durations'." )); |
152 | |
153 | m.def( |
154 | "GetAsyncCheckpointWriteDurations" , |
155 | [](const char* api_label) { |
156 | // This function is called sparingly, so protobuf (de)-serialization |
157 | // round trip is not an issue. |
158 | return py::bytes(metrics::AsyncCheckpointWriteDuration(api_label) |
159 | .value() |
160 | .SerializeAsString()); |
161 | }, |
162 | py::kw_only(), py::arg("api_label" ), |
163 | py::doc("Get serialized HistogramProto of `api_label` cell for " |
164 | "'/tensorflow/core/checkpoint/write/async_write_durations'." )); |
165 | |
166 | m.def( |
167 | "AddTrainingTimeSaved" , |
168 | [](const char* api_label, double microseconds) { |
169 | metrics::TrainingTimeSaved(api_label).IncrementBy(microseconds); |
170 | }, |
171 | py::kw_only(), py::arg("api_label" ), py::arg("microseconds" ), |
172 | py::doc("Add `microseconds` to the cell `api_label` for " |
173 | "'/tensorflow/core/checkpoint/write/training_time_saved'." )); |
174 | |
175 | m.def( |
176 | "GetTrainingTimeSaved" , |
177 | [](const char* api_label) { |
178 | return metrics::TrainingTimeSaved(api_label).value(); |
179 | }, |
180 | py::kw_only(), py::arg("api_label" ), |
181 | py::doc("Get cell `api_label` for " |
182 | "'/tensorflow/core/checkpoint/write/training_time_saved'." )); |
183 | |
184 | m.def( |
185 | "CalculateFileSize" , |
186 | [](const char* filename) { |
187 | Env* env = Env::Default(); |
188 | uint64 filesize = 0; |
189 | if (!env->GetFileSize(filename, &filesize).ok()) { |
190 | return (int64_t)-1; |
191 | } |
192 | // Convert to MB. |
193 | int64_t filesize_mb = filesize / 1000; |
194 | // Round to the nearest 100 MB. |
195 | // Smaller multiple. |
196 | int64_t a = (filesize_mb / 100) * 100; |
197 | // Larger multiple. |
198 | int64_t b = a + 100; |
199 | // Return closest of two. |
200 | return (filesize_mb - a > b - filesize_mb) ? b : a; |
201 | }, |
202 | py::doc("Calculate filesize (MB) for `filename`, rounding to the nearest " |
203 | "100MB. Returns -1 if `filename` is invalid." )); |
204 | |
205 | m.def( |
206 | "RecordCheckpointSize" , |
207 | [](const char* api_label, int64_t filesize) { |
208 | metrics::CheckpointSize(api_label, filesize).IncrementBy(1); |
209 | }, |
210 | py::kw_only(), py::arg("api_label" ), py::arg("filesize" ), |
211 | py::doc("Increment the " |
212 | "'/tensorflow/core/checkpoint/write/checkpoint_size' counter for " |
213 | "cell (api_label, filesize) after writing a checkpoint." )); |
214 | |
215 | m.def( |
216 | "GetCheckpointSize" , |
217 | [](const char* api_label, uint64 filesize) { |
218 | return metrics::CheckpointSize(api_label, filesize).value(); |
219 | }, |
220 | py::kw_only(), py::arg("api_label" ), py::arg("filesize" ), |
221 | py::doc("Get cell (api_label, filesize) for " |
222 | "'/tensorflow/core/checkpoint/write/checkpoint_size'." )); |
223 | } |
224 | |
225 | } // namespace python |
226 | } // namespace saved_model |
227 | } // namespace tensorflow |
228 | |