1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/framework/python_op_gen.h"
17
18#include <memory>
19#include <string>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/framework/op_def.pb.h"
25#include "tensorflow/core/framework/op_gen_lib.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/io/inputbuffer.h"
28#include "tensorflow/core/lib/io/path.h"
29#include "tensorflow/core/lib/strings/scanner.h"
30#include "tensorflow/core/lib/strings/str_util.h"
31#include "tensorflow/core/platform/env.h"
32#include "tensorflow/core/platform/init_main.h"
33#include "tensorflow/core/platform/logging.h"
34
35namespace tensorflow {
36namespace {
37
38Status ReadOpListFromFile(const string& filename,
39 std::vector<string>* op_list) {
40 std::unique_ptr<RandomAccessFile> file;
41 TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
42 std::unique_ptr<io::InputBuffer> input_buffer(
43 new io::InputBuffer(file.get(), 256 << 10));
44 string line_contents;
45 Status s = input_buffer->ReadLine(&line_contents);
46 while (s.ok()) {
47 // The parser assumes that the op name is the first string on each
48 // line with no preceding whitespace, and ignores lines that do
49 // not start with an op name as a comment.
50 strings::Scanner scanner{StringPiece(line_contents)};
51 StringPiece op_name;
52 if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
53 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
54 .GetResult(nullptr, &op_name)) {
55 op_list->emplace_back(op_name);
56 }
57 s = input_buffer->ReadLine(&line_contents);
58 }
59 if (!errors::IsOutOfRange(s)) return s;
60 return OkStatus();
61}
62
63// The argument parsing is deliberately simplistic to support our only
64// known use cases:
65//
66// 1. Read all op names from a file.
67// 2. Read all op names from the arg as a comma-delimited list.
68//
69// Expected command-line argument syntax:
70// ARG ::= '@' FILENAME
71// | OP_NAME [',' OP_NAME]*
72// | ''
73Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
74 std::vector<string> op_names = str_util::Split(arg, ',');
75 if (op_names.size() == 1 && op_names[0].empty()) {
76 return OkStatus();
77 } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") {
78 const string filename = op_names[0].substr(1);
79 return tensorflow::ReadOpListFromFile(filename, op_list);
80 } else {
81 *op_list = std::move(op_names);
82 }
83 return OkStatus();
84}
85
86// Use the name of the current executable to infer the C++ source file
87// where the REGISTER_OP() call for the operator can be found.
88// Returns the name of the file.
89// Returns an empty string if the current executable's name does not
90// follow a known pattern.
91string InferSourceFileName(const char* argv_zero) {
92 StringPiece command_str = io::Basename(argv_zero);
93
94 // For built-in ops, the Bazel build creates a separate executable
95 // with the name gen_<op type>_ops_py_wrappers_cc containing the
96 // operators defined in <op type>_ops.cc
97 const char* kExecPrefix = "gen_";
98 const char* kExecSuffix = "_py_wrappers_cc";
99 if (absl::ConsumePrefix(&command_str, kExecPrefix) &&
100 str_util::EndsWith(command_str, kExecSuffix)) {
101 command_str.remove_suffix(strlen(kExecSuffix));
102 return strings::StrCat(command_str, ".cc");
103 } else {
104 return string("");
105 }
106}
107
108void PrintAllPythonOps(const std::vector<string>& op_list,
109 const std::vector<string>& api_def_dirs,
110 const string& source_file_name,
111 bool op_list_is_allowlist,
112 const std::unordered_set<string> type_annotate_ops) {
113 OpList ops;
114 OpRegistry::Global()->Export(false, &ops);
115
116 ApiDefMap api_def_map(ops);
117 if (!api_def_dirs.empty()) {
118 Env* env = Env::Default();
119
120 for (const auto& api_def_dir : api_def_dirs) {
121 std::vector<string> api_files;
122 TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
123 &api_files));
124 TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
125 }
126 api_def_map.UpdateDocs();
127 }
128
129 if (op_list_is_allowlist) {
130 std::unordered_set<string> allowlist(op_list.begin(), op_list.end());
131 OpList pruned_ops;
132 for (const auto& op_def : ops.op()) {
133 if (allowlist.find(op_def.name()) != allowlist.end()) {
134 *pruned_ops.mutable_op()->Add() = op_def;
135 }
136 }
137 PrintPythonOps(pruned_ops, api_def_map, {}, source_file_name,
138 type_annotate_ops);
139 } else {
140 PrintPythonOps(ops, api_def_map, op_list, source_file_name,
141 type_annotate_ops);
142 }
143}
144
145} // namespace
146} // namespace tensorflow
147
148int main(int argc, char* argv[]) {
149 tensorflow::port::InitMain(argv[0], &argc, &argv);
150
151 tensorflow::string source_file_name =
152 tensorflow::InferSourceFileName(argv[0]);
153
154 // Usage:
155 // gen_main api_def_dir1,api_def_dir2,...
156 // [ @FILENAME | OpName[,OpName]* ] [0 | 1]
157 if (argc < 2) {
158 return -1;
159 }
160 std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
161 argv[1], ",", tensorflow::str_util::SkipEmpty());
162
163 // Add op name here to generate type annotations for it
164 const std::unordered_set<tensorflow::string> type_annotate_ops{};
165
166 if (argc == 2) {
167 tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
168 false /* op_list_is_allowlist */,
169 type_annotate_ops);
170 } else if (argc == 3) {
171 std::vector<tensorflow::string> hidden_ops;
172 TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
173 tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
174 false /* op_list_is_allowlist */,
175 type_annotate_ops);
176 } else if (argc == 4) {
177 std::vector<tensorflow::string> op_list;
178 TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
179 tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
180 tensorflow::string(argv[3]) == "1",
181 type_annotate_ops);
182 } else {
183 return -1;
184 }
185 return 0;
186}
187