1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "pybind11/cast.h" |
19 | #include "pybind11/pybind11.h" |
20 | #include "pybind11/pytypes.h" |
21 | #include "pybind11/stl.h" |
22 | #include "tensorflow/compiler/aot/compile.h" |
23 | #include "tensorflow/compiler/aot/flags.h" |
24 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
25 | #include "tensorflow/python/lib/core/pybind11_status.h" |
26 | |
27 | namespace py = pybind11; |
28 | |
29 | PYBIND11_MODULE(_pywrap_tfcompile, m) { |
30 | m.doc() = R"pbdoc( |
31 | _pywrap_tfcompile |
32 | ----- |
33 | )pbdoc" ; |
34 | |
35 | m.def( |
36 | "Compile" , |
37 | [](std::string graph, std::string config, std::string target_triple, |
38 | std::string target_cpu, std::string target_features, |
39 | std::string entry_point, std::string cpp_class, |
40 | std::string out_function_object, std::string out_metadata_object, |
41 | std::string , std::string out_session_module, |
42 | std::string mlir_components, bool gen_name_to_index, |
43 | bool gen_program_shape) { |
44 | tensorflow::tfcompile::MainFlags flags; |
45 | flags.graph = std::move(graph); |
46 | flags.config = std::move(config); |
47 | flags.target_triple = std::move(target_triple); |
48 | flags.target_cpu = std::move(target_cpu); |
49 | flags.target_features = std::move(target_features); |
50 | flags.entry_point = std::move(entry_point); |
51 | flags.cpp_class = std::move(cpp_class); |
52 | flags.out_function_object = std::move(out_function_object); |
53 | flags.out_metadata_object = std::move(out_metadata_object); |
54 | flags.out_header = std::move(out_header); |
55 | flags.out_session_module = std::move(out_session_module); |
56 | flags.mlir_components = std::move(mlir_components); |
57 | |
58 | // C++ codegen options |
59 | flags.gen_name_to_index = gen_name_to_index; |
60 | flags.gen_program_shape = gen_program_shape; |
61 | |
62 | tensorflow::MaybeRaiseFromStatus(tensorflow::tfcompile::Main(flags)); |
63 | }, |
64 | py::arg("graph" ) = "" , py::arg("config" ) = "" , |
65 | py::arg("target_triple" ) = "x86_64-pc-linux" , py::arg("target_cpu" ) = "" , |
66 | py::arg("target_features" ) = "" , py::arg("entry_point" ) = "entry" , |
67 | py::arg("cpp_class" ) = "" , py::arg("out_function_object" ) = "out_model.o" , |
68 | py::arg("out_metadata_object" ) = "out_helper.o" , |
69 | py::arg("out_header" ) = "out.h" , py::arg("out_session_module" ) = "" , |
70 | py::arg("mlir_components" ) = "" , py::arg("gen_name_to_index" ) = false, |
71 | py::arg("gen_program_shape" ) = false); |
72 | } |
73 | |