1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/util.h"
16
17#include <stddef.h>
18#include <stdint.h>
19
20#include <algorithm>
21#include <complex>
22#include <cstring>
23#include <initializer_list>
24#include <memory>
25#include <string>
26#include <vector>
27
28#include "tensorflow/lite/builtin_ops.h"
29#include "tensorflow/lite/c/common.h"
30#include "tensorflow/lite/core/macros.h"
31#include "tensorflow/lite/schema/schema_generated.h"
32
33namespace tflite {
34namespace {
35
36TfLiteStatus UnresolvedOpInvoke(TfLiteContext* context, TfLiteNode* node) {
37 TF_LITE_KERNEL_LOG(context,
38 "Encountered an unresolved custom op. Did you miss "
39 "a custom op or delegate?");
40 return kTfLiteError;
41}
42
43} // namespace
44
45bool IsFlexOp(const char* custom_name) {
46 return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
47 strlen(kFlexCustomCodePrefix)) == 0;
48}
49
50std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
51 const std::vector<int>& data) {
52 std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
53 TfLiteIntArrayCreate(data.size()));
54 std::copy(data.begin(), data.end(), result->data);
55 return result;
56}
57
58TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
59 return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
60 input.data());
61}
62
63TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int ndims, const int* dims) {
64 TfLiteIntArray* output = TfLiteIntArrayCreate(ndims);
65 for (size_t i = 0; i < ndims; i++) {
66 output->data[i] = dims[i];
67 }
68 return output;
69}
70
71bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
72 const int* b) {
73 if (!a) return false;
74 if (a->size != b_size) return false;
75 for (int i = 0; i < a->size; ++i) {
76 if (a->data[i] != b[i]) return false;
77 }
78 return true;
79}
80
81size_t CombineHashes(std::initializer_list<size_t> hashes) {
82 size_t result = 0;
83 // Hash combiner used by TensorFlow core.
84 for (size_t hash : hashes) {
85 result = result ^
86 (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4));
87 }
88 return result;
89}
90
91TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
92 size_t* bytes) {
93 // TODO(levp): remove the default case so that new types produce compilation
94 // error.
95 switch (type) {
96 case kTfLiteFloat32:
97 *bytes = sizeof(float);
98 break;
99 case kTfLiteInt32:
100 *bytes = sizeof(int32_t);
101 break;
102 case kTfLiteUInt32:
103 *bytes = sizeof(uint32_t);
104 break;
105 case kTfLiteUInt8:
106 *bytes = sizeof(uint8_t);
107 break;
108 case kTfLiteInt64:
109 *bytes = sizeof(int64_t);
110 break;
111 case kTfLiteUInt64:
112 *bytes = sizeof(uint64_t);
113 break;
114 case kTfLiteBool:
115 *bytes = sizeof(bool);
116 break;
117 case kTfLiteComplex64:
118 *bytes = sizeof(std::complex<float>);
119 break;
120 case kTfLiteComplex128:
121 *bytes = sizeof(std::complex<double>);
122 break;
123 case kTfLiteUInt16:
124 *bytes = sizeof(uint16_t);
125 break;
126 case kTfLiteInt16:
127 *bytes = sizeof(int16_t);
128 break;
129 case kTfLiteInt8:
130 *bytes = sizeof(int8_t);
131 break;
132 case kTfLiteFloat16:
133 *bytes = sizeof(TfLiteFloat16);
134 break;
135 case kTfLiteFloat64:
136 *bytes = sizeof(double);
137 break;
138 default:
139 if (context) {
140 TF_LITE_KERNEL_LOG(
141 context,
142 "Type %d is unsupported. Only float16, float32, float64, int8, "
143 "int16, int32, int64, uint8, uint64, bool, complex64 and "
144 "complex128 supported currently.",
145 type);
146 }
147 return kTfLiteError;
148 }
149 return kTfLiteOk;
150}
151
152TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name) {
153 return TfLiteRegistration{nullptr,
154 nullptr,
155 nullptr,
156 /*invoke*/ &UnresolvedOpInvoke,
157 nullptr,
158 BuiltinOperator_CUSTOM,
159 custom_op_name,
160 1};
161}
162
163bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) {
164 return registration.builtin_code == tflite::BuiltinOperator_CUSTOM &&
165 registration.invoke == &UnresolvedOpInvoke;
166}
167
168std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
169 auto op = registration.builtin_code;
170 std::string result =
171 EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
172 if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) &&
173 registration.custom_name) {
174 result += " " + std::string(registration.custom_name);
175 }
176 return result;
177}
178
179bool IsValidationSubgraph(const char* name) {
180 // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
181 return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
182}
183
184TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
185 // Multiplying a * b where a and b are size_t cannot result in overflow in a
186 // size_t accumulator if both numbers have no non-zero bits in their upper
187 // half.
188 constexpr size_t size_t_bits = 8 * sizeof(size_t);
189 constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2;
190 *product = a * b;
191 // If neither integers have non-zero bits past 32 bits can't overflow.
192 // Otherwise check using slow devision.
193 if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) {
194 if (a != 0 && *product / a != b) return kTfLiteError;
195 }
196 return kTfLiteOk;
197}
198} // namespace tflite
199