1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "pybind11/detail/common.h" |
17 | #include "pybind11/pybind11.h" |
18 | #include "tensorflow/core/framework/types.h" |
19 | #include "tensorflow/core/framework/types.pb.h" |
20 | |
21 | namespace { |
22 | |
23 | inline int DataTypeId(tensorflow::DataType dt) { return static_cast<int>(dt); } |
24 | |
25 | // A variant of tensorflow::DataTypeString which uses fixed-width names |
26 | // for floating point data types. This behavior is compatible with that of |
27 | // existing pure Python DType. |
28 | const std::string DataTypeStringCompat(tensorflow::DataType dt) { |
29 | switch (dt) { |
30 | case tensorflow::DataType::DT_HALF: |
31 | return "float16" ; |
32 | case tensorflow::DataType::DT_HALF_REF: |
33 | return "float16_ref" ; |
34 | case tensorflow::DataType::DT_FLOAT: |
35 | return "float32" ; |
36 | case tensorflow::DataType::DT_FLOAT_REF: |
37 | return "float32_ref" ; |
38 | case tensorflow::DataType::DT_DOUBLE: |
39 | return "float64" ; |
40 | case tensorflow::DataType::DT_DOUBLE_REF: |
41 | return "float64_ref" ; |
42 | default: |
43 | return tensorflow::DataTypeString(dt); |
44 | } |
45 | } |
46 | |
47 | } // namespace |
48 | |
49 | namespace tensorflow { |
50 | |
51 | constexpr DataTypeSet kNumPyIncompatibleTypes = |
52 | ToSet(DataType::DT_RESOURCE) | ToSet(DataType::DT_VARIANT); |
53 | |
54 | inline bool DataTypeIsNumPyCompatible(DataType dt) { |
55 | return !kNumPyIncompatibleTypes.Contains(dt); |
56 | } |
57 | |
58 | } // namespace tensorflow |
59 | |
60 | namespace py = pybind11; |
61 | |
62 | PYBIND11_MODULE(_dtypes, m) { |
63 | py::class_<tensorflow::DataType>(m, "DType" ) |
64 | .def(py::init([](py::object obj) { |
65 | auto id = static_cast<int>(py::int_(obj)); |
66 | if (tensorflow::DataType_IsValid(id) && |
67 | id != static_cast<int>(tensorflow::DT_INVALID)) { |
68 | return static_cast<tensorflow::DataType>(id); |
69 | } |
70 | throw py::type_error( |
71 | py::str("{} does not correspond to a valid tensorflow::DataType" ) |
72 | .format(id)); |
73 | })) |
74 | // For compatibility with pure-Python DType. |
75 | .def_property_readonly("_type_enum" , &DataTypeId) |
76 | .def_property_readonly( |
77 | "as_datatype_enum" , &DataTypeId, |
78 | "Returns a `types_pb2.DataType` enum value based on this data type." ) |
79 | |
80 | .def_property_readonly("name" , |
81 | [](tensorflow::DataType self) { |
82 | #if PY_MAJOR_VERSION < 3 |
83 | return py::bytes(DataTypeStringCompat(self)); |
84 | #else |
85 | return DataTypeStringCompat(self); |
86 | #endif |
87 | }) |
88 | .def_property_readonly( |
89 | "size" , |
90 | [](tensorflow::DataType self) { |
91 | return tensorflow::DataTypeSize(tensorflow::BaseType(self)); |
92 | }) |
93 | |
94 | .def("__repr__" , |
95 | [](tensorflow::DataType self) { |
96 | return py::str("tf.{}" ).format(DataTypeStringCompat(self)); |
97 | }) |
98 | .def("__str__" , |
99 | [](tensorflow::DataType self) { |
100 | return py::str("<dtype: {!r}>" ) |
101 | #if PY_MAJOR_VERSION < 3 |
102 | .format(py::bytes(DataTypeStringCompat(self))); |
103 | #else |
104 | .format(DataTypeStringCompat(self)); |
105 | #endif |
106 | }) |
107 | .def("__hash__" , &DataTypeId) |
108 | |
109 | .def_property_readonly( |
110 | "is_numpy_compatible" , |
111 | [](tensorflow::DataType self) { |
112 | return tensorflow::DataTypeIsNumPyCompatible( |
113 | tensorflow::BaseType(self)); |
114 | }, |
115 | "Returns whether this data type has a compatible NumPy data type." ) |
116 | |
117 | .def_property_readonly( |
118 | "is_bool" , |
119 | [](tensorflow::DataType self) { |
120 | return tensorflow::BaseType(self) == tensorflow::DT_BOOL; |
121 | }, |
122 | "Returns whether this is a boolean data type." ) |
123 | .def_property_readonly( |
124 | "is_complex" , |
125 | [](tensorflow::DataType self) { |
126 | return tensorflow::DataTypeIsComplex(tensorflow::BaseType(self)); |
127 | }, |
128 | "Returns whether this is a complex floating point type." ) |
129 | .def_property_readonly( |
130 | "is_floating" , |
131 | [](tensorflow::DataType self) { |
132 | return tensorflow::DataTypeIsFloating(tensorflow::BaseType(self)); |
133 | }, |
134 | "Returns whether this is a (non-quantized, real) floating point " |
135 | "type." ) |
136 | .def_property_readonly( |
137 | "is_integer" , |
138 | [](tensorflow::DataType self) { |
139 | return tensorflow::DataTypeIsInteger(tensorflow::BaseType(self)); |
140 | }, |
141 | "Returns whether this is a (non-quantized) integer type." ) |
142 | .def_property_readonly( |
143 | "is_quantized" , |
144 | [](tensorflow::DataType self) { |
145 | return tensorflow::DataTypeIsQuantized(tensorflow::BaseType(self)); |
146 | }, |
147 | "Returns whether this is a quantized data type." ) |
148 | .def_property_readonly( |
149 | "is_unsigned" , |
150 | [](tensorflow::DataType self) { |
151 | return tensorflow::DataTypeIsUnsigned(tensorflow::BaseType(self)); |
152 | }, |
153 | R"doc(Returns whether this type is unsigned. |
154 | |
155 | Non-numeric, unordered, and quantized types are not considered unsigned, and |
156 | this function returns `False`.)doc" ); |
157 | } |
158 | |