1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cfloat> |
18 | #include <cstdint> |
19 | #include <memory> |
20 | #include <set> |
21 | #include <stdexcept> |
22 | #include <string> |
23 | #include <tuple> |
24 | #include <unordered_map> |
25 | #include <vector> |
26 | |
27 | #include "pybind11/pybind11.h" |
28 | #include "pybind11/stl.h" |
29 | #include "tensorflow/core/framework/kernel_def.pb.h" |
30 | #include "tensorflow/core/framework/memory_types.h" |
31 | #include "tensorflow/core/framework/op_def.pb.h" |
32 | #include "tensorflow/core/framework/step_stats.pb.h" |
33 | #include "tensorflow/core/grappler/clusters/cluster.h" |
34 | #include "tensorflow/core/grappler/clusters/single_machine.h" |
35 | #include "tensorflow/core/grappler/clusters/virtual_cluster.h" |
36 | #include "tensorflow/core/grappler/costs/cost_estimator.h" |
37 | #include "tensorflow/core/grappler/costs/graph_memory.h" |
38 | #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h" |
39 | #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" |
40 | #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" |
41 | #include "tensorflow/core/grappler/costs/utils.h" |
42 | #include "tensorflow/core/grappler/devices.h" |
43 | #include "tensorflow/core/grappler/grappler_item.h" |
44 | #include "tensorflow/core/grappler/utils.h" |
45 | #include "tensorflow/core/platform/status.h" |
46 | #include "tensorflow/core/protobuf/config.pb.h" |
47 | #include "tensorflow/core/protobuf/device_properties.pb.h" |
48 | #include "tensorflow/python/lib/core/pybind11_status.h" |
49 | |
50 | namespace py = pybind11; |
51 | |
52 | tensorflow::Status _GetOpPerformanceDataAndRunTime( |
53 | const tensorflow::grappler::GrapplerItem& item, |
54 | tensorflow::grappler::CostEstimator* cost_measure, |
55 | tensorflow::OpPerformanceList* op_performance_data, |
56 | tensorflow::grappler::Costs* costs) { |
57 | tensorflow::Status status = cost_measure->Initialize(item); |
58 | if (!status.ok()) return status; |
59 | |
60 | tensorflow::RunMetadata run_metadata; |
61 | MaybeRaiseRegisteredFromStatus( |
62 | cost_measure->PredictCosts(item.graph, &run_metadata, costs)); |
63 | |
64 | if (op_performance_data) { |
65 | *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData( |
66 | run_metadata.cost_graph(), item.graph); |
67 | } |
68 | return ::tensorflow::OkStatus(); |
69 | } |
70 | |
71 | PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster); |
72 | |
73 | PYBIND11_MODULE(_pywrap_tf_cluster, m) { |
74 | py::class_<tensorflow::grappler::Cluster> grappler_cluster( |
75 | m, "tensorflow::grappler::Cluster" ); |
76 | |
77 | m.def("TF_NewCluster" , |
78 | [](bool allow_soft_placement, |
79 | bool disable_detailed_stats) -> tensorflow::grappler::Cluster* { |
80 | // TODO(petebu): Make these named arguments with default values |
81 | // instead. |
82 | int num_cpu_cores = |
83 | tensorflow::grappler::GetNumAvailableLogicalCPUCores(); |
84 | int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); |
85 | int timeout_s = 60 * 10; |
86 | std::unique_ptr<tensorflow::grappler::Cluster> cluster = |
87 | std::make_unique<tensorflow::grappler::SingleMachine>( |
88 | timeout_s, num_cpu_cores, num_gpus); |
89 | cluster->DisableDetailedStats(disable_detailed_stats); |
90 | cluster->AllowSoftPlacement(allow_soft_placement); |
91 | cluster->SetNumWarmupSteps(10); |
92 | MaybeRaiseRegisteredFromStatus(cluster->Provision()); |
93 | return cluster.release(); |
94 | }); |
95 | |
96 | m.def("TF_NewVirtualCluster" , |
97 | [](const std::vector<py::bytes>& serialized_named_devices) |
98 | -> tensorflow::grappler::Cluster* { |
99 | std::vector<tensorflow::NamedDevice> named_devices; |
100 | for (const auto& s : serialized_named_devices) { |
101 | tensorflow::NamedDevice named_device; |
102 | if (!named_device.ParseFromString(std::string(s))) { |
103 | throw std::invalid_argument( |
104 | "The NamedDevice could not be parsed as a valid protocol " |
105 | "buffer" ); |
106 | } |
107 | named_devices.push_back(named_device); |
108 | } |
109 | |
110 | std::unordered_map<std::string, tensorflow::DeviceProperties> devices; |
111 | for (const auto& named_device : named_devices) { |
112 | devices[named_device.name()] = named_device.properties(); |
113 | } |
114 | std::unique_ptr<tensorflow::grappler::Cluster> cluster = |
115 | std::make_unique<tensorflow::grappler::VirtualCluster>(devices); |
116 | { |
117 | // TODO(petebu): Do we need to hold the GIL here? |
118 | py::gil_scoped_acquire acquire; |
119 | MaybeRaiseRegisteredFromStatus(cluster->Provision()); |
120 | } |
121 | return cluster.release(); |
122 | }); |
123 | |
124 | m.def("TF_ShutdownCluster" , [](tensorflow::grappler::Cluster* cluster) { |
125 | // TODO(petebu): Do we need to hold the GIL here? |
126 | py::gil_scoped_acquire acquire; |
127 | (void)cluster->Shutdown(); |
128 | }); |
129 | |
130 | m.def("TF_ListDevices" , |
131 | [](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> { |
132 | const std::unordered_map<std::string, tensorflow::DeviceProperties>& |
133 | devices = cluster->GetDevices(); |
134 | std::vector<py::bytes> named_devices; |
135 | for (auto& dev : devices) { |
136 | tensorflow::NamedDevice d; |
137 | d.set_name(dev.first); |
138 | *d.mutable_properties() = dev.second; |
139 | named_devices.push_back(d.SerializeAsString()); |
140 | } |
141 | return named_devices; |
142 | }); |
143 | |
144 | m.def("TF_ListAvailableOps" , []() -> std::vector<std::string> { |
145 | tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); |
146 | std::vector<tensorflow::OpDef> ops; |
147 | registry->GetRegisteredOps(&ops); |
148 | std::vector<std::string> op_names; |
149 | op_names.reserve(ops.size()); |
150 | for (const tensorflow::OpDef& op : ops) { |
151 | op_names.push_back(op.name()); |
152 | } |
153 | std::sort(op_names.begin(), op_names.end()); |
154 | return op_names; |
155 | }); |
156 | |
157 | m.def( |
158 | "TF_GetSupportedDevices" , |
159 | [](tensorflow::grappler::Cluster* cluster, |
160 | tensorflow::grappler::GrapplerItem* item) |
161 | -> std::unordered_map<std::string, std::vector<std::string>> { |
162 | if (cluster == nullptr || item == nullptr) { |
163 | MaybeRaiseRegisteredFromStatus(tensorflow::Status( |
164 | tensorflow::errors::Internal("You need both a cluster and an " |
165 | "item to get supported devices." ))); |
166 | } |
167 | const std::unordered_map<std::string, tensorflow::DeviceProperties>& |
168 | devices = cluster->GetDevices(); |
169 | std::unordered_map<std::string, std::vector<std::string>> device_types; |
170 | for (const auto& dev : devices) { |
171 | device_types[dev.second.type()].push_back(dev.first); |
172 | } |
173 | |
174 | std::unordered_map<std::string, std::set<std::string>> |
175 | supported_device_types; |
176 | std::unordered_map<std::string, std::set<std::string>> |
177 | device_restrictions; |
178 | |
179 | for (const auto& node : item->graph.node()) { |
180 | for (const auto& dev : device_types) { |
181 | const std::string& type = dev.first; |
182 | if (cluster->type() != "single_machine" ) { |
183 | // The actual kernel may not be linked in this binary. |
184 | supported_device_types[node.name()].insert(type); |
185 | } else { |
186 | // Check the kernel capabilities |
187 | const tensorflow::DeviceType dev_type(type); |
188 | tensorflow::Status s = |
189 | tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr); |
190 | if (s.ok()) { |
191 | supported_device_types[node.name()].insert(type); |
192 | |
193 | // Check which inputs are restricted to reside on the host. |
194 | // TODO: extends this to support outputs as well |
195 | tensorflow::MemoryTypeVector inp_mtypes; |
196 | tensorflow::MemoryTypeVector out_mtypes; |
197 | tensorflow::Status s = tensorflow::MemoryTypesForNode( |
198 | tensorflow::OpRegistry::Global(), dev_type, node, |
199 | &inp_mtypes, &out_mtypes); |
200 | if (s.ok()) { |
201 | for (size_t i = 0; i < inp_mtypes.size(); ++i) { |
202 | if (inp_mtypes[i] == tensorflow::HOST_MEMORY) { |
203 | device_restrictions[tensorflow::grappler::NodeName( |
204 | node.input(i))] |
205 | .insert("CPU" ); |
206 | break; |
207 | } |
208 | } |
209 | } |
210 | } |
211 | } |
212 | } |
213 | } |
214 | |
215 | std::unordered_map<std::string, std::vector<std::string>> result; |
216 | for (const auto& supported_dev : supported_device_types) { |
217 | const std::string& node = supported_dev.first; |
218 | std::set<std::string> feasible; |
219 | const auto it = device_restrictions.find(node); |
220 | if (it != device_restrictions.end()) { |
221 | const std::set<std::string>& candidates = supported_dev.second; |
222 | const std::set<std::string>& valid = it->second; |
223 | std::set_intersection(candidates.begin(), candidates.end(), |
224 | valid.begin(), valid.end(), |
225 | std::inserter(feasible, feasible.begin())); |
226 | } else { |
227 | feasible = supported_dev.second; |
228 | } |
229 | |
230 | std::vector<std::string> device_names; |
231 | for (const std::string& type : feasible) { |
232 | auto it = device_types.find(type); |
233 | DCHECK(it != device_types.end()); |
234 | for (const std::string& name : it->second) { |
235 | device_names.push_back(name); |
236 | } |
237 | } |
238 | result[node] = device_names; |
239 | } |
240 | return result; |
241 | }); |
242 | |
243 | m.def("TF_EstimatePerformance" , [](const py::bytes& serialized_device) { |
244 | tensorflow::NamedDevice device; |
245 | if (!device.ParseFromString(std::string(serialized_device))) { |
246 | throw std::invalid_argument( |
247 | "The NamedDevice could not be parsed as a valid protocol buffer" ); |
248 | } |
249 | tensorflow::grappler::OpLevelCostEstimator estimator; |
250 | tensorflow::grappler::DeviceInfo info = |
251 | estimator.GetDeviceInfo(device.properties()); |
252 | return info.gigaops; |
253 | }); |
254 | |
255 | m.def("TF_MeasureCosts" , |
256 | [](tensorflow::grappler::GrapplerItem* item, |
257 | tensorflow::grappler::Cluster* cluster, bool generate_timeline) |
258 | -> std::tuple<std::vector<py::bytes>, double, py::bytes> { |
259 | const int num_measurements = cluster->type() == "virtual" ? 1 : 10; |
260 | tensorflow::grappler::MeasuringCostEstimator cost_measure( |
261 | cluster, num_measurements, 0); |
262 | |
263 | tensorflow::OpPerformanceList op_performance_data; |
264 | tensorflow::grappler::Costs costs; |
265 | tensorflow::Status s = _GetOpPerformanceDataAndRunTime( |
266 | *item, &cost_measure, &op_performance_data, &costs); |
267 | double run_time = FLT_MAX; |
268 | if (s.ok()) { |
269 | run_time = static_cast<double>(costs.execution_time.count()) / 1e9; |
270 | } |
271 | tensorflow::StepStats step_stats; |
272 | if (generate_timeline) { |
273 | tensorflow::RunMetadata metadata; |
274 | MaybeRaiseRegisteredFromStatus( |
275 | cluster->Run(item->graph, item->feed, item->fetch, &metadata)); |
276 | step_stats = metadata.step_stats(); |
277 | } |
278 | |
279 | std::vector<py::bytes> op_perf_objs; |
280 | op_perf_objs.resize(op_performance_data.op_performance_size()); |
281 | for (int i = 0; i < op_performance_data.op_performance_size(); i++) { |
282 | op_perf_objs[i] = |
283 | op_performance_data.op_performance(i).SerializeAsString(); |
284 | } |
285 | |
286 | py::bytes step_stats_str = step_stats.SerializeAsString(); |
287 | return std::make_tuple(op_perf_objs, run_time, step_stats_str); |
288 | }); |
289 | |
290 | using DurationType = tensorflow::grappler::Costs::Duration::rep; |
291 | using MemoryUsage = |
292 | std::tuple<std::string, int, size_t, DurationType, DurationType>; |
293 | |
294 | m.def( |
295 | "TF_DeterminePeakMemoryUsage" , |
296 | [](tensorflow::grappler::GrapplerItem* item, |
297 | tensorflow::grappler::Cluster* cluster) |
298 | -> std::unordered_map<std::string, |
299 | std::tuple<int64_t, std::vector<MemoryUsage>>> { |
300 | if (item == nullptr || cluster == nullptr) { |
301 | MaybeRaiseRegisteredFromStatus( |
302 | tensorflow::Status(tensorflow::errors::Internal( |
303 | "You need both a cluster and an item to determine peak " |
304 | "memory usage." ))); |
305 | } |
306 | tensorflow::grappler::GraphMemory memory(*item); |
307 | |
308 | if (cluster->DetailedStatsEnabled()) { |
309 | MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster)); |
310 | } else { |
311 | MaybeRaiseRegisteredFromStatus( |
312 | memory.InferStatically(cluster->GetDevices())); |
313 | } |
314 | |
315 | std::unordered_map<std::string, |
316 | std::tuple<int64_t, std::vector<MemoryUsage>>> |
317 | result; |
318 | for (const auto& device : cluster->GetDevices()) { |
319 | const tensorflow::grappler::GraphMemory::MemoryUsage& usage = |
320 | memory.GetPeakMemoryUsage(device.first); |
321 | std::vector<MemoryUsage> per_device; |
322 | for (size_t i = 0; i < usage.live_tensors.size(); ++i) { |
323 | const auto& live_tensor = usage.live_tensors[i]; |
324 | per_device.push_back(std::make_tuple( |
325 | live_tensor.node, live_tensor.output_id, |
326 | live_tensor.memory_used, live_tensor.allocation_time.count(), |
327 | live_tensor.deallocation_time.count())); |
328 | } |
329 | result[device.first] = std::make_tuple(usage.used_memory, per_device); |
330 | } |
331 | return result; |
332 | }); |
333 | } |
334 | |