1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/util.h" |
16 | |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | |
20 | #include <algorithm> |
21 | #include <complex> |
22 | #include <cstring> |
23 | #include <initializer_list> |
24 | #include <memory> |
25 | #include <string> |
26 | #include <vector> |
27 | |
28 | #include "tensorflow/lite/builtin_ops.h" |
29 | #include "tensorflow/lite/c/common.h" |
30 | #include "tensorflow/lite/core/macros.h" |
31 | #include "tensorflow/lite/schema/schema_generated.h" |
32 | |
33 | namespace tflite { |
34 | namespace { |
35 | |
36 | TfLiteStatus UnresolvedOpInvoke(TfLiteContext* context, TfLiteNode* node) { |
37 | TF_LITE_KERNEL_LOG(context, |
38 | "Encountered an unresolved custom op. Did you miss " |
39 | "a custom op or delegate?" ); |
40 | return kTfLiteError; |
41 | } |
42 | |
43 | } // namespace |
44 | |
45 | bool IsFlexOp(const char* custom_name) { |
46 | return custom_name && strncmp(custom_name, kFlexCustomCodePrefix, |
47 | strlen(kFlexCustomCodePrefix)) == 0; |
48 | } |
49 | |
50 | std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray( |
51 | const std::vector<int>& data) { |
52 | std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result( |
53 | TfLiteIntArrayCreate(data.size())); |
54 | std::copy(data.begin(), data.end(), result->data); |
55 | return result; |
56 | } |
57 | |
58 | TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) { |
59 | return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()), |
60 | input.data()); |
61 | } |
62 | |
63 | TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int ndims, const int* dims) { |
64 | TfLiteIntArray* output = TfLiteIntArrayCreate(ndims); |
65 | for (size_t i = 0; i < ndims; i++) { |
66 | output->data[i] = dims[i]; |
67 | } |
68 | return output; |
69 | } |
70 | |
71 | bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, |
72 | const int* b) { |
73 | if (!a) return false; |
74 | if (a->size != b_size) return false; |
75 | for (int i = 0; i < a->size; ++i) { |
76 | if (a->data[i] != b[i]) return false; |
77 | } |
78 | return true; |
79 | } |
80 | |
81 | size_t CombineHashes(std::initializer_list<size_t> hashes) { |
82 | size_t result = 0; |
83 | // Hash combiner used by TensorFlow core. |
84 | for (size_t hash : hashes) { |
85 | result = result ^ |
86 | (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4)); |
87 | } |
88 | return result; |
89 | } |
90 | |
91 | TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, |
92 | size_t* bytes) { |
93 | // TODO(levp): remove the default case so that new types produce compilation |
94 | // error. |
95 | switch (type) { |
96 | case kTfLiteFloat32: |
97 | *bytes = sizeof(float); |
98 | break; |
99 | case kTfLiteInt32: |
100 | *bytes = sizeof(int32_t); |
101 | break; |
102 | case kTfLiteUInt32: |
103 | *bytes = sizeof(uint32_t); |
104 | break; |
105 | case kTfLiteUInt8: |
106 | *bytes = sizeof(uint8_t); |
107 | break; |
108 | case kTfLiteInt64: |
109 | *bytes = sizeof(int64_t); |
110 | break; |
111 | case kTfLiteUInt64: |
112 | *bytes = sizeof(uint64_t); |
113 | break; |
114 | case kTfLiteBool: |
115 | *bytes = sizeof(bool); |
116 | break; |
117 | case kTfLiteComplex64: |
118 | *bytes = sizeof(std::complex<float>); |
119 | break; |
120 | case kTfLiteComplex128: |
121 | *bytes = sizeof(std::complex<double>); |
122 | break; |
123 | case kTfLiteUInt16: |
124 | *bytes = sizeof(uint16_t); |
125 | break; |
126 | case kTfLiteInt16: |
127 | *bytes = sizeof(int16_t); |
128 | break; |
129 | case kTfLiteInt8: |
130 | *bytes = sizeof(int8_t); |
131 | break; |
132 | case kTfLiteFloat16: |
133 | *bytes = sizeof(TfLiteFloat16); |
134 | break; |
135 | case kTfLiteFloat64: |
136 | *bytes = sizeof(double); |
137 | break; |
138 | default: |
139 | if (context) { |
140 | TF_LITE_KERNEL_LOG( |
141 | context, |
142 | "Type %d is unsupported. Only float16, float32, float64, int8, " |
143 | "int16, int32, int64, uint8, uint64, bool, complex64 and " |
144 | "complex128 supported currently." , |
145 | type); |
146 | } |
147 | return kTfLiteError; |
148 | } |
149 | return kTfLiteOk; |
150 | } |
151 | |
152 | TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name) { |
153 | return TfLiteRegistration{nullptr, |
154 | nullptr, |
155 | nullptr, |
156 | /*invoke*/ &UnresolvedOpInvoke, |
157 | nullptr, |
158 | BuiltinOperator_CUSTOM, |
159 | custom_op_name, |
160 | 1}; |
161 | } |
162 | |
163 | bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) { |
164 | return registration.builtin_code == tflite::BuiltinOperator_CUSTOM && |
165 | registration.invoke == &UnresolvedOpInvoke; |
166 | } |
167 | |
168 | std::string GetOpNameByRegistration(const TfLiteRegistration& registration) { |
169 | auto op = registration.builtin_code; |
170 | std::string result = |
171 | EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op)); |
172 | if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) && |
173 | registration.custom_name) { |
174 | result += " " + std::string(registration.custom_name); |
175 | } |
176 | return result; |
177 | } |
178 | |
179 | bool IsValidationSubgraph(const char* name) { |
180 | // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed. |
181 | return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0; |
182 | } |
183 | |
184 | TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) { |
185 | // Multiplying a * b where a and b are size_t cannot result in overflow in a |
186 | // size_t accumulator if both numbers have no non-zero bits in their upper |
187 | // half. |
188 | constexpr size_t size_t_bits = 8 * sizeof(size_t); |
189 | constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2; |
190 | *product = a * b; |
191 | // If neither integers have non-zero bits past 32 bits can't overflow. |
192 | // Otherwise check using slow devision. |
193 | if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) { |
194 | if (a != 0 && *product / a != b) return kTfLiteError; |
195 | } |
196 | return kTfLiteOk; |
197 | } |
198 | } // namespace tflite |
199 | |