1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cfloat>
18#include <cstdint>
19#include <memory>
20#include <set>
21#include <stdexcept>
22#include <string>
23#include <tuple>
24#include <unordered_map>
25#include <vector>
26
27#include "pybind11/pybind11.h"
28#include "pybind11/stl.h"
29#include "tensorflow/core/framework/kernel_def.pb.h"
30#include "tensorflow/core/framework/memory_types.h"
31#include "tensorflow/core/framework/op_def.pb.h"
32#include "tensorflow/core/framework/step_stats.pb.h"
33#include "tensorflow/core/grappler/clusters/cluster.h"
34#include "tensorflow/core/grappler/clusters/single_machine.h"
35#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
36#include "tensorflow/core/grappler/costs/cost_estimator.h"
37#include "tensorflow/core/grappler/costs/graph_memory.h"
38#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
39#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
40#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
41#include "tensorflow/core/grappler/costs/utils.h"
42#include "tensorflow/core/grappler/devices.h"
43#include "tensorflow/core/grappler/grappler_item.h"
44#include "tensorflow/core/grappler/utils.h"
45#include "tensorflow/core/platform/status.h"
46#include "tensorflow/core/protobuf/config.pb.h"
47#include "tensorflow/core/protobuf/device_properties.pb.h"
48#include "tensorflow/python/lib/core/pybind11_status.h"
49
50namespace py = pybind11;
51
52tensorflow::Status _GetOpPerformanceDataAndRunTime(
53 const tensorflow::grappler::GrapplerItem& item,
54 tensorflow::grappler::CostEstimator* cost_measure,
55 tensorflow::OpPerformanceList* op_performance_data,
56 tensorflow::grappler::Costs* costs) {
57 tensorflow::Status status = cost_measure->Initialize(item);
58 if (!status.ok()) return status;
59
60 tensorflow::RunMetadata run_metadata;
61 MaybeRaiseRegisteredFromStatus(
62 cost_measure->PredictCosts(item.graph, &run_metadata, costs));
63
64 if (op_performance_data) {
65 *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
66 run_metadata.cost_graph(), item.graph);
67 }
68 return ::tensorflow::OkStatus();
69}
70
71PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster);
72
73PYBIND11_MODULE(_pywrap_tf_cluster, m) {
74 py::class_<tensorflow::grappler::Cluster> grappler_cluster(
75 m, "tensorflow::grappler::Cluster");
76
77 m.def("TF_NewCluster",
78 [](bool allow_soft_placement,
79 bool disable_detailed_stats) -> tensorflow::grappler::Cluster* {
80 // TODO(petebu): Make these named arguments with default values
81 // instead.
82 int num_cpu_cores =
83 tensorflow::grappler::GetNumAvailableLogicalCPUCores();
84 int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
85 int timeout_s = 60 * 10;
86 std::unique_ptr<tensorflow::grappler::Cluster> cluster =
87 std::make_unique<tensorflow::grappler::SingleMachine>(
88 timeout_s, num_cpu_cores, num_gpus);
89 cluster->DisableDetailedStats(disable_detailed_stats);
90 cluster->AllowSoftPlacement(allow_soft_placement);
91 cluster->SetNumWarmupSteps(10);
92 MaybeRaiseRegisteredFromStatus(cluster->Provision());
93 return cluster.release();
94 });
95
96 m.def("TF_NewVirtualCluster",
97 [](const std::vector<py::bytes>& serialized_named_devices)
98 -> tensorflow::grappler::Cluster* {
99 std::vector<tensorflow::NamedDevice> named_devices;
100 for (const auto& s : serialized_named_devices) {
101 tensorflow::NamedDevice named_device;
102 if (!named_device.ParseFromString(std::string(s))) {
103 throw std::invalid_argument(
104 "The NamedDevice could not be parsed as a valid protocol "
105 "buffer");
106 }
107 named_devices.push_back(named_device);
108 }
109
110 std::unordered_map<std::string, tensorflow::DeviceProperties> devices;
111 for (const auto& named_device : named_devices) {
112 devices[named_device.name()] = named_device.properties();
113 }
114 std::unique_ptr<tensorflow::grappler::Cluster> cluster =
115 std::make_unique<tensorflow::grappler::VirtualCluster>(devices);
116 {
117 // TODO(petebu): Do we need to hold the GIL here?
118 py::gil_scoped_acquire acquire;
119 MaybeRaiseRegisteredFromStatus(cluster->Provision());
120 }
121 return cluster.release();
122 });
123
124 m.def("TF_ShutdownCluster", [](tensorflow::grappler::Cluster* cluster) {
125 // TODO(petebu): Do we need to hold the GIL here?
126 py::gil_scoped_acquire acquire;
127 (void)cluster->Shutdown();
128 });
129
130 m.def("TF_ListDevices",
131 [](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> {
132 const std::unordered_map<std::string, tensorflow::DeviceProperties>&
133 devices = cluster->GetDevices();
134 std::vector<py::bytes> named_devices;
135 for (auto& dev : devices) {
136 tensorflow::NamedDevice d;
137 d.set_name(dev.first);
138 *d.mutable_properties() = dev.second;
139 named_devices.push_back(d.SerializeAsString());
140 }
141 return named_devices;
142 });
143
144 m.def("TF_ListAvailableOps", []() -> std::vector<std::string> {
145 tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
146 std::vector<tensorflow::OpDef> ops;
147 registry->GetRegisteredOps(&ops);
148 std::vector<std::string> op_names;
149 op_names.reserve(ops.size());
150 for (const tensorflow::OpDef& op : ops) {
151 op_names.push_back(op.name());
152 }
153 std::sort(op_names.begin(), op_names.end());
154 return op_names;
155 });
156
157 m.def(
158 "TF_GetSupportedDevices",
159 [](tensorflow::grappler::Cluster* cluster,
160 tensorflow::grappler::GrapplerItem* item)
161 -> std::unordered_map<std::string, std::vector<std::string>> {
162 if (cluster == nullptr || item == nullptr) {
163 MaybeRaiseRegisteredFromStatus(tensorflow::Status(
164 tensorflow::errors::Internal("You need both a cluster and an "
165 "item to get supported devices.")));
166 }
167 const std::unordered_map<std::string, tensorflow::DeviceProperties>&
168 devices = cluster->GetDevices();
169 std::unordered_map<std::string, std::vector<std::string>> device_types;
170 for (const auto& dev : devices) {
171 device_types[dev.second.type()].push_back(dev.first);
172 }
173
174 std::unordered_map<std::string, std::set<std::string>>
175 supported_device_types;
176 std::unordered_map<std::string, std::set<std::string>>
177 device_restrictions;
178
179 for (const auto& node : item->graph.node()) {
180 for (const auto& dev : device_types) {
181 const std::string& type = dev.first;
182 if (cluster->type() != "single_machine") {
183 // The actual kernel may not be linked in this binary.
184 supported_device_types[node.name()].insert(type);
185 } else {
186 // Check the kernel capabilities
187 const tensorflow::DeviceType dev_type(type);
188 tensorflow::Status s =
189 tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
190 if (s.ok()) {
191 supported_device_types[node.name()].insert(type);
192
193 // Check which inputs are restricted to reside on the host.
194 // TODO: extends this to support outputs as well
195 tensorflow::MemoryTypeVector inp_mtypes;
196 tensorflow::MemoryTypeVector out_mtypes;
197 tensorflow::Status s = tensorflow::MemoryTypesForNode(
198 tensorflow::OpRegistry::Global(), dev_type, node,
199 &inp_mtypes, &out_mtypes);
200 if (s.ok()) {
201 for (size_t i = 0; i < inp_mtypes.size(); ++i) {
202 if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
203 device_restrictions[tensorflow::grappler::NodeName(
204 node.input(i))]
205 .insert("CPU");
206 break;
207 }
208 }
209 }
210 }
211 }
212 }
213 }
214
215 std::unordered_map<std::string, std::vector<std::string>> result;
216 for (const auto& supported_dev : supported_device_types) {
217 const std::string& node = supported_dev.first;
218 std::set<std::string> feasible;
219 const auto it = device_restrictions.find(node);
220 if (it != device_restrictions.end()) {
221 const std::set<std::string>& candidates = supported_dev.second;
222 const std::set<std::string>& valid = it->second;
223 std::set_intersection(candidates.begin(), candidates.end(),
224 valid.begin(), valid.end(),
225 std::inserter(feasible, feasible.begin()));
226 } else {
227 feasible = supported_dev.second;
228 }
229
230 std::vector<std::string> device_names;
231 for (const std::string& type : feasible) {
232 auto it = device_types.find(type);
233 DCHECK(it != device_types.end());
234 for (const std::string& name : it->second) {
235 device_names.push_back(name);
236 }
237 }
238 result[node] = device_names;
239 }
240 return result;
241 });
242
243 m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
244 tensorflow::NamedDevice device;
245 if (!device.ParseFromString(std::string(serialized_device))) {
246 throw std::invalid_argument(
247 "The NamedDevice could not be parsed as a valid protocol buffer");
248 }
249 tensorflow::grappler::OpLevelCostEstimator estimator;
250 tensorflow::grappler::DeviceInfo info =
251 estimator.GetDeviceInfo(device.properties());
252 return info.gigaops;
253 });
254
255 m.def("TF_MeasureCosts",
256 [](tensorflow::grappler::GrapplerItem* item,
257 tensorflow::grappler::Cluster* cluster, bool generate_timeline)
258 -> std::tuple<std::vector<py::bytes>, double, py::bytes> {
259 const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
260 tensorflow::grappler::MeasuringCostEstimator cost_measure(
261 cluster, num_measurements, 0);
262
263 tensorflow::OpPerformanceList op_performance_data;
264 tensorflow::grappler::Costs costs;
265 tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
266 *item, &cost_measure, &op_performance_data, &costs);
267 double run_time = FLT_MAX;
268 if (s.ok()) {
269 run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
270 }
271 tensorflow::StepStats step_stats;
272 if (generate_timeline) {
273 tensorflow::RunMetadata metadata;
274 MaybeRaiseRegisteredFromStatus(
275 cluster->Run(item->graph, item->feed, item->fetch, &metadata));
276 step_stats = metadata.step_stats();
277 }
278
279 std::vector<py::bytes> op_perf_objs;
280 op_perf_objs.resize(op_performance_data.op_performance_size());
281 for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
282 op_perf_objs[i] =
283 op_performance_data.op_performance(i).SerializeAsString();
284 }
285
286 py::bytes step_stats_str = step_stats.SerializeAsString();
287 return std::make_tuple(op_perf_objs, run_time, step_stats_str);
288 });
289
290 using DurationType = tensorflow::grappler::Costs::Duration::rep;
291 using MemoryUsage =
292 std::tuple<std::string, int, size_t, DurationType, DurationType>;
293
294 m.def(
295 "TF_DeterminePeakMemoryUsage",
296 [](tensorflow::grappler::GrapplerItem* item,
297 tensorflow::grappler::Cluster* cluster)
298 -> std::unordered_map<std::string,
299 std::tuple<int64_t, std::vector<MemoryUsage>>> {
300 if (item == nullptr || cluster == nullptr) {
301 MaybeRaiseRegisteredFromStatus(
302 tensorflow::Status(tensorflow::errors::Internal(
303 "You need both a cluster and an item to determine peak "
304 "memory usage.")));
305 }
306 tensorflow::grappler::GraphMemory memory(*item);
307
308 if (cluster->DetailedStatsEnabled()) {
309 MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster));
310 } else {
311 MaybeRaiseRegisteredFromStatus(
312 memory.InferStatically(cluster->GetDevices()));
313 }
314
315 std::unordered_map<std::string,
316 std::tuple<int64_t, std::vector<MemoryUsage>>>
317 result;
318 for (const auto& device : cluster->GetDevices()) {
319 const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
320 memory.GetPeakMemoryUsage(device.first);
321 std::vector<MemoryUsage> per_device;
322 for (size_t i = 0; i < usage.live_tensors.size(); ++i) {
323 const auto& live_tensor = usage.live_tensors[i];
324 per_device.push_back(std::make_tuple(
325 live_tensor.node, live_tensor.output_id,
326 live_tensor.memory_used, live_tensor.allocation_time.count(),
327 live_tensor.deallocation_time.count()));
328 }
329 result[device.first] = std::make_tuple(usage.used_memory, per_device);
330 }
331 return result;
332 });
333}
334