1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cmath> |
17 | #include <stdexcept> |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "Python.h" |
22 | #include "absl/algorithm/container.h" |
23 | #include "absl/container/flat_hash_map.h" |
24 | #include "absl/container/flat_hash_set.h" |
25 | #include "absl/strings/match.h" |
26 | #include "absl/strings/str_cat.h" |
27 | #include "absl/strings/str_join.h" |
28 | #include "pybind11/cast.h" |
29 | #include "pybind11/pybind11.h" |
30 | #include "pybind11/pytypes.h" |
31 | #include "tensorflow/core/framework/attr_value.pb.h" |
32 | #include "tensorflow/core/framework/op.h" |
33 | #include "tensorflow/core/framework/op_def.pb.h" |
34 | #include "tensorflow/core/framework/op_def_util.h" |
35 | #include "tensorflow/core/framework/types.pb.h" |
36 | #include "tensorflow/core/platform/protobuf.h" |
37 | #include "tensorflow/python/framework/op_def_util.h" |
38 | #include "tensorflow/python/lib/core/pybind11_status.h" |
39 | |
40 | namespace py = pybind11; |
41 | |
42 | namespace { |
43 | |
44 | using ::tensorflow::AttributeType; |
45 | using ::tensorflow::AttributeTypeFromName; |
46 | using ::tensorflow::AttrValue; |
47 | using ::tensorflow::CheckOpDeprecation; |
48 | using ::tensorflow::ConvertPyObjectToAttributeType; |
49 | using ::tensorflow::DataType; |
50 | using ::tensorflow::DataTypeToPyObject; |
51 | using ::tensorflow::MaybeRaiseFromStatus; |
52 | using ::tensorflow::OpDef; |
53 | using ::tensorflow::OpRegistry; |
54 | using ::tensorflow::protobuf::RepeatedField; |
55 | using ::tensorflow::protobuf::RepeatedPtrField; |
56 | using AttrDef = ::tensorflow::OpDef::AttrDef; |
57 | using ArgDef = ::tensorflow::OpDef::ArgDef; |
58 | // Keys: attr.name(); Values: attr_def.allowed_values().list().type() |
59 | using AllowedAttrMap = |
60 | absl::flat_hash_map<std::string, absl::flat_hash_set<int>>; |
61 | // Keys: attr.name(); Values; attr_def.default_value().type() |
62 | using DefaultAttrMap = absl::flat_hash_map<std::string, py::object>; |
63 | // Keys: attr.name(); Values: corresponding attr serialized as an AttrValue |
64 | using AttrProtosMap = absl::flat_hash_map<std::string, AttrValue>; |
65 | |
66 | constexpr char kType[] = "type" ; |
67 | constexpr char kTypeEnum[] = "_type_enum" ; |
68 | constexpr char kDType[] = "dtype" ; |
69 | constexpr char kBaseDType[] = "base_dtype" ; |
70 | constexpr char kAsProto[] = "as_proto" ; |
71 | constexpr char kSerialize[] = "SerializeToString" ; |
72 | constexpr char kListPrefix[] = "list(" ; |
73 | constexpr char kPop[] = "pop" ; |
74 | |
75 | inline py::error_already_set PyTypeError(const std::string& error_msg) { |
76 | PyErr_SetString(PyExc_TypeError, error_msg.c_str()); |
77 | return pybind11::error_already_set(); |
78 | } |
79 | |
80 | inline py::error_already_set PyValueError(const std::string& error_msg) { |
81 | PyErr_SetString(PyExc_ValueError, error_msg.c_str()); |
82 | return pybind11::error_already_set(); |
83 | } |
84 | |
85 | inline std::string PyRepr(const py::handle& value) { |
86 | return value.attr("__repr__" )().cast<std::string>(); |
87 | } |
88 | |
89 | py::object DataTypeToPybindObject(const DataType& data_type) { |
90 | return py::reinterpret_borrow<py::object>( |
91 | DataTypeToPyObject(data_type).release()); |
92 | } |
93 | |
94 | // Converts the py:object to the AttributeType. |
95 | // ToAttributeType corrupts the value's representation when it fails. So this |
96 | // should be stored before hand if it is needed for error msgs. |
97 | py::object ToAttributeType(const py::handle& value, const AttributeType type) { |
98 | auto result = ConvertPyObjectToAttributeType(value.ptr(), type); |
99 | if (result == nullptr) { |
100 | throw std::runtime_error("Failed to perform conversion." ); |
101 | } |
102 | return py::reinterpret_borrow<py::object>(result.release()); |
103 | } |
104 | |
105 | inline bool MakeBool(const py::handle& value, const std::string& arg_name) { |
106 | if (!py::isinstance<py::bool_>(value)) { |
107 | throw PyTypeError( |
108 | absl::StrCat("Expected bool for argument '" , arg_name, "' not " , |
109 | value.attr("__repr__" )().cast<std::string>(), "." )); |
110 | } |
111 | return value.cast<py::bool_>(); |
112 | } |
113 | |
114 | inline int MakeInt(const py::handle& value) { |
115 | try { |
116 | // Needed for TF1 compatibility where a tf.Dimension may be passed in. |
117 | return value.attr("value" ).cast<float>(); |
118 | } catch (...) { |
119 | return value.cast<float>(); // Cast to float to match Python's behaviour. |
120 | } |
121 | } |
122 | |
123 | inline DataType MakeType(const py::handle& value, const std::string& arg_name) { |
124 | std::string repr_v = PyRepr(value); |
125 | try { |
126 | return ToAttributeType(value, AttributeType::DTYPE) |
127 | .attr(kBaseDType) |
128 | .cast<DataType>(); |
129 | } catch (...) { |
130 | throw PyTypeError(absl::StrCat("Expected DataType for argument '" , arg_name, |
131 | "' not " , repr_v, "." )); |
132 | } |
133 | } |
134 | |
135 | inline std::string MakeShape(const py::handle& value, |
136 | const std::string& arg_name) { |
137 | std::string repr_v = PyRepr(value); |
138 | try { |
139 | return ToAttributeType(value, AttributeType::SHAPE) |
140 | .attr(kAsProto)() |
141 | .attr(kSerialize)() |
142 | .cast<std::string>(); |
143 | } catch (...) { |
144 | throw PyTypeError(absl::StrCat("Error converting " , repr_v, " (arg name = " , |
145 | arg_name, ") to a TensorShape" )); |
146 | } |
147 | } |
148 | |
149 | AttrValue ValueToAttrValue(const py::object& value, |
150 | const std::string& attr_type, |
151 | const std::string& arg_name) { |
152 | AttrValue attr_value; |
153 | if (absl::StartsWith(attr_type, kListPrefix)) { |
154 | if (!py::isinstance<py::list>(value) && !py::isinstance<py::tuple>(value)) { |
155 | throw PyTypeError(absl::StrCat( |
156 | "Expected list for attr " , arg_name, ", obtained " , |
157 | py::type::handle_of(value).attr("__name__" ).cast<std::string>(), |
158 | " instead." )); |
159 | } |
160 | } |
161 | |
162 | try { |
163 | const AttributeType type_enum = AttributeTypeFromName(attr_type); |
164 | switch (type_enum) { |
165 | case AttributeType::STRING: |
166 | attr_value.set_s(value.cast<std::string>()); |
167 | break; |
168 | case AttributeType::LIST_STRING: { |
169 | auto* list = attr_value.mutable_list(); |
170 | for (const auto& v : value) { |
171 | list->add_s(v.cast<std::string>()); |
172 | } |
173 | break; |
174 | } |
175 | case AttributeType::INT: |
176 | attr_value.set_i(MakeInt(value)); |
177 | break; |
178 | case AttributeType::LIST_INT: { |
179 | auto* list = attr_value.mutable_list(); |
180 | for (const auto& v : value) { |
181 | list->add_i(MakeInt(v)); |
182 | } |
183 | break; |
184 | } |
185 | case AttributeType::FLOAT: |
186 | attr_value.set_f(value.cast<float>()); |
187 | break; |
188 | case AttributeType::LIST_FLOAT: { |
189 | auto* list = attr_value.mutable_list(); |
190 | for (const auto& v : value) { |
191 | list->add_f(v.cast<float>()); |
192 | } |
193 | break; |
194 | } |
195 | case AttributeType::BOOL: |
196 | attr_value.set_b(MakeBool(value, arg_name)); |
197 | break; |
198 | case AttributeType::LIST_BOOL: { |
199 | auto* list = attr_value.mutable_list(); |
200 | for (const auto& v : value) { |
201 | list->add_b(MakeBool(v, arg_name)); |
202 | } |
203 | break; |
204 | } |
205 | case AttributeType::DTYPE: { |
206 | attr_value.set_type(MakeType(value, arg_name)); |
207 | break; |
208 | } |
209 | case AttributeType::LIST_DTYPE: { |
210 | auto* list = attr_value.mutable_list(); |
211 | for (const auto& v : value) { |
212 | list->add_type(MakeType(v, arg_name)); |
213 | } |
214 | break; |
215 | } |
216 | case AttributeType::SHAPE: |
217 | attr_value.mutable_shape()->ParseFromString(MakeShape(value, arg_name)); |
218 | break; |
219 | case AttributeType::LIST_SHAPE: { |
220 | auto* list = attr_value.mutable_list(); |
221 | for (const auto& v : value) { |
222 | list->add_shape()->ParseFromString(MakeShape(v, arg_name)); |
223 | } |
224 | break; |
225 | } |
226 | case AttributeType::TENSOR: |
227 | attr_value.mutable_tensor()->ParseFromString( |
228 | ToAttributeType(value, type_enum) |
229 | .attr(kSerialize)() |
230 | .cast<std::string>()); |
231 | break; |
232 | case AttributeType::LIST_TENSOR: { |
233 | auto* list = attr_value.mutable_list(); |
234 | for (const auto& v : value) { |
235 | list->add_tensor()->ParseFromString( |
236 | ToAttributeType(v, AttributeType::TENSOR) |
237 | .attr(kSerialize)() |
238 | .cast<std::string>()); |
239 | } |
240 | break; |
241 | } |
242 | default: |
243 | throw PyTypeError(absl::StrCat("Unrecognized Attr type " , attr_type, |
244 | " for " , arg_name, "." )); |
245 | } |
246 | } catch (const py::error_already_set& e) { |
247 | throw e; |
248 | } catch (...) { |
249 | throw PyTypeError(absl::StrCat( |
250 | "Expected " , attr_type, " for argument '" , arg_name, "' not " , |
251 | value.attr("__repr__" )().cast<std::string>(), "." )); |
252 | } |
253 | |
254 | return attr_value; |
255 | } |
256 | |
257 | py::object AttrValueToSerializedBytesPyObject(const AttrValue& attr_value) { |
258 | std::string serialized_attr_value; |
259 | if (!attr_value.SerializeToString(&serialized_attr_value)) { |
260 | throw std::runtime_error("Failed to serialized AttrValue to string" ); |
261 | } |
262 | return py::reinterpret_borrow<py::object>(py::bytes(serialized_attr_value)); |
263 | } |
264 | |
265 | void AssertSatisfiesLengthConstraint(const py::object& attr, |
266 | const AttrDef& attr_def, |
267 | const std::string& attr_name, |
268 | const std::string& op_type_name) { |
269 | if (!absl::StartsWith(attr_def.type(), kListPrefix)) return; |
270 | int attr_size = attr.cast<py::list>().size(); |
271 | if (attr_def.has_minimum() && attr_size < attr_def.minimum()) { |
272 | throw PyValueError(absl::StrCat("Attr '" , attr_name, "' of '" , op_type_name, |
273 | "' Op passed list of length " , attr_size, |
274 | " less than minimum " , attr_def.minimum(), |
275 | "." )); |
276 | } |
277 | } |
278 | |
279 | void AssertSatisfiesAllowedStringConstraint( |
280 | const std::string& attr, |
281 | const RepeatedPtrField<std::string>& allowed_values, |
282 | const std::string& attr_name, const std::string& op_type_name) { |
283 | if (!absl::c_linear_search(allowed_values, attr)) { |
284 | const std::string allowed_values_str = |
285 | absl::StrJoin(allowed_values, "\", \"" ); |
286 | throw PyValueError(absl::StrCat("Attr '" , attr_name, "' of '" , op_type_name, |
287 | "' Op passed string '" , attr, |
288 | "' not in: \"" , allowed_values_str, "\"." )); |
289 | } |
290 | } |
291 | |
292 | void AssertSatisfiesAllowedStringsConstraint(const AttrValue& attr, |
293 | const AttrDef& attr_def, |
294 | const std::string& attr_name, |
295 | const AttributeType attr_type, |
296 | const std::string& op_type_name) { |
297 | if (!attr_def.has_allowed_values()) return; |
298 | const auto& allowed_values = attr_def.allowed_values().list().s(); |
299 | if (attr_type == AttributeType::STRING) { |
300 | AssertSatisfiesAllowedStringConstraint(attr.s(), allowed_values, attr_name, |
301 | op_type_name); |
302 | } else if (attr_type == AttributeType::LIST_STRING) { |
303 | for (const std::string& v : attr.list().s()) { |
304 | AssertSatisfiesAllowedStringConstraint(v, allowed_values, attr_name, |
305 | op_type_name); |
306 | } |
307 | } |
308 | } |
309 | |
310 | void AssertSatisfiesIntMinimumConstraint(const AttrValue& attr, |
311 | const AttrDef& attr_def, |
312 | const std::string& attr_name, |
313 | const AttributeType attr_type, |
314 | const std::string& op_type_name) { |
315 | if (attr_def.has_minimum() && attr_type == AttributeType::INT && |
316 | attr.i() < attr_def.minimum()) { |
317 | throw PyValueError(absl::StrCat( |
318 | "Attr '" , attr_name, "' of '" , op_type_name, "' Op passed " , attr.i(), |
319 | " less than minimum " , attr_def.minimum(), "." )); |
320 | } |
321 | } |
322 | |
323 | void AssertSatisfiesAllowedListAttrTypeConstraint( |
324 | const std::string& type_attr, const AllowedAttrMap& allowed_list_attr_map, |
325 | const py::object& dtype, const std::string& input_name) { |
326 | auto it = allowed_list_attr_map.find(type_attr); |
327 | if (it != allowed_list_attr_map.end() && |
328 | !it->second.contains(dtype.cast<DataType>())) { |
329 | std::vector<std::string> allowed_values; |
330 | for (const auto& allowed_value : it->second) { |
331 | allowed_values.emplace_back( |
332 | DataTypeToPybindObject(static_cast<DataType>(allowed_value)) |
333 | .attr("name" ) |
334 | .cast<std::string>()); |
335 | } |
336 | throw PyTypeError(absl::StrCat("Value passed to parameter '" , input_name, |
337 | "' has DataType " , |
338 | dtype.attr("name" ).cast<std::string>(), |
339 | " not in list of allowed values: " , |
340 | absl::StrJoin(allowed_values, ", " ))); |
341 | } |
342 | } |
343 | |
344 | void AssertSatisfiesDTypeConstraint(const int attr, |
345 | const RepeatedField<int>& allowed_values, |
346 | const std::string& attr_name, |
347 | const std::string& op_type_name) { |
348 | if (!absl::c_linear_search(allowed_values, attr)) { |
349 | std::string allowed_vals_str; |
350 | for (const auto& v : allowed_values) { |
351 | if (!allowed_vals_str.empty()) absl::StrAppend(&allowed_vals_str, ", " ); |
352 | absl::StrAppend(&allowed_vals_str, |
353 | DataTypeToPybindObject(static_cast<DataType>(v)) |
354 | .attr("name" ) |
355 | .cast<std::string>()); |
356 | } |
357 | throw PyTypeError(absl::StrCat( |
358 | "Value passed to parameter '" , attr_name, "' has DataType " , |
359 | DataTypeToPybindObject(static_cast<DataType>(attr)) |
360 | .attr("name" ) |
361 | .cast<std::string>(), |
362 | " not in list of allowed values: " , allowed_vals_str)); |
363 | } |
364 | } |
365 | |
366 | void AssertSatisfiesTypeConstraint(const AttrValue& attr, |
367 | const AttrDef& attr_def, |
368 | const std::string& attr_name, |
369 | const AttributeType attr_type, |
370 | const std::string& op_type_name) { |
371 | if (!attr_def.has_allowed_values()) return; |
372 | const auto& allowed_values = attr_def.allowed_values().list().type(); |
373 | if (attr_type == AttributeType::DTYPE) { |
374 | AssertSatisfiesDTypeConstraint(attr.type(), allowed_values, attr_name, |
375 | op_type_name); |
376 | } else if (attr_type == AttributeType::LIST_DTYPE) { |
377 | for (const auto& v : attr.list().type()) { |
378 | AssertSatisfiesDTypeConstraint(v, allowed_values, attr_name, |
379 | op_type_name); |
380 | } |
381 | } |
382 | } |
383 | |
384 | // Returns the OpDef from the global registry. Raises runtime_error if the |
385 | // OpDef is not found. |
386 | const OpDef* GetOpDef(const std::string& op_type_name, int producer_version) { |
387 | const OpDef* op_def = nullptr; |
388 | auto status = OpRegistry::Global()->LookUpOpDef(op_type_name, &op_def); |
389 | if (!status.ok() || op_def == nullptr) { |
390 | throw std::runtime_error( |
391 | absl::StrCat("Unrecognized Op name " , op_type_name)); |
392 | } |
393 | return op_def; |
394 | } |
395 | |
396 | // Extracts the default_type_attr_map and the allowed_list_attr_map from the |
397 | // OpDef. |
398 | void ExtractDefaultTypesAndAllowedTypes(const OpDef& op_def, |
399 | DefaultAttrMap& default_type_attr_map, |
400 | AllowedAttrMap& allowed_list_attr_map) { |
401 | for (const AttrDef& attr_def : op_def.attr()) { |
402 | if (attr_def.type() != kType) continue; |
403 | const std::string& attr_name = attr_def.name(); |
404 | if (attr_def.has_default_value()) { |
405 | default_type_attr_map[attr_name] = |
406 | DataTypeToPybindObject(attr_def.default_value().type()); |
407 | } |
408 | if (attr_def.has_allowed_values()) { |
409 | const auto& types = attr_def.allowed_values().list().type(); |
410 | absl::flat_hash_set<int> allowed_values(types.begin(), types.end()); |
411 | allowed_list_attr_map[attr_name] = std::move(allowed_values); |
412 | } |
413 | } |
414 | } |
415 | |
416 | // Returns the input Tensor corresponding to `input_name` from `keywords`. |
417 | // Updates `input_name` if it is a Python keyword or built-in. |
418 | py::object GetInputTensor(std::string& input_name, const py::dict& keywords, |
419 | const OpDef& op_def) { |
420 | if (keywords.contains(input_name)) { |
421 | return py::reinterpret_borrow<py::object>( |
422 | keywords.attr(kPop)(input_name.c_str())); |
423 | } else if (keywords.contains(absl::StrCat(input_name, "_" ))) { |
424 | absl::StrAppend(&input_name, "_" ); |
425 | return py::reinterpret_borrow<py::object>( |
426 | keywords.attr(kPop)(input_name.c_str())); |
427 | } else { |
428 | throw PyTypeError(absl::StrCat("No argument for input " , input_name, |
429 | " found in " , op_def.DebugString())); |
430 | } |
431 | } |
432 | |
433 | // Returns the input Tensor's DType. |
434 | py::object GetInputType( |
435 | const py::object& input_tensor, const ArgDef& input_arg, |
436 | const AllowedAttrMap& allowed_list_attr_map, |
437 | const std::string& op_type_name, const std::string& input_name, |
438 | py::dict& attrs, |
439 | absl::flat_hash_map<std::string, std::string>& inferred_from) { |
440 | py::object dtype = input_tensor.attr(kDType); |
441 | py::object base_type = dtype.attr(kBaseDType); |
442 | |
443 | // Check that the input_arg and the input are compatible. |
444 | if (input_arg.type() != DataType::DT_INVALID && |
445 | input_arg.type() != dtype.cast<DataType>() && |
446 | input_arg.type() != base_type.cast<DataType>()) { |
447 | throw PyTypeError(absl::StrCat("Input '" , input_name, "' of '" , |
448 | op_type_name, "' Op has type " , |
449 | base_type.attr("name" ).cast<std::string>(), |
450 | " that does not match expected type of " , |
451 | DataTypeToPybindObject(input_arg.type()) |
452 | .attr("name" ) |
453 | .cast<std::string>(), |
454 | "." )); |
455 | } |
456 | |
457 | const std::string& type_attr = input_arg.type_attr(); |
458 | if (!type_attr.empty()) { |
459 | if (attrs.contains(type_attr) && |
460 | attrs[type_attr.c_str()].cast<py::object>() != base_type) { |
461 | throw PyTypeError(absl::StrCat( |
462 | "Input '" , input_name, "' of '" , op_type_name, "' Op has type " , |
463 | base_type.attr("name" ).cast<std::string>(), |
464 | " that does not match type " , |
465 | attrs[type_attr.c_str()].attr("name" ).cast<std::string>(), |
466 | " of argument '" , inferred_from.at(type_attr), "'." )); |
467 | } else { |
468 | AssertSatisfiesAllowedListAttrTypeConstraint( |
469 | type_attr, allowed_list_attr_map, base_type, input_name); |
470 | attrs[type_attr.c_str()] = base_type; |
471 | inferred_from[input_arg.type_attr()] = input_name; |
472 | } |
473 | } else if (base_type.cast<DataType>() != input_arg.type()) { |
474 | // Added to match the python behaviour. |
475 | throw PyTypeError("Unreachable" ); |
476 | } |
477 | if (input_arg.is_ref()) return dtype; |
478 | return base_type; |
479 | } |
480 | |
481 | // Extracts `inputs`, `input_types` and `attrs`. |
482 | void ExtractInputsAndAttrs(const std::string& op_type_name, const OpDef& op_def, |
483 | const AllowedAttrMap& allowed_list_attr_map, |
484 | py::dict& keywords, py::dict& attrs, |
485 | py::list& inputs, py::list& input_types) { |
486 | absl::flat_hash_map<std::string, std::string> inferred_from; |
487 | for (const ArgDef& input_arg : op_def.input_arg()) { |
488 | std::string input_name = input_arg.name(); |
489 | py::object input_tensor = GetInputTensor(input_name, keywords, op_def); |
490 | inputs.append(input_tensor); |
491 | py::object dtype = |
492 | GetInputType(input_tensor, input_arg, allowed_list_attr_map, |
493 | op_type_name, input_name, attrs, inferred_from); |
494 | input_types.append(dtype); |
495 | } |
496 | } |
497 | |
498 | // Extracts the remaining attributes from the OpDef to `attrs`. |
499 | void ExtractRemainingAttrs(const std::string& op_type_name, const OpDef& op_def, |
500 | const py::dict& keywords, |
501 | const DefaultAttrMap& default_type_attr_map, |
502 | py::dict& attrs) { |
503 | for (const AttrDef& attr : op_def.attr()) { |
504 | const std::string& attr_name = attr.name(); |
505 | if (attrs.contains(attr_name)) { |
506 | if (keywords.contains(attr_name)) { |
507 | throw PyTypeError( |
508 | absl::StrCat("Should not specify value for inferred attr '" , |
509 | attr_name, "' for " , op_type_name, "." )); |
510 | } |
511 | continue; |
512 | } |
513 | if (keywords.contains(attr_name)) { |
514 | attrs[attr_name.c_str()] = |
515 | keywords.attr(kPop)(attr_name.c_str()).cast<py::object>(); |
516 | } else if (keywords.contains(absl::StrCat(attr_name, "_" ))) { |
517 | attrs[attr_name.c_str()] = |
518 | keywords.attr(kPop)(absl::StrCat(attr_name, "_" ).c_str()) |
519 | .cast<py::object>(); |
520 | } else if (default_type_attr_map.contains(attr_name)) { |
521 | attrs[attr_name.c_str()] = default_type_attr_map.at(attr_name); |
522 | } else { |
523 | throw PyTypeError(absl::StrCat("No argument found for attr " , attr_name, |
524 | " for " , op_type_name)); |
525 | } |
526 | } |
527 | } |
528 | |
529 | void SetAttrProto(const std::string& key, const AttrValue& value, |
530 | py::dict& attr_protos, AttrProtosMap& attr_protos_map) { |
531 | attr_protos_map[key] = value; |
532 | attr_protos[key.c_str()] = AttrValueToSerializedBytesPyObject(value); |
533 | } |
534 | |
535 | // Converts attr values to AttrValues. |
536 | void (const std::string& op_type_name, const OpDef& op_def, |
537 | const py::dict& attrs, py::dict& attr_protos, |
538 | AttrProtosMap& attr_protos_map) { |
539 | for (const AttrDef& attr_def : op_def.attr()) { |
540 | const std::string& attr_name = attr_def.name(); |
541 | const py::object attr = attrs[attr_name.c_str()].cast<py::object>(); |
542 | |
543 | if (attr_def.has_default_value() && attr.is_none()) { |
544 | SetAttrProto(attr_name, attr_def.default_value(), attr_protos, |
545 | attr_protos_map); |
546 | continue; |
547 | } |
548 | |
549 | const AttrValue attr_value = |
550 | ValueToAttrValue(attr, attr_def.type(), attr_name); |
551 | const AttributeType attr_type = AttributeTypeFromName(attr_def.type()); |
552 | AssertSatisfiesLengthConstraint(attr, attr_def, attr_name, op_type_name); |
553 | AssertSatisfiesAllowedStringsConstraint(attr_value, attr_def, attr_name, |
554 | attr_type, op_type_name); |
555 | AssertSatisfiesIntMinimumConstraint(attr_value, attr_def, attr_name, |
556 | attr_type, op_type_name); |
557 | AssertSatisfiesTypeConstraint(attr_value, attr_def, attr_name, attr_type, |
558 | op_type_name); |
559 | SetAttrProto(attr_name, attr_value, attr_protos, attr_protos_map); |
560 | } |
561 | } |
562 | |
563 | inline const AttrValue& MaybeGetAttrValue(const py::dict& attr_protos, |
564 | const AttrProtosMap& attr_protos_map, |
565 | const std::string& attr_name, |
566 | const std::string& op_type_name) { |
567 | auto it = attr_protos_map.find(attr_name); |
568 | if (it != attr_protos_map.end()) return it->second; |
569 | throw PyTypeError(absl::StrCat( |
570 | "Inconsistent OpDef for '" , op_type_name, "', missing attr '" , attr_name, |
571 | "' from '" , attr_protos.attr("__repr__" )().cast<std::string>(), "'." )); |
572 | } |
573 | |
574 | void (const std::string& op_type_name, |
575 | const OpDef& op_def, const py::dict& attr_protos, |
576 | const AttrProtosMap& attr_protos_map, |
577 | py::list& output_structure) { |
578 | for (const ArgDef& arg : op_def.output_arg()) { |
579 | if (!arg.number_attr().empty()) { |
580 | const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map, |
581 | arg.number_attr(), op_type_name); |
582 | output_structure.append(value.i()); |
583 | } else if (!arg.type_attr().empty()) { |
584 | const auto& _ = MaybeGetAttrValue(attr_protos, attr_protos_map, |
585 | arg.type_attr(), op_type_name); |
586 | output_structure.append(py::none()); |
587 | } else if (!arg.type_list_attr().empty()) { |
588 | const auto& value = MaybeGetAttrValue(attr_protos, attr_protos_map, |
589 | arg.type_list_attr(), op_type_name); |
590 | output_structure.append(value.list().type_size()); |
591 | } else { |
592 | output_structure.append(py::none()); |
593 | } |
594 | } |
595 | } |
596 | |
597 | void CheckAllInputsUsed(const std::string& op_type_name, |
598 | const py::dict& keywords) { |
599 | if (!keywords.empty()) { |
600 | std::string all_keywords; |
601 | for (const auto& item : keywords) { |
602 | if (!all_keywords.empty()) absl::StrAppend(&all_keywords, ", " ); |
603 | absl::StrAppend(&all_keywords, item.first.cast<std::string>()); |
604 | } |
605 | throw PyTypeError(absl::StrCat( |
606 | op_type_name, " got unexpected keyword arguments: " , all_keywords)); |
607 | } |
608 | } |
609 | |
610 | } // namespace |
611 | |
612 | // This module provides a subset of the functionality from op_def_library.py |
613 | // and relies on op_def_library_test.py for test coverage. |
614 | PYBIND11_MODULE(_op_def_library_pybind, m) { |
615 | // Method assumes all inputs in `keywords` are of type tf.Tensor. |
616 | m.def("process_inputs" , [](std::string& op_type_name, int producer_version, |
617 | py::dict& keywords) { |
618 | const OpDef* op_def = GetOpDef(op_type_name, producer_version); |
619 | MaybeRaiseFromStatus(CheckOpDeprecation(*op_def, producer_version)); |
620 | |
621 | DefaultAttrMap default_type_attr_map; |
622 | AllowedAttrMap allowed_list_attr_map; |
623 | AttrProtosMap attr_protos_map; |
624 | py::dict attrs, attr_protos; |
625 | py::list inputs, input_types, output_structure; |
626 | |
627 | ExtractDefaultTypesAndAllowedTypes(*op_def, default_type_attr_map, |
628 | allowed_list_attr_map); |
629 | ExtractInputsAndAttrs(op_type_name, *op_def, allowed_list_attr_map, |
630 | keywords, attrs, inputs, input_types); |
631 | ExtractRemainingAttrs(op_type_name, *op_def, keywords, |
632 | default_type_attr_map, attrs); |
633 | ExtractAttrProto(op_type_name, *op_def, attrs, attr_protos, |
634 | attr_protos_map); |
635 | ExtractOutputStructure(op_type_name, *op_def, attr_protos, attr_protos_map, |
636 | output_structure); |
637 | CheckAllInputsUsed(op_type_name, keywords); |
638 | |
639 | return py::make_tuple(attr_protos, inputs, input_types, output_structure); |
640 | }); |
641 | }; |
642 | |