1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <stdexcept>
18#include <string>
19#include <utility>
20
21#include "Python.h"
22#include "absl/algorithm/container.h"
23#include "absl/container/flat_hash_map.h"
24#include "absl/container/flat_hash_set.h"
25#include "absl/strings/match.h"
26#include "absl/strings/str_cat.h"
27#include "absl/strings/str_join.h"
28#include "pybind11/cast.h"
29#include "pybind11/pybind11.h"
30#include "pybind11/pytypes.h"
31#include "tensorflow/core/framework/attr_value.pb.h"
32#include "tensorflow/core/framework/op.h"
33#include "tensorflow/core/framework/op_def.pb.h"
34#include "tensorflow/core/framework/op_def_util.h"
35#include "tensorflow/core/framework/types.pb.h"
36#include "tensorflow/core/platform/protobuf.h"
37#include "tensorflow/python/framework/op_def_util.h"
38#include "tensorflow/python/lib/core/pybind11_status.h"
39
40namespace py = pybind11;
41
42namespace {
43
44using ::tensorflow::AttributeType;
45using ::tensorflow::AttributeTypeFromName;
46using ::tensorflow::AttrValue;
47using ::tensorflow::CheckOpDeprecation;
48using ::tensorflow::ConvertPyObjectToAttributeType;
49using ::tensorflow::DataType;
50using ::tensorflow::DataTypeToPyObject;
51using ::tensorflow::MaybeRaiseFromStatus;
52using ::tensorflow::OpDef;
53using ::tensorflow::OpRegistry;
54using ::tensorflow::protobuf::RepeatedField;
55using ::tensorflow::protobuf::RepeatedPtrField;
56using AttrDef = ::tensorflow::OpDef::AttrDef;
57using ArgDef = ::tensorflow::OpDef::ArgDef;
58// Keys: attr.name(); Values: attr_def.allowed_values().list().type()
59using AllowedAttrMap =
60 absl::flat_hash_map<std::string, absl::flat_hash_set<int>>;
61// Keys: attr.name(); Values; attr_def.default_value().type()
62using DefaultAttrMap = absl::flat_hash_map<std::string, py::object>;
63// Keys: attr.name(); Values: corresponding attr serialized as an AttrValue
64using AttrProtosMap = absl::flat_hash_map<std::string, AttrValue>;
65
66constexpr char kType[] = "type";
67constexpr char kTypeEnum[] = "_type_enum";
68constexpr char kDType[] = "dtype";
69constexpr char kBaseDType[] = "base_dtype";
70constexpr char kAsProto[] = "as_proto";
71constexpr char kSerialize[] = "SerializeToString";
72constexpr char kListPrefix[] = "list(";
73constexpr char kPop[] = "pop";
74
75inline py::error_already_set PyTypeError(const std::string& error_msg) {
76 PyErr_SetString(PyExc_TypeError, error_msg.c_str());
77 return pybind11::error_already_set();
78}
79
80inline py::error_already_set PyValueError(const std::string& error_msg) {
81 PyErr_SetString(PyExc_ValueError, error_msg.c_str());
82 return pybind11::error_already_set();
83}
84
85inline std::string PyRepr(const py::handle& value) {
86 return value.attr("__repr__")().cast<std::string>();
87}
88
89py::object DataTypeToPybindObject(const DataType& data_type) {
90 return py::reinterpret_borrow<py::object>(
91 DataTypeToPyObject(data_type).release());
92}
93
94// Converts the py:object to the AttributeType.
95// ToAttributeType corrupts the value's representation when it fails. So this
96// should be stored before hand if it is needed for error msgs.
97py::object ToAttributeType(const py::handle& value, const AttributeType type) {
98 auto result = ConvertPyObjectToAttributeType(value.ptr(), type);
99 if (result == nullptr) {
100 throw std::runtime_error("Failed to perform conversion.");
101 }
102 return py::reinterpret_borrow<py::object>(result.release());
103}
104
105inline bool MakeBool(const py::handle& value, const std::string& arg_name) {
106 if (!py::isinstance<py::bool_>(value)) {
107 throw PyTypeError(
108 absl::StrCat("Expected bool for argument '", arg_name, "' not ",
109 value.attr("__repr__")().cast<std::string>(), "."));
110 }
111 return value.cast<py::bool_>();
112}
113
114inline int MakeInt(const py::handle& value) {
115 try {
116 // Needed for TF1 compatibility where a tf.Dimension may be passed in.
117 return value.attr("value").cast<float>();
118 } catch (...) {
119 return value.cast<float>(); // Cast to float to match Python's behaviour.
120 }
121}
122
123inline DataType MakeType(const py::handle& value, const std::string& arg_name) {
124 std::string repr_v = PyRepr(value);
125 try {
126 return ToAttributeType(value, AttributeType::DTYPE)
127 .attr(kBaseDType)
128 .cast<DataType>();
129 } catch (...) {
130 throw PyTypeError(absl::StrCat("Expected DataType for argument '", arg_name,
131 "' not ", repr_v, "."));
132 }
133}
134
135inline std::string MakeShape(const py::handle& value,
136 const std::string& arg_name) {
137 std::string repr_v = PyRepr(value);
138 try {
139 return ToAttributeType(value, AttributeType::SHAPE)
140 .attr(kAsProto)()
141 .attr(kSerialize)()
142 .cast<std::string>();
143 } catch (...) {
144 throw PyTypeError(absl::StrCat("Error converting ", repr_v, " (arg name = ",
145 arg_name, ") to a TensorShape"));
146 }
147}
148
149AttrValue ValueToAttrValue(const py::object& value,
150 const std::string& attr_type,
151 const std::string& arg_name) {
152 AttrValue attr_value;
153 if (absl::StartsWith(attr_type, kListPrefix)) {
154 if (!py::isinstance<py::list>(value) && !py::isinstance<py::tuple>(value)) {
155 throw PyTypeError(absl::StrCat(
156 "Expected list for attr ", arg_name, ", obtained ",
157 py::type::handle_of(value).attr("__name__").cast<std::string>(),
158 " instead."));
159 }
160 }
161
162 try {
163 const AttributeType type_enum = AttributeTypeFromName(attr_type);
164 switch (type_enum) {
165 case AttributeType::STRING:
166 attr_value.set_s(value.cast<std::string>());
167 break;
168 case AttributeType::LIST_STRING: {
169 auto* list = attr_value.mutable_list();
170 for (const auto& v : value) {
171 list->add_s(v.cast<std::string>());
172 }
173 break;
174 }
175 case AttributeType::INT:
176 attr_value.set_i(MakeInt(value));
177 break;
178 case AttributeType::LIST_INT: {
179 auto* list = attr_value.mutable_list();
180 for (const auto& v : value) {
181 list->add_i(MakeInt(v));
182 }
183 break;
184 }
185 case AttributeType::FLOAT:
186 attr_value.set_f(value.cast<float>());
187 break;
188 case AttributeType::LIST_FLOAT: {
189 auto* list = attr_value.mutable_list();
190 for (const auto& v : value) {
191 list->add_f(v.cast<float>());
192 }
193 break;
194 }
195 case AttributeType::BOOL:
196 attr_value.set_b(MakeBool(value, arg_name));
197 break;
198 case AttributeType::LIST_BOOL: {
199 auto* list = attr_value.mutable_list();
200 for (const auto& v : value) {
201 list->add_b(MakeBool(v, arg_name));
202 }
203 break;
204 }
205 case AttributeType::DTYPE: {
206 attr_value.set_type(MakeType(value, arg_name));
207 break;
208 }
209 case AttributeType::LIST_DTYPE: {
210 auto* list = attr_value.mutable_list();
211 for (const auto& v : value) {
212 list->add_type(MakeType(v, arg_name));
213 }
214 break;
215 }
216 case AttributeType::SHAPE:
217 attr_value.mutable_shape()->ParseFromString(MakeShape(value, arg_name));
218 break;
219 case AttributeType::LIST_SHAPE: {
220 auto* list = attr_value.mutable_list();
221 for (const auto& v : value) {
222 list->add_shape()->ParseFromString(MakeShape(v, arg_name));
223 }
224 break;
225 }
226 case AttributeType::TENSOR:
227 attr_value.mutable_tensor()->ParseFromString(
228 ToAttributeType(value, type_enum)
229 .attr(kSerialize)()
230 .cast<std::string>());
231 break;
232 case AttributeType::LIST_TENSOR: {
233 auto* list = attr_value.mutable_list();
234 for (const auto& v : value) {
235 list->add_tensor()->ParseFromString(
236 ToAttributeType(v, AttributeType::TENSOR)
237 .attr(kSerialize)()
238 .cast<std::string>());
239 }
240 break;
241 }
242 default:
243 throw PyTypeError(absl::StrCat("Unrecognized Attr type ", attr_type,
244 " for ", arg_name, "."));
245 }
246 } catch (const py::error_already_set& e) {
247 throw e;
248 } catch (...) {
249 throw PyTypeError(absl::StrCat(
250 "Expected ", attr_type, " for argument '", arg_name, "' not ",
251 value.attr("__repr__")().cast<std::string>(), "."));
252 }
253
254 return attr_value;
255}
256
257py::object AttrValueToSerializedBytesPyObject(const AttrValue& attr_value) {
258 std::string serialized_attr_value;
259 if (!attr_value.SerializeToString(&serialized_attr_value)) {
260 throw std::runtime_error("Failed to serialized AttrValue to string");
261 }
262 return py::reinterpret_borrow<py::object>(py::bytes(serialized_attr_value));
263}
264
265void AssertSatisfiesLengthConstraint(const py::object& attr,
266 const AttrDef& attr_def,
267 const std::string& attr_name,
268 const std::string& op_type_name) {
269 if (!absl::StartsWith(attr_def.type(), kListPrefix)) return;
270 int attr_size = attr.cast<py::list>().size();
271 if (attr_def.has_minimum() && attr_size < attr_def.minimum()) {
272 throw PyValueError(absl::StrCat("Attr '", attr_name, "' of '", op_type_name,
273 "' Op passed list of length ", attr_size,
274 " less than minimum ", attr_def.minimum(),
275 "."));
276 }
277}
278
279void AssertSatisfiesAllowedStringConstraint(
280 const std::string& attr,
281 const RepeatedPtrField<std::string>& allowed_values,
282 const std::string& attr_name, const std::string& op_type_name) {
283 if (!absl::c_linear_search(allowed_values, attr)) {
284 const std::string allowed_values_str =
285 absl::StrJoin(allowed_values, "\", \"");
286 throw PyValueError(absl::StrCat("Attr '", attr_name, "' of '", op_type_name,
287 "' Op passed string '", attr,
288 "' not in: \"", allowed_values_str, "\"."));
289 }
290}
291
292void AssertSatisfiesAllowedStringsConstraint(const AttrValue& attr,
293 const AttrDef& attr_def,
294 const std::string& attr_name,
295 const AttributeType attr_type,
296 const std::string& op_type_name) {
297 if (!attr_def.has_allowed_values()) return;
298 const auto& allowed_values = attr_def.allowed_values().list().s();
299 if (attr_type == AttributeType::STRING) {
300 AssertSatisfiesAllowedStringConstraint(attr.s(), allowed_values, attr_name,
301 op_type_name);
302 } else if (attr_type == AttributeType::LIST_STRING) {
303 for (const std::string& v : attr.list().s()) {
304 AssertSatisfiesAllowedStringConstraint(v, allowed_values, attr_name,
305 op_type_name);
306 }
307 }
308}
309
310void AssertSatisfiesIntMinimumConstraint(const AttrValue& attr,
311 const AttrDef& attr_def,
312 const std::string& attr_name,
313 const AttributeType attr_type,
314 const std::string& op_type_name) {
315 if (attr_def.has_minimum() && attr_type == AttributeType::INT &&
316 attr.i() < attr_def.minimum()) {
317 throw PyValueError(absl::StrCat(
318 "Attr '", attr_name, "' of '", op_type_name, "' Op passed ", attr.i(),
319 " less than minimum ", attr_def.minimum(), "."));
320 }
321}
322
323void AssertSatisfiesAllowedListAttrTypeConstraint(
324 const std::string& type_attr, const AllowedAttrMap& allowed_list_attr_map,
325 const py::object& dtype, const std::string& input_name) {
326 auto it = allowed_list_attr_map.find(type_attr);
327 if (it != allowed_list_attr_map.end() &&
328 !it->second.contains(dtype.cast<DataType>())) {
329 std::vector<std::string> allowed_values;
330 for (const auto& allowed_value : it->second) {
331 allowed_values.emplace_back(
332 DataTypeToPybindObject(static_cast<DataType>(allowed_value))
333 .attr("name")
334 .cast<std::string>());
335 }
336 throw PyTypeError(absl::StrCat("Value passed to parameter '", input_name,
337 "' has DataType ",
338 dtype.attr("name").cast<std::string>(),
339 " not in list of allowed values: ",
340 absl::StrJoin(allowed_values, ", ")));
341 }
342}
343
344void AssertSatisfiesDTypeConstraint(const int attr,
345 const RepeatedField<int>& allowed_values,
346 const std::string& attr_name,
347 const std::string& op_type_name) {
348 if (!absl::c_linear_search(allowed_values, attr)) {
349 std::string allowed_vals_str;
350 for (const auto& v : allowed_values) {
351 if (!allowed_vals_str.empty()) absl::StrAppend(&allowed_vals_str, ", ");
352 absl::StrAppend(&allowed_vals_str,
353 DataTypeToPybindObject(static_cast<DataType>(v))
354 .attr("name")
355 .cast<std::string>());
356 }
357 throw PyTypeError(absl::StrCat(
358 "Value passed to parameter '", attr_name, "' has DataType ",
359 DataTypeToPybindObject(static_cast<DataType>(attr))
360 .attr("name")
361 .cast<std::string>(),
362 " not in list of allowed values: ", allowed_vals_str));
363 }
364}
365
366void AssertSatisfiesTypeConstraint(const AttrValue& attr,
367 const AttrDef& attr_def,
368 const std::string& attr_name,
369 const AttributeType attr_type,
370 const std::string& op_type_name) {
371 if (!attr_def.has_allowed_values()) return;
372 const auto& allowed_values = attr_def.allowed_values().list().type();
373 if (attr_type == AttributeType::DTYPE) {
374 AssertSatisfiesDTypeConstraint(attr.type(), allowed_values, attr_name,
375 op_type_name);
376 } else if (attr_type == AttributeType::LIST_DTYPE) {
377 for (const auto& v : attr.list().type()) {
378 AssertSatisfiesDTypeConstraint(v, allowed_values, attr_name,
379 op_type_name);
380 }
381 }
382}
383
384// Returns the OpDef from the global registry. Raises runtime_error if the
385// OpDef is not found.
386const OpDef* GetOpDef(const std::string& op_type_name, int producer_version) {
387 const OpDef* op_def = nullptr;
388 auto status = OpRegistry::Global()->LookUpOpDef(op_type_name, &op_def);
389 if (!status.ok() || op_def == nullptr) {
390 throw std::runtime_error(
391 absl::StrCat("Unrecognized Op name ", op_type_name));
392 }
393 return op_def;
394}
395
396// Extracts the default_type_attr_map and the allowed_list_attr_map from the
397// OpDef.
398void ExtractDefaultTypesAndAllowedTypes(const OpDef& op_def,
399 DefaultAttrMap& default_type_attr_map,
400 AllowedAttrMap& allowed_list_attr_map) {
401 for (const AttrDef& attr_def : op_def.attr()) {
402 if (attr_def.type() != kType) continue;
403 const std::string& attr_name = attr_def.name();
404 if (attr_def.has_default_value()) {
405 default_type_attr_map[attr_name] =
406 DataTypeToPybindObject(attr_def.default_value().type());
407 }
408 if (attr_def.has_allowed_values()) {
409 const auto& types = attr_def.allowed_values().list().type();
410 absl::flat_hash_set<int> allowed_values(types.begin(), types.end());
411 allowed_list_attr_map[attr_name] = std::move(allowed_values);
412 }
413 }
414}
415
416// Returns the input Tensor corresponding to `input_name` from `keywords`.
417// Updates `input_name` if it is a Python keyword or built-in.
418py::object GetInputTensor(std::string& input_name, const py::dict& keywords,
419 const OpDef& op_def) {
420 if (keywords.contains(input_name)) {
421 return py::reinterpret_borrow<py::object>(
422 keywords.attr(kPop)(input_name.c_str()));
423 } else if (keywords.contains(absl::StrCat(input_name, "_"))) {
424 absl::StrAppend(&input_name, "_");
425 return py::reinterpret_borrow<py::object>(
426 keywords.attr(kPop)(input_name.c_str()));
427 } else {
428 throw PyTypeError(absl::StrCat("No argument for input ", input_name,
429 " found in ", op_def.DebugString()));
430 }
431}
432
433// Returns the input Tensor's DType.
434py::object GetInputType(
435 const py::object& input_tensor, const ArgDef& input_arg,
436 const AllowedAttrMap& allowed_list_attr_map,
437 const std::string& op_type_name, const std::string& input_name,
438 py::dict& attrs,
439 absl::flat_hash_map<std::string, std::string>& inferred_from) {
440 py::object dtype = input_tensor.attr(kDType);
441 py::object base_type = dtype.attr(kBaseDType);
442
443 // Check that the input_arg and the input are compatible.
444 if (input_arg.type() != DataType::DT_INVALID &&
445 input_arg.type() != dtype.cast<DataType>() &&
446 input_arg.type() != base_type.cast<DataType>()) {
447 throw PyTypeError(absl::StrCat("Input '", input_name, "' of '",
448 op_type_name, "' Op has type ",
449 base_type.attr("name").cast<std::string>(),
450 " that does not match expected type of ",
451 DataTypeToPybindObject(input_arg.type())
452 .attr("name")
453 .cast<std::string>(),
454 "."));
455 }
456
457 const std::string& type_attr = input_arg.type_attr();
458 if (!type_attr.empty()) {
459 if (attrs.contains(type_attr) &&
460 attrs[type_attr.c_str()].cast<py::object>() != base_type) {
461 throw PyTypeError(absl::StrCat(
462 "Input '", input_name, "' of '", op_type_name, "' Op has type ",
463 base_type.attr("name").cast<std::string>(),
464 " that does not match type ",
465 attrs[type_attr.c_str()].attr("name").cast<std::string>(),
466 " of argument '", inferred_from.at(type_attr), "'."));
467 } else {
468 AssertSatisfiesAllowedListAttrTypeConstraint(
469 type_attr, allowed_list_attr_map, base_type, input_name);
470 attrs[type_attr.c_str()] = base_type;
471 inferred_from[input_arg.type_attr()] = input_name;
472 }
473 } else if (base_type.cast<DataType>() != input_arg.type()) {
474 // Added to match the python behaviour.
475 throw PyTypeError("Unreachable");
476 }
477 if (input_arg.is_ref()) return dtype;
478 return base_type;
479}
480
481// Extracts `inputs`, `input_types` and `attrs`.
482void ExtractInputsAndAttrs(const std::string& op_type_name, const OpDef& op_def,
483 const AllowedAttrMap& allowed_list_attr_map,
484 py::dict& keywords, py::dict& attrs,
485 py::list& inputs, py::list& input_types) {
486 absl::flat_hash_map<std::string, std::string> inferred_from;
487 for (const ArgDef& input_arg : op_def.input_arg()) {
488 std::string input_name = input_arg.name();
489 py::object input_tensor = GetInputTensor(input_name, keywords, op_def);
490 inputs.append(input_tensor);
491 py::object dtype =
492 GetInputType(input_tensor, input_arg, allowed_list_attr_map,
493 op_type_name, input_name, attrs, inferred_from);
494 input_types.append(dtype);
495 }
496}
497
498// Extracts the remaining attributes from the OpDef to `attrs`.
499void ExtractRemainingAttrs(const std::string& op_type_name, const OpDef& op_def,
500 const py::dict& keywords,
501 const DefaultAttrMap& default_type_attr_map,
502 py::dict& attrs) {
503 for (const AttrDef& attr : op_def.attr()) {
504 const std::string& attr_name = attr.name();
505 if (attrs.contains(attr_name)) {
506 if (keywords.contains(attr_name)) {
507 throw PyTypeError(
508 absl::StrCat("Should not specify value for inferred attr '",
509 attr_name, "' for ", op_type_name, "."));
510 }
511 continue;
512 }
513 if (keywords.contains(attr_name)) {
514 attrs[attr_name.c_str()] =
515 keywords.attr(kPop)(attr_name.c_str()).cast<py::object>();
516 } else if (keywords.contains(absl::StrCat(attr_name, "_"))) {
517 attrs[attr_name.c_str()] =
518 keywords.attr(kPop)(absl::StrCat(attr_name, "_").c_str())
519 .cast<py::object>();
520 } else if (default_type_attr_map.contains(attr_name)) {
521 attrs[attr_name.c_str()] = default_type_attr_map.at(attr_name);
522 } else {
523 throw PyTypeError(absl::StrCat("No argument found for attr ", attr_name,
524 " for ", op_type_name));
525 }
526 }
527}
528
529void SetAttrProto(const std::string& key, const AttrValue& value,
530 py::dict& attr_protos, AttrProtosMap& attr_protos_map) {
531 attr_protos_map[key] = value;
532 attr_protos[key.c_str()] = AttrValueToSerializedBytesPyObject(value);
533}
534
535// Converts attr values to AttrValues.
536void ExtractAttrProto(const std::string& op_type_name, const OpDef& op_def,
537 const py::dict& attrs, py::dict& attr_protos,
538 AttrProtosMap& attr_protos_map) {
539 for (const AttrDef& attr_def : op_def.attr()) {
540 const std::string& attr_name = attr_def.name();
541 const py::object attr = attrs[attr_name.c_str()].cast<py::object>();
542
543 if (attr_def.has_default_value() && attr.is_none()) {
544 SetAttrProto(attr_name, attr_def.default_value(), attr_protos,
545 attr_protos_map);
546 continue;
547 }
548
549 const AttrValue attr_value =
550 ValueToAttrValue(attr, attr_def.type(), attr_name);
551 const AttributeType attr_type = AttributeTypeFromName(attr_def.type());
552 AssertSatisfiesLengthConstraint(attr, attr_def, attr_name, op_type_name);
553 AssertSatisfiesAllowedStringsConstraint(attr_value, attr_def, attr_name,
554 attr_type, op_type_name);
555 AssertSatisfiesIntMinimumConstraint(attr_value, attr_def, attr_name,
556 attr_type, op_type_name);
557 AssertSatisfiesTypeConstraint(attr_value, attr_def, attr_name, attr_type,
558 op_type_name);
559 SetAttrProto(attr_name, attr_value, attr_protos, attr_protos_map);
560 }
561}
562
563inline const AttrValue& MaybeGetAttrValue(const py::dict& attr_protos,
564 const AttrProtosMap& attr_protos_map,
565 const std::string& attr_name,
566 const std::string& op_type_name) {
567 auto it = attr_protos_map.find(attr_name);
568 if (it != attr_protos_map.end()) return it->second;
569 throw PyTypeError(absl::StrCat(
570 "Inconsistent OpDef for '", op_type_name, "', missing attr '", attr_name,
571 "' from '", attr_protos.attr("__repr__")().cast<std::string>(), "'."));
572}
573
574void ExtractOutputStructure(const std::string& op_type_name,
575 const OpDef& op_def, const py::dict& attr_protos,
576 const AttrProtosMap& attr_protos_map,
577 py::list& output_structure) {
578 for (const ArgDef& arg : op_def.output_arg()) {
579 if (!arg.number_attr().empty()) {
580 const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map,
581 arg.number_attr(), op_type_name);
582 output_structure.append(value.i());
583 } else if (!arg.type_attr().empty()) {
584 const auto& _ = MaybeGetAttrValue(attr_protos, attr_protos_map,
585 arg.type_attr(), op_type_name);
586 output_structure.append(py::none());
587 } else if (!arg.type_list_attr().empty()) {
588 const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map,
589 arg.type_list_attr(), op_type_name);
590 output_structure.append(value.list().type_size());
591 } else {
592 output_structure.append(py::none());
593 }
594 }
595}
596
597void CheckAllInputsUsed(const std::string& op_type_name,
598 const py::dict& keywords) {
599 if (!keywords.empty()) {
600 std::string all_keywords;
601 for (const auto& item : keywords) {
602 if (!all_keywords.empty()) absl::StrAppend(&all_keywords, ", ");
603 absl::StrAppend(&all_keywords, item.first.cast<std::string>());
604 }
605 throw PyTypeError(absl::StrCat(
606 op_type_name, " got unexpected keyword arguments: ", all_keywords));
607 }
608}
609
610} // namespace
611
612// This module provides a subset of the functionality from op_def_library.py
613// and relies on op_def_library_test.py for test coverage.
614PYBIND11_MODULE(_op_def_library_pybind, m) {
615 // Method assumes all inputs in `keywords` are of type tf.Tensor.
616 m.def("process_inputs", [](std::string& op_type_name, int producer_version,
617 py::dict& keywords) {
618 const OpDef* op_def = GetOpDef(op_type_name, producer_version);
619 MaybeRaiseFromStatus(CheckOpDeprecation(*op_def, producer_version));
620
621 DefaultAttrMap default_type_attr_map;
622 AllowedAttrMap allowed_list_attr_map;
623 AttrProtosMap attr_protos_map;
624 py::dict attrs, attr_protos;
625 py::list inputs, input_types, output_structure;
626
627 ExtractDefaultTypesAndAllowedTypes(*op_def, default_type_attr_map,
628 allowed_list_attr_map);
629 ExtractInputsAndAttrs(op_type_name, *op_def, allowed_list_attr_map,
630 keywords, attrs, inputs, input_types);
631 ExtractRemainingAttrs(op_type_name, *op_def, keywords,
632 default_type_attr_map, attrs);
633 ExtractAttrProto(op_type_name, *op_def, attrs, attr_protos,
634 attr_protos_map);
635 ExtractOutputStructure(op_type_name, *op_def, attr_protos, attr_protos_map,
636 output_structure);
637 CheckAllInputsUsed(op_type_name, keywords);
638
639 return py::make_tuple(attr_protos, inputs, input_types, output_structure);
640 });
641};
642