1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/framework/python_op_gen.h" |
17 | |
18 | #include <memory> |
19 | #include <string> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_def.pb.h" |
25 | #include "tensorflow/core/framework/op_gen_lib.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/lib/io/inputbuffer.h" |
28 | #include "tensorflow/core/lib/io/path.h" |
29 | #include "tensorflow/core/lib/strings/scanner.h" |
30 | #include "tensorflow/core/lib/strings/str_util.h" |
31 | #include "tensorflow/core/platform/env.h" |
32 | #include "tensorflow/core/platform/init_main.h" |
33 | #include "tensorflow/core/platform/logging.h" |
34 | |
35 | namespace tensorflow { |
36 | namespace { |
37 | |
38 | Status ReadOpListFromFile(const string& filename, |
39 | std::vector<string>* op_list) { |
40 | std::unique_ptr<RandomAccessFile> file; |
41 | TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file)); |
42 | std::unique_ptr<io::InputBuffer> input_buffer( |
43 | new io::InputBuffer(file.get(), 256 << 10)); |
44 | string line_contents; |
45 | Status s = input_buffer->ReadLine(&line_contents); |
46 | while (s.ok()) { |
47 | // The parser assumes that the op name is the first string on each |
48 | // line with no preceding whitespace, and ignores lines that do |
49 | // not start with an op name as a comment. |
50 | strings::Scanner scanner{StringPiece(line_contents)}; |
51 | StringPiece op_name; |
52 | if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT) |
53 | .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) |
54 | .GetResult(nullptr, &op_name)) { |
55 | op_list->emplace_back(op_name); |
56 | } |
57 | s = input_buffer->ReadLine(&line_contents); |
58 | } |
59 | if (!errors::IsOutOfRange(s)) return s; |
60 | return OkStatus(); |
61 | } |
62 | |
63 | // The argument parsing is deliberately simplistic to support our only |
64 | // known use cases: |
65 | // |
66 | // 1. Read all op names from a file. |
67 | // 2. Read all op names from the arg as a comma-delimited list. |
68 | // |
69 | // Expected command-line argument syntax: |
70 | // ARG ::= '@' FILENAME |
71 | // | OP_NAME [',' OP_NAME]* |
72 | // | '' |
73 | Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) { |
74 | std::vector<string> op_names = str_util::Split(arg, ','); |
75 | if (op_names.size() == 1 && op_names[0].empty()) { |
76 | return OkStatus(); |
77 | } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@" ) { |
78 | const string filename = op_names[0].substr(1); |
79 | return tensorflow::ReadOpListFromFile(filename, op_list); |
80 | } else { |
81 | *op_list = std::move(op_names); |
82 | } |
83 | return OkStatus(); |
84 | } |
85 | |
86 | // Use the name of the current executable to infer the C++ source file |
87 | // where the REGISTER_OP() call for the operator can be found. |
88 | // Returns the name of the file. |
89 | // Returns an empty string if the current executable's name does not |
90 | // follow a known pattern. |
91 | string InferSourceFileName(const char* argv_zero) { |
92 | StringPiece command_str = io::Basename(argv_zero); |
93 | |
94 | // For built-in ops, the Bazel build creates a separate executable |
95 | // with the name gen_<op type>_ops_py_wrappers_cc containing the |
96 | // operators defined in <op type>_ops.cc |
97 | const char* kExecPrefix = "gen_" ; |
98 | const char* kExecSuffix = "_py_wrappers_cc" ; |
99 | if (absl::ConsumePrefix(&command_str, kExecPrefix) && |
100 | str_util::EndsWith(command_str, kExecSuffix)) { |
101 | command_str.remove_suffix(strlen(kExecSuffix)); |
102 | return strings::StrCat(command_str, ".cc" ); |
103 | } else { |
104 | return string("" ); |
105 | } |
106 | } |
107 | |
108 | void PrintAllPythonOps(const std::vector<string>& op_list, |
109 | const std::vector<string>& api_def_dirs, |
110 | const string& source_file_name, |
111 | bool op_list_is_allowlist, |
112 | const std::unordered_set<string> type_annotate_ops) { |
113 | OpList ops; |
114 | OpRegistry::Global()->Export(false, &ops); |
115 | |
116 | ApiDefMap api_def_map(ops); |
117 | if (!api_def_dirs.empty()) { |
118 | Env* env = Env::Default(); |
119 | |
120 | for (const auto& api_def_dir : api_def_dirs) { |
121 | std::vector<string> api_files; |
122 | TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt" ), |
123 | &api_files)); |
124 | TF_CHECK_OK(api_def_map.LoadFileList(env, api_files)); |
125 | } |
126 | api_def_map.UpdateDocs(); |
127 | } |
128 | |
129 | if (op_list_is_allowlist) { |
130 | std::unordered_set<string> allowlist(op_list.begin(), op_list.end()); |
131 | OpList pruned_ops; |
132 | for (const auto& op_def : ops.op()) { |
133 | if (allowlist.find(op_def.name()) != allowlist.end()) { |
134 | *pruned_ops.mutable_op()->Add() = op_def; |
135 | } |
136 | } |
137 | PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name, |
138 | type_annotate_ops); |
139 | } else { |
140 | PrintPythonOps(ops, api_def_map, op_list, source_file_name, |
141 | type_annotate_ops); |
142 | } |
143 | } |
144 | |
145 | } // namespace |
146 | } // namespace tensorflow |
147 | |
148 | int main(int argc, char* argv[]) { |
149 | tensorflow::port::InitMain(argv[0], &argc, &argv); |
150 | |
151 | tensorflow::string source_file_name = |
152 | tensorflow::InferSourceFileName(argv[0]); |
153 | |
154 | // Usage: |
155 | // gen_main api_def_dir1,api_def_dir2,... |
156 | // [ @FILENAME | OpName[,OpName]* ] [0 | 1] |
157 | if (argc < 2) { |
158 | return -1; |
159 | } |
160 | std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split( |
161 | argv[1], "," , tensorflow::str_util::SkipEmpty()); |
162 | |
163 | // Add op name here to generate type annotations for it |
164 | const std::unordered_set<tensorflow::string> type_annotate_ops{}; |
165 | |
166 | if (argc == 2) { |
167 | tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name, |
168 | false /* op_list_is_allowlist */, |
169 | type_annotate_ops); |
170 | } else if (argc == 3) { |
171 | std::vector<tensorflow::string> hidden_ops; |
172 | TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops)); |
173 | tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name, |
174 | false /* op_list_is_allowlist */, |
175 | type_annotate_ops); |
176 | } else if (argc == 4) { |
177 | std::vector<tensorflow::string> op_list; |
178 | TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list)); |
179 | tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name, |
180 | tensorflow::string(argv[3]) == "1" , |
181 | type_annotate_ops); |
182 | } else { |
183 | return -1; |
184 | } |
185 | return 0; |
186 | } |
187 | |