1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_def_builder.h" |
17 | |
18 | #include <limits> |
19 | #include <vector> |
20 | |
21 | #include "absl/strings/escaping.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/attr_value_util.h" |
24 | #include "tensorflow/core/framework/op_def_util.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/lib/gtl/array_slice.h" |
28 | #include "tensorflow/core/lib/strings/scanner.h" |
29 | #include "tensorflow/core/lib/strings/str_util.h" |
30 | #include "tensorflow/core/lib/strings/strcat.h" |
31 | #include "tensorflow/core/platform/errors.h" |
32 | |
33 | using ::tensorflow::strings::Scanner; |
34 | |
35 | namespace tensorflow { |
36 | |
37 | namespace { |
38 | |
39 | string AttrError(StringPiece orig, const string& op_name) { |
40 | return strings::StrCat(" from Attr(\"" , orig, "\") for Op " , op_name); |
41 | } |
42 | |
43 | bool ConsumeAttrName(StringPiece* sp, StringPiece* out) { |
44 | return Scanner(*sp) |
45 | .One(Scanner::LETTER) |
46 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE) |
47 | .StopCapture() |
48 | .AnySpace() |
49 | .OneLiteral(":" ) |
50 | .AnySpace() |
51 | .GetResult(sp, out); |
52 | } |
53 | |
54 | bool ConsumeListPrefix(StringPiece* sp) { |
55 | return Scanner(*sp) |
56 | .OneLiteral("list" ) |
57 | .AnySpace() |
58 | .OneLiteral("(" ) |
59 | .AnySpace() |
60 | .GetResult(sp); |
61 | } |
62 | |
63 | bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) { |
64 | const string quote_str(1, quote_ch); |
65 | return Scanner(*sp) |
66 | .OneLiteral(quote_str.c_str()) |
67 | .RestartCapture() |
68 | .ScanEscapedUntil(quote_ch) |
69 | .StopCapture() |
70 | .OneLiteral(quote_str.c_str()) |
71 | .AnySpace() |
72 | .GetResult(sp, out); |
73 | } |
74 | |
75 | bool ConsumeAttrType(StringPiece* sp, StringPiece* out) { |
76 | return Scanner(*sp) |
77 | .Many(Scanner::LOWERLETTER_DIGIT) |
78 | .StopCapture() |
79 | .AnySpace() |
80 | .GetResult(sp, out); |
81 | } |
82 | |
83 | bool ConsumeAttrNumber(StringPiece* sp, int64_t* out) { |
84 | Scanner scan(*sp); |
85 | StringPiece match; |
86 | StringPiece remaining; |
87 | |
88 | scan.AnySpace().RestartCapture(); |
89 | if (scan.Peek() == '-') { |
90 | scan.OneLiteral("-" ); |
91 | } |
92 | if (!scan.Many(Scanner::DIGIT) |
93 | .StopCapture() |
94 | .AnySpace() |
95 | .GetResult(&remaining, &match)) { |
96 | return false; |
97 | } |
98 | int64_t value = 0; |
99 | if (!strings::safe_strto64(match, &value)) { |
100 | return false; |
101 | } |
102 | *out = value; |
103 | *sp = remaining; |
104 | return true; |
105 | } |
106 | |
107 | #define VERIFY(expr, ...) \ |
108 | do { \ |
109 | if (!(expr)) { \ |
110 | errors->push_back( \ |
111 | strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \ |
112 | return; \ |
113 | } \ |
114 | } while (false) |
115 | |
116 | bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) { |
117 | auto capture_data = sp->data(); |
118 | auto capture_begin = sp->begin(); |
119 | if (absl::ConsumePrefix(sp, "numbertype" ) || |
120 | absl::ConsumePrefix(sp, "numerictype" ) || |
121 | absl::ConsumePrefix(sp, "quantizedtype" ) || |
122 | absl::ConsumePrefix(sp, "realnumbertype" ) || |
123 | absl::ConsumePrefix(sp, "realnumberictype" )) { |
124 | *out = StringPiece(capture_data, sp->begin() - capture_begin); |
125 | return true; |
126 | } |
127 | return false; |
128 | } |
129 | |
130 | bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) { |
131 | if (type_string == "numbertype" || type_string == "numerictype" ) { |
132 | for (DataType dt : NumberTypes()) { |
133 | allowed->mutable_list()->add_type(dt); |
134 | } |
135 | } else if (type_string == "quantizedtype" ) { |
136 | for (DataType dt : QuantizedTypes()) { |
137 | allowed->mutable_list()->add_type(dt); |
138 | } |
139 | } else if (type_string == "realnumbertype" || |
140 | type_string == "realnumerictype" ) { |
141 | for (DataType dt : RealNumberTypes()) { |
142 | allowed->mutable_list()->add_type(dt); |
143 | } |
144 | } else { |
145 | return false; |
146 | } |
147 | return true; |
148 | } |
149 | |
150 | void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def, |
151 | std::vector<string>* errors) { |
152 | OpDef::AttrDef* attr = op_def->add_attr(); |
153 | StringPiece orig(spec); |
154 | |
155 | // Parse "<name>:" at the beginning. |
156 | StringPiece tmp_name; |
157 | VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'" ); |
158 | attr->set_name(tmp_name.data(), tmp_name.size()); |
159 | |
160 | // Read "<type>" or "list(<type>)". |
161 | bool is_list = ConsumeListPrefix(&spec); |
162 | string type; |
163 | StringPiece type_string; // Used if type == "type" |
164 | if (absl::ConsumePrefix(&spec, "string" )) { |
165 | type = "string" ; |
166 | } else if (absl::ConsumePrefix(&spec, "int" )) { |
167 | type = "int" ; |
168 | } else if (absl::ConsumePrefix(&spec, "float" )) { |
169 | type = "float" ; |
170 | } else if (absl::ConsumePrefix(&spec, "bool" )) { |
171 | type = "bool" ; |
172 | } else if (absl::ConsumePrefix(&spec, "type" )) { |
173 | type = "type" ; |
174 | } else if (absl::ConsumePrefix(&spec, "shape" )) { |
175 | type = "shape" ; |
176 | } else if (absl::ConsumePrefix(&spec, "tensor" )) { |
177 | type = "tensor" ; |
178 | } else if (absl::ConsumePrefix(&spec, "func" )) { |
179 | type = "func" ; |
180 | } else if (absl::ConsumePrefix(&spec, "any" ) && allow_attr_type_any) { |
181 | type = "any" ; |
182 | } else if (ConsumeCompoundAttrType(&spec, &type_string)) { |
183 | type = "type" ; |
184 | AttrValue* allowed = attr->mutable_allowed_values(); |
185 | VERIFY(ProcessCompoundType(type_string, allowed), |
186 | "Expected to see a compound type, saw: " , type_string); |
187 | } else if (absl::ConsumePrefix(&spec, "{" )) { |
188 | // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" |
189 | AttrValue* allowed = attr->mutable_allowed_values(); |
190 | str_util::RemoveLeadingWhitespace(&spec); |
191 | if (absl::StartsWith(spec, "\"" ) || absl::StartsWith(spec, "'" )) { |
192 | type = "string" ; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" |
193 | while (true) { |
194 | StringPiece escaped_string; |
195 | VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || |
196 | ConsumeQuotedString('\'', &spec, &escaped_string), |
197 | "Trouble parsing allowed string at '" , spec, "'" ); |
198 | string unescaped; |
199 | string error; |
200 | VERIFY(absl::CUnescape(escaped_string, &unescaped, &error), |
201 | "Trouble unescaping \"" , escaped_string, |
202 | "\", got error: " , error); |
203 | allowed->mutable_list()->add_s(unescaped); |
204 | if (absl::ConsumePrefix(&spec, "," )) { |
205 | str_util::RemoveLeadingWhitespace(&spec); |
206 | if (absl::ConsumePrefix(&spec, "}" )) |
207 | break; // Allow ending with ", }". |
208 | } else { |
209 | VERIFY(absl::ConsumePrefix(&spec, "}" ), |
210 | "Expected , or } after strings in list, not: '" , spec, "'" ); |
211 | break; |
212 | } |
213 | } |
214 | } else { // "{ bool, numbertype, string }" |
215 | type = "type" ; |
216 | while (true) { |
217 | VERIFY(ConsumeAttrType(&spec, &type_string), |
218 | "Trouble parsing type string at '" , spec, "'" ); |
219 | if (ProcessCompoundType(type_string, allowed)) { |
220 | // Processed a compound type. |
221 | } else { |
222 | DataType dt; |
223 | VERIFY(DataTypeFromString(type_string, &dt), |
224 | "Unrecognized type string '" , type_string, "'" ); |
225 | allowed->mutable_list()->add_type(dt); |
226 | } |
227 | if (absl::ConsumePrefix(&spec, "," )) { |
228 | str_util::RemoveLeadingWhitespace(&spec); |
229 | if (absl::ConsumePrefix(&spec, "}" )) |
230 | break; // Allow ending with ", }". |
231 | } else { |
232 | VERIFY(absl::ConsumePrefix(&spec, "}" ), |
233 | "Expected , or } after types in list, not: '" , spec, "'" ); |
234 | break; |
235 | } |
236 | } |
237 | } |
238 | } else { // if spec.Consume("{") |
239 | VERIFY(false, "Trouble parsing type string at '" , spec, "'" ); |
240 | } |
241 | str_util::RemoveLeadingWhitespace(&spec); |
242 | |
243 | // Write the type into *attr. |
244 | if (is_list) { |
245 | VERIFY(absl::ConsumePrefix(&spec, ")" ), |
246 | "Expected ) to close 'list(', not: '" , spec, "'" ); |
247 | str_util::RemoveLeadingWhitespace(&spec); |
248 | attr->set_type(strings::StrCat("list(" , type, ")" )); |
249 | } else { |
250 | attr->set_type(type); |
251 | } |
252 | |
253 | // Read optional minimum constraint at the end. |
254 | if ((is_list || type == "int" ) && absl::ConsumePrefix(&spec, ">=" )) { |
255 | int64_t min_limit = -999; |
256 | VERIFY(ConsumeAttrNumber(&spec, &min_limit), |
257 | "Could not parse integer lower limit after '>=', found '" , spec, |
258 | "' instead" ); |
259 | attr->set_has_minimum(true); |
260 | attr->set_minimum(min_limit); |
261 | } |
262 | |
263 | // Parse default value, if present. |
264 | if (absl::ConsumePrefix(&spec, "=" )) { |
265 | str_util::RemoveLeadingWhitespace(&spec); |
266 | VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), |
267 | "Could not parse default value '" , spec, "'" ); |
268 | } else { |
269 | VERIFY(spec.empty(), "Extra '" , spec, "' unparsed at the end" ); |
270 | } |
271 | } |
272 | |
273 | #undef VERIFY |
274 | |
275 | string InOutError(bool is_output, StringPiece orig, const string& op_name) { |
276 | return strings::StrCat(" from " , is_output ? "Output" : "Input" , "(\"" , orig, |
277 | "\") for Op " , op_name); |
278 | } |
279 | |
280 | bool ConsumeInOutName(StringPiece* sp, StringPiece* out) { |
281 | return Scanner(*sp) |
282 | .One(Scanner::LOWERLETTER) |
283 | .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE) |
284 | .StopCapture() |
285 | .AnySpace() |
286 | .OneLiteral(":" ) |
287 | .AnySpace() |
288 | .GetResult(sp, out); |
289 | } |
290 | |
291 | bool ConsumeInOutRefOpen(StringPiece* sp) { |
292 | return Scanner(*sp) |
293 | .OneLiteral("Ref" ) |
294 | .AnySpace() |
295 | .OneLiteral("(" ) |
296 | .AnySpace() |
297 | .GetResult(sp); |
298 | } |
299 | |
300 | bool ConsumeInOutRefClose(StringPiece* sp) { |
301 | return Scanner(*sp).OneLiteral(")" ).AnySpace().GetResult(sp); |
302 | } |
303 | |
304 | bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) { |
305 | return Scanner(*sp) |
306 | .One(Scanner::LETTER) |
307 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE) |
308 | .StopCapture() |
309 | .AnySpace() |
310 | .GetResult(sp, out); |
311 | } |
312 | |
313 | bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) { |
314 | return Scanner(*sp) |
315 | .OneLiteral("*" ) |
316 | .AnySpace() |
317 | .RestartCapture() |
318 | .One(Scanner::LETTER) |
319 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE) |
320 | .StopCapture() |
321 | .AnySpace() |
322 | .GetResult(sp, out); |
323 | } |
324 | |
325 | bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) { |
326 | return Scanner(*sp) |
327 | .One(Scanner::LETTER) |
328 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE) |
329 | .StopCapture() |
330 | .GetResult(sp, out); |
331 | } |
332 | |
333 | #define VERIFY(expr, ...) \ |
334 | do { \ |
335 | if (!(expr)) { \ |
336 | errors->push_back(strings::StrCat( \ |
337 | __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \ |
338 | return; \ |
339 | } \ |
340 | } while (false) |
341 | |
342 | void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, |
343 | std::vector<string>* errors) { |
344 | OpDef::ArgDef* arg = |
345 | is_output ? op_def->add_output_arg() : op_def->add_input_arg(); |
346 | |
347 | StringPiece orig(spec); |
348 | |
349 | // Parse "<name>:" at the beginning. |
350 | StringPiece tmp_name; |
351 | VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'" ); |
352 | arg->set_name(tmp_name.data(), tmp_name.size()); |
353 | |
354 | // Detect "Ref(...)". |
355 | if (ConsumeInOutRefOpen(&spec)) { |
356 | arg->set_is_ref(true); |
357 | } |
358 | |
359 | { // Parse "<name|type>" or "<name>*<name|type>". |
360 | StringPiece first, second, type_or_attr; |
361 | VERIFY(ConsumeInOutNameOrType(&spec, &first), |
362 | "Trouble parsing either a type or an attr name at '" , spec, "'" ); |
363 | if (ConsumeInOutTimesType(&spec, &second)) { |
364 | arg->set_number_attr(first.data(), first.size()); |
365 | type_or_attr = second; |
366 | } else { |
367 | type_or_attr = first; |
368 | } |
369 | DataType dt; |
370 | if (DataTypeFromString(type_or_attr, &dt)) { |
371 | arg->set_type(dt); |
372 | } else { |
373 | const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def); |
374 | VERIFY(attr != nullptr, "Reference to unknown attr '" , type_or_attr, "'" ); |
375 | if (attr->type() == "type" ) { |
376 | arg->set_type_attr(type_or_attr.data(), type_or_attr.size()); |
377 | } else { |
378 | VERIFY(attr->type() == "list(type)" , "Reference to attr '" , |
379 | type_or_attr, "' with type " , attr->type(), |
380 | " that isn't type or list(type)" ); |
381 | arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size()); |
382 | } |
383 | } |
384 | } |
385 | |
386 | // Closing ) for Ref(. |
387 | if (arg->is_ref()) { |
388 | VERIFY(ConsumeInOutRefClose(&spec), |
389 | "Did not find closing ')' for 'Ref(', instead found: '" , spec, "'" ); |
390 | } |
391 | |
392 | // Should not have anything else. |
393 | VERIFY(spec.empty(), "Extra '" , spec, "' unparsed at the end" ); |
394 | |
395 | // Int attrs that are the length of an input or output get a default |
396 | // minimum of 1. |
397 | if (!arg->number_attr().empty()) { |
398 | OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def); |
399 | if (attr != nullptr && !attr->has_minimum()) { |
400 | attr->set_has_minimum(true); |
401 | attr->set_minimum(1); |
402 | } |
403 | } else if (!arg->type_list_attr().empty()) { |
404 | // If an input or output has type specified by a list(type) attr, |
405 | // it gets a default minimum of 1 as well. |
406 | OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def); |
407 | if (attr != nullptr && attr->type() == "list(type)" && |
408 | !attr->has_minimum()) { |
409 | attr->set_has_minimum(true); |
410 | attr->set_minimum(1); |
411 | } |
412 | } |
413 | |
414 | // If the arg's dtype is resource we should mark the op as stateful as it |
415 | // likely touches a resource manager. This deliberately doesn't cover inputs / |
416 | // outputs which resolve to resource via Attrs as those mostly operate on |
417 | // resource handles as an opaque type (as opposed to ops which explicitly take |
418 | // / produce resources). |
419 | if (arg->type() == DT_RESOURCE) { |
420 | op_def->set_is_stateful(true); |
421 | } |
422 | } |
423 | |
424 | #undef VERIFY |
425 | |
426 | string ControlOutError(StringPiece orig, const string& op_name) { |
427 | return strings::StrCat(" from ControlOutput(\"" , orig, "\") for Op " , |
428 | op_name); |
429 | } |
430 | |
431 | void FinalizeControlOutput(StringPiece name, OpDef* op_def, |
432 | std::vector<string>* errors) { |
433 | StringPiece orig(name); |
434 | |
435 | // Parse control output name. |
436 | StringPiece tmp_name; |
437 | if (!ConsumeControlOutName(&orig, &tmp_name)) { |
438 | errors->push_back(strings::StrCat("Trouble parsing 'name:'" , |
439 | ControlOutError(orig, op_def->name()))); |
440 | } |
441 | |
442 | *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size()); |
443 | } |
444 | |
445 | int num_leading_spaces(StringPiece s) { |
446 | size_t i = 0; |
447 | while (i < s.size() && s[i] == ' ') { |
448 | ++i; |
449 | } |
450 | return i; |
451 | } |
452 | |
453 | bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) { |
454 | return Scanner(*sp) |
455 | .One(Scanner::LETTER) |
456 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE) |
457 | .StopCapture() |
458 | .AnySpace() |
459 | .OneLiteral(":" ) |
460 | .AnySpace() |
461 | .GetResult(sp, out); |
462 | } |
463 | |
464 | bool IsDocNameColon(StringPiece s) { |
465 | return ConsumeDocNameColon(&s, nullptr /* out */); |
466 | } |
467 | |
468 | void FinalizeDoc(const string& text, OpDef* op_def, |
469 | std::vector<string>* errors) { |
470 | std::vector<string> lines = str_util::Split(text, '\n'); |
471 | |
472 | // Remove trailing spaces. |
473 | for (string& line : lines) { |
474 | absl::StripTrailingAsciiWhitespace(&line); |
475 | } |
476 | |
477 | // First non-blank line -> summary. |
478 | int l = 0; |
479 | while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; |
480 | if (static_cast<size_t>(l) < lines.size()) { |
481 | op_def->set_summary(lines[l]); |
482 | ++l; |
483 | } |
484 | while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; |
485 | |
486 | // Lines until we see name: -> description. |
487 | int start_l = l; |
488 | while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { |
489 | ++l; |
490 | } |
491 | int end_l = l; |
492 | // Trim trailing blank lines from the description. |
493 | while (start_l < end_l && lines[end_l - 1].empty()) --end_l; |
494 | string desc = absl::StrJoin( |
495 | gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n" ); |
496 | if (!desc.empty()) op_def->set_description(desc); |
497 | |
498 | // name: description |
499 | // possibly continued on the next line |
500 | // if so, we remove the minimum indent |
501 | StringPiece name; |
502 | std::vector<StringPiece> description; |
503 | while (static_cast<size_t>(l) < lines.size()) { |
504 | description.clear(); |
505 | description.push_back(lines[l]); |
506 | ConsumeDocNameColon(&description.back(), &name); |
507 | ++l; |
508 | while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { |
509 | description.push_back(lines[l]); |
510 | ++l; |
511 | } |
512 | // Remove any trailing blank lines. |
513 | while (!description.empty() && description.back().empty()) { |
514 | description.pop_back(); |
515 | } |
516 | // Compute the minimum indent of all lines after the first. |
517 | int min_indent = -1; |
518 | for (size_t i = 1; i < description.size(); ++i) { |
519 | if (!description[i].empty()) { |
520 | int indent = num_leading_spaces(description[i]); |
521 | if (min_indent < 0 || indent < min_indent) min_indent = indent; |
522 | } |
523 | } |
524 | // Remove min_indent spaces from all lines after the first. |
525 | for (size_t i = 1; i < description.size(); ++i) { |
526 | if (!description[i].empty()) description[i].remove_prefix(min_indent); |
527 | } |
528 | // Concatenate lines into a single string. |
529 | const string complete(absl::StrJoin(description, "\n" )); |
530 | |
531 | // Find name. |
532 | bool found = false; |
533 | for (int i = 0; !found && i < op_def->input_arg_size(); ++i) { |
534 | if (op_def->input_arg(i).name() == name) { |
535 | op_def->mutable_input_arg(i)->set_description(complete); |
536 | found = true; |
537 | } |
538 | } |
539 | for (int i = 0; !found && i < op_def->output_arg_size(); ++i) { |
540 | if (op_def->output_arg(i).name() == name) { |
541 | op_def->mutable_output_arg(i)->set_description(complete); |
542 | found = true; |
543 | } |
544 | } |
545 | for (int i = 0; !found && i < op_def->attr_size(); ++i) { |
546 | if (op_def->attr(i).name() == name) { |
547 | op_def->mutable_attr(i)->set_description(complete); |
548 | found = true; |
549 | } |
550 | } |
551 | if (!found) { |
552 | errors->push_back( |
553 | strings::StrCat("No matching input/output/attr for name '" , name, |
554 | "' from Doc() for Op " , op_def->name())); |
555 | return; |
556 | } |
557 | } |
558 | } |
559 | |
560 | } // namespace |
561 | |
562 | OpDefBuilder::OpDefBuilder(string op_name) { |
563 | op_def()->set_name(std::move(op_name)); |
564 | } |
565 | |
566 | OpDefBuilder& OpDefBuilder::Attr(string spec) { |
567 | attrs_.push_back(std::move(spec)); |
568 | return *this; |
569 | } |
570 | |
571 | OpDefBuilder& OpDefBuilder::Input(string spec) { |
572 | inputs_.push_back(std::move(spec)); |
573 | return *this; |
574 | } |
575 | |
576 | OpDefBuilder& OpDefBuilder::Output(string spec) { |
577 | outputs_.push_back(std::move(spec)); |
578 | return *this; |
579 | } |
580 | |
581 | OpDefBuilder& OpDefBuilder::ControlOutput(string name) { |
582 | control_outputs_.push_back(std::move(name)); |
583 | return *this; |
584 | } |
585 | |
586 | OpDefBuilder& OpDefBuilder::Doc(string text) { |
587 | #ifndef TF_LEAN_BINARY |
588 | if (!doc_.empty()) { |
589 | errors_.push_back( |
590 | strings::StrCat("Extra call to Doc() for Op " , op_def()->name())); |
591 | } else { |
592 | doc_ = std::move(text); |
593 | } |
594 | #endif |
595 | return *this; |
596 | } |
597 | |
598 | OpDefBuilder& OpDefBuilder::SetIsCommutative() { |
599 | op_def()->set_is_commutative(true); |
600 | return *this; |
601 | } |
602 | |
603 | OpDefBuilder& OpDefBuilder::SetIsAggregate() { |
604 | op_def()->set_is_aggregate(true); |
605 | return *this; |
606 | } |
607 | |
608 | OpDefBuilder& OpDefBuilder::SetIsStateful() { |
609 | op_def()->set_is_stateful(true); |
610 | return *this; |
611 | } |
612 | |
613 | OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { |
614 | op_def()->set_allows_uninitialized_input(true); |
615 | return *this; |
616 | } |
617 | |
618 | OpDefBuilder& OpDefBuilder::SetIsDistributedCommunication() { |
619 | op_def()->set_is_distributed_communication(true); |
620 | return *this; |
621 | } |
622 | |
623 | OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) { |
624 | if (op_def()->has_deprecation()) { |
625 | errors_.push_back( |
626 | strings::StrCat("Deprecated called twice for Op " , op_def()->name())); |
627 | } else { |
628 | OpDeprecation* deprecation = op_def()->mutable_deprecation(); |
629 | deprecation->set_version(version); |
630 | deprecation->set_explanation(std::move(explanation)); |
631 | } |
632 | return *this; |
633 | } |
634 | |
635 | OpDefBuilder& OpDefBuilder::SetTypeConstructor(OpTypeConstructor c) { |
636 | op_reg_data_.type_ctor = c; |
637 | return *this; |
638 | } |
639 | |
640 | OpDefBuilder& OpDefBuilder::SetForwardTypeFn(ForwardTypeInferenceFn f) { |
641 | op_reg_data_.fwd_type_fn = f; |
642 | return *this; |
643 | } |
644 | |
645 | OpDefBuilder& OpDefBuilder::SetReverseTypeFn(int input_number, |
646 | ForwardTypeInferenceFn f) { |
647 | op_reg_data_.rev_type_fn = f; |
648 | op_reg_data_.rev_type_input = input_number; |
649 | return *this; |
650 | } |
651 | |
652 | OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) { |
653 | if (op_reg_data_.shape_inference_fn != nullptr) { |
654 | errors_.push_back( |
655 | strings::StrCat("SetShapeFn called twice for Op " , op_def()->name())); |
656 | } else { |
657 | op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn); |
658 | } |
659 | return *this; |
660 | } |
661 | |
662 | OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() { |
663 | allow_attr_type_any_ = true; |
664 | return *this; |
665 | } |
666 | |
667 | Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { |
668 | std::vector<string> errors = errors_; |
669 | *op_reg_data = op_reg_data_; |
670 | |
671 | OpDef* op_def = &op_reg_data->op_def; |
672 | for (StringPiece attr : attrs_) { |
673 | FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors); |
674 | } |
675 | for (StringPiece input : inputs_) { |
676 | FinalizeInputOrOutput(input, false, op_def, &errors); |
677 | } |
678 | for (StringPiece output : outputs_) { |
679 | FinalizeInputOrOutput(output, true, op_def, &errors); |
680 | } |
681 | for (StringPiece control_output : control_outputs_) { |
682 | FinalizeControlOutput(control_output, op_def, &errors); |
683 | } |
684 | FinalizeDoc(doc_, op_def, &errors); |
685 | |
686 | if (op_reg_data->type_ctor != nullptr) { |
687 | TF_RETURN_IF_ERROR(op_reg_data->type_ctor(op_def)); |
688 | } |
689 | |
690 | if (errors.empty()) return OkStatus(); |
691 | return errors::InvalidArgument(absl::StrJoin(errors, "\n" )); |
692 | } |
693 | |
694 | } // namespace tensorflow |
695 | |