1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ |
16 | #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ |
17 | |
18 | #include <string> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/op_def.pb.h" |
23 | #include "tensorflow/core/framework/op_gen_lib.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // Returns a string containing the generated Python code for the given Ops. |
29 | // ops is a protobuff, typically generated using OpRegistry::Global()->Export. |
30 | // api_defs is typically constructed directly from ops. |
31 | // hidden_ops should be a list of Op names that should get a leading _ |
32 | // in the output. |
33 | // source_file_name is optional and contains the name of the original C++ source |
34 | // file where the ops' REGISTER_OP() calls reside. |
35 | string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, |
36 | const std::vector<string>& hidden_ops, |
37 | const string& source_file_name, |
38 | const std::unordered_set<string> type_annotate_ops); |
39 | |
40 | // Prints the output of GetPrintOps to stdout. |
41 | // hidden_ops should be a list of Op names that should get a leading _ |
42 | // in the output. |
43 | // Optional fourth argument is the name of the original C++ source file |
44 | // where the ops' REGISTER_OP() calls reside. |
45 | void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, |
46 | const std::vector<string>& hidden_ops, |
47 | const string& source_file_name, |
48 | const std::unordered_set<string> type_annotate_ops); |
49 | |
50 | // Get the python wrappers for a list of ops in a OpList. |
51 | // `op_list_buf` should be a pointer to a buffer containing |
52 | // the binary encoded OpList proto, and `op_list_len` should be the |
53 | // length of that buffer. |
54 | string GetPythonWrappers(const char* op_list_buf, size_t op_list_len); |
55 | |
56 | // Get the type annotation for an arg |
57 | // `arg` should be an input or output of an op |
58 | // `type_annotations` should contain attr names mapped to TypeVar names |
59 | string GetArgAnnotation( |
60 | const OpDef::ArgDef& arg, |
61 | const std::unordered_map<string, string>& type_annotations); |
62 | |
63 | } // namespace tensorflow |
64 | |
65 | #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ |
66 | |