1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "pybind11/pybind11.h" |
17 | #include "tensorflow/core/common_runtime/quantize_training.h" |
18 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
19 | #include "tensorflow/python/lib/core/pybind11_status.h" |
20 | |
21 | namespace py = pybind11; |
22 | |
23 | namespace tensorflow { |
24 | static PyObject* DoQuantizeTrainingOnGraphDefHelper(const string& input_graph, |
25 | int num_bits) { |
26 | string result; |
27 | // TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable. |
28 | tensorflow::MaybeRaiseFromStatus( |
29 | tensorflow::DoQuantizeTrainingOnSerializedGraphDef( |
30 | input_graph, num_bits, "QuantizeAndDequantizeV2" , &result)); |
31 | |
32 | PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size()); |
33 | if (!py_str) { |
34 | tensorflow::MaybeRaiseFromStatus(tensorflow::errors::Internal( |
35 | "Failed to generate serialized string of the rewritten graph." )); |
36 | } |
37 | return py_str; |
38 | } |
39 | } // namespace tensorflow |
40 | |
41 | PYBIND11_MODULE(_pywrap_quantize_training, m) { |
42 | m.def("DoQuantizeTrainingOnGraphDefHelper" , |
43 | [](const py::object input_graph, int num_bits) { |
44 | return tensorflow::PyoOrThrow( |
45 | tensorflow::DoQuantizeTrainingOnGraphDefHelper( |
46 | input_graph.cast<std::string>(), num_bits)); |
47 | }); |
48 | }; |
49 | |