1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "pybind11/detail/common.h"
17#include "pybind11/pybind11.h"
18#include "tensorflow/core/framework/types.h"
19#include "tensorflow/core/framework/types.pb.h"
20
21namespace {
22
23inline int DataTypeId(tensorflow::DataType dt) { return static_cast<int>(dt); }
24
25// A variant of tensorflow::DataTypeString which uses fixed-width names
26// for floating point data types. This behavior is compatible with that of
27// existing pure Python DType.
28const std::string DataTypeStringCompat(tensorflow::DataType dt) {
29 switch (dt) {
30 case tensorflow::DataType::DT_HALF:
31 return "float16";
32 case tensorflow::DataType::DT_HALF_REF:
33 return "float16_ref";
34 case tensorflow::DataType::DT_FLOAT:
35 return "float32";
36 case tensorflow::DataType::DT_FLOAT_REF:
37 return "float32_ref";
38 case tensorflow::DataType::DT_DOUBLE:
39 return "float64";
40 case tensorflow::DataType::DT_DOUBLE_REF:
41 return "float64_ref";
42 default:
43 return tensorflow::DataTypeString(dt);
44 }
45}
46
47} // namespace
48
49namespace tensorflow {
50
51constexpr DataTypeSet kNumPyIncompatibleTypes =
52 ToSet(DataType::DT_RESOURCE) | ToSet(DataType::DT_VARIANT);
53
54inline bool DataTypeIsNumPyCompatible(DataType dt) {
55 return !kNumPyIncompatibleTypes.Contains(dt);
56}
57
58} // namespace tensorflow
59
60namespace py = pybind11;
61
62PYBIND11_MODULE(_dtypes, m) {
63 py::class_<tensorflow::DataType>(m, "DType")
64 .def(py::init([](py::object obj) {
65 auto id = static_cast<int>(py::int_(obj));
66 if (tensorflow::DataType_IsValid(id) &&
67 id != static_cast<int>(tensorflow::DT_INVALID)) {
68 return static_cast<tensorflow::DataType>(id);
69 }
70 throw py::type_error(
71 py::str("{} does not correspond to a valid tensorflow::DataType")
72 .format(id));
73 }))
74 // For compatibility with pure-Python DType.
75 .def_property_readonly("_type_enum", &DataTypeId)
76 .def_property_readonly(
77 "as_datatype_enum", &DataTypeId,
78 "Returns a `types_pb2.DataType` enum value based on this data type.")
79
80 .def_property_readonly("name",
81 [](tensorflow::DataType self) {
82#if PY_MAJOR_VERSION < 3
83 return py::bytes(DataTypeStringCompat(self));
84#else
85 return DataTypeStringCompat(self);
86#endif
87 })
88 .def_property_readonly(
89 "size",
90 [](tensorflow::DataType self) {
91 return tensorflow::DataTypeSize(tensorflow::BaseType(self));
92 })
93
94 .def("__repr__",
95 [](tensorflow::DataType self) {
96 return py::str("tf.{}").format(DataTypeStringCompat(self));
97 })
98 .def("__str__",
99 [](tensorflow::DataType self) {
100 return py::str("<dtype: {!r}>")
101#if PY_MAJOR_VERSION < 3
102 .format(py::bytes(DataTypeStringCompat(self)));
103#else
104 .format(DataTypeStringCompat(self));
105#endif
106 })
107 .def("__hash__", &DataTypeId)
108
109 .def_property_readonly(
110 "is_numpy_compatible",
111 [](tensorflow::DataType self) {
112 return tensorflow::DataTypeIsNumPyCompatible(
113 tensorflow::BaseType(self));
114 },
115 "Returns whether this data type has a compatible NumPy data type.")
116
117 .def_property_readonly(
118 "is_bool",
119 [](tensorflow::DataType self) {
120 return tensorflow::BaseType(self) == tensorflow::DT_BOOL;
121 },
122 "Returns whether this is a boolean data type.")
123 .def_property_readonly(
124 "is_complex",
125 [](tensorflow::DataType self) {
126 return tensorflow::DataTypeIsComplex(tensorflow::BaseType(self));
127 },
128 "Returns whether this is a complex floating point type.")
129 .def_property_readonly(
130 "is_floating",
131 [](tensorflow::DataType self) {
132 return tensorflow::DataTypeIsFloating(tensorflow::BaseType(self));
133 },
134 "Returns whether this is a (non-quantized, real) floating point "
135 "type.")
136 .def_property_readonly(
137 "is_integer",
138 [](tensorflow::DataType self) {
139 return tensorflow::DataTypeIsInteger(tensorflow::BaseType(self));
140 },
141 "Returns whether this is a (non-quantized) integer type.")
142 .def_property_readonly(
143 "is_quantized",
144 [](tensorflow::DataType self) {
145 return tensorflow::DataTypeIsQuantized(tensorflow::BaseType(self));
146 },
147 "Returns whether this is a quantized data type.")
148 .def_property_readonly(
149 "is_unsigned",
150 [](tensorflow::DataType self) {
151 return tensorflow::DataTypeIsUnsigned(tensorflow::BaseType(self));
152 },
153 R"doc(Returns whether this type is unsigned.
154
155Non-numeric, unordered, and quantized types are not considered unsigned, and
156this function returns `False`.)doc");
157}
158