1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/tensor.h"
20#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21#include "tensorflow/lite/kernels/kernel_util.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace one_hot {
27
28constexpr int kIndicesTensor = 0;
29constexpr int kDepthTensor = 1;
30constexpr int kOnValueTensor = 2;
31constexpr int kOffValueTensor = 3;
32constexpr int kOutputTensor = 0;
33
34// Convenience utility for destructuring a node into the appropriate tensors and
35// data for the op. Note that this destructuring is quite cheap, so we can avoid
36// allocating op-specific, persistent data on the heap.
37struct OneHotContext {
38 OneHotContext(TfLiteContext* context, TfLiteNode* node) {
39 indices = GetInput(context, node, kIndicesTensor);
40 depth = GetInput(context, node, kDepthTensor);
41 on_value = GetInput(context, node, kOnValueTensor);
42 off_value = GetInput(context, node, kOffValueTensor);
43 output = GetOutput(context, node, kOutputTensor);
44
45 const auto* params =
46 reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
47 const int indices_dims = indices->dims->size;
48 axis = (params->axis == -1) ? indices_dims : params->axis;
49 output_dims = indices_dims + 1;
50 dtype = on_value->type;
51 }
52
53 const TfLiteTensor* indices;
54 const TfLiteTensor* depth;
55 const TfLiteTensor* on_value;
56 const TfLiteTensor* off_value;
57 TfLiteTensor* output;
58 int axis;
59 int output_dims;
60 TfLiteType dtype;
61};
62
63template <typename T, typename TI>
64void OneHotComputeImpl(const OneHotContext& op_context) {
65 // prefix_dim_size == # of elements before the axis
66 // depth == # of elements per axis
67 // suffix_dim_size == # of elements after the axis
68 int prefix_dim_size = 1;
69 for (int i = 0; i < op_context.axis; ++i) {
70 prefix_dim_size *= op_context.indices->dims->data[i];
71 }
72 if (prefix_dim_size == 0) {
73 // If indices tensor is degenerate, return a degenerate tensor, just like
74 // TensorFlow does.
75 return;
76 }
77 const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
78 const int depth = *op_context.depth->data.i32;
79
80 const T on_value = *GetTensorData<T>(op_context.on_value);
81 const T off_value = *GetTensorData<T>(op_context.off_value);
82
83 // View the indices as a matrix of size:
84 // prefix_dim_size x suffix_dim_size
85 // View the output as a matrix of size:
86 // prefix_dim_size x depth x suffix_dim_size
87 // Then the output is:
88 // output(i, j, k) == (indices(i, k) == j) ? on : off
89 T* output = GetTensorData<T>(op_context.output);
90 const TI* indices = GetTensorData<TI>(op_context.indices);
91 for (int i = 0; i < prefix_dim_size; ++i) {
92 for (int j = 0; j < depth; ++j) {
93 for (int k = 0; k < suffix_dim_size; ++k, ++output) {
94 *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
95 ? on_value
96 : off_value;
97 }
98 }
99 }
100}
101
102template <typename T>
103void OneHotCompute(const OneHotContext& op_context) {
104 if (op_context.indices->type == kTfLiteInt64) {
105 OneHotComputeImpl<T, int64_t>(op_context);
106 } else {
107 OneHotComputeImpl<T, int>(op_context);
108 }
109}
110
111TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
112 const OneHotContext& op_context) {
113 TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
114 TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
115 for (int i = 0; i < op_context.output_dims; ++i) {
116 if (i < op_context.axis) {
117 output_size->data[i] = op_context.indices->dims->data[i];
118 } else if (i == op_context.axis) {
119 output_size->data[i] = *op_context.depth->data.i32;
120 } else {
121 output_size->data[i] = op_context.indices->dims->data[i - 1];
122 }
123 }
124 return context->ResizeTensor(context, op_context.output, output_size);
125}
126
127TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
128 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
129 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
130
131 OneHotContext op_context{context, node};
132 switch (op_context.dtype) {
133 // TODO(b/111744875): Support uint8 and quantization.
134 case kTfLiteFloat32:
135 case kTfLiteInt16:
136 case kTfLiteInt32:
137 case kTfLiteInt64:
138 case kTfLiteInt8:
139 case kTfLiteUInt8:
140 case kTfLiteBool:
141 op_context.output->type = op_context.dtype;
142 break;
143 default:
144 TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s",
145 TfLiteTypeGetName(op_context.dtype));
146 return kTfLiteError;
147 }
148
149 TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
150 op_context.indices->type == kTfLiteInt64);
151 TF_LITE_ENSURE(context, op_context.axis >= 0 &&
152 op_context.axis < op_context.output_dims);
153 TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
154 TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
155 TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
156 TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype);
157 TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type,
158 op_context.dtype);
159
160 if (!IsConstantTensor(op_context.depth)) {
161 SetTensorToDynamic(op_context.output);
162 return kTfLiteOk;
163 }
164
165 return ResizeOutputTensor(context, op_context);
166}
167
168TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
169 OneHotContext op_context{context, node};
170
171 if (IsDynamicTensor(op_context.output)) {
172 ResizeOutputTensor(context, op_context);
173 }
174
175 switch (op_context.output->type) {
176 case kTfLiteFloat32:
177 OneHotCompute<float>(op_context);
178 break;
179 case kTfLiteInt32:
180 OneHotCompute<int>(op_context);
181 break;
182 case kTfLiteInt64:
183 OneHotCompute<int64_t>(op_context);
184 break;
185 case kTfLiteInt8:
186 OneHotCompute<int8_t>(op_context);
187 break;
188 case kTfLiteUInt8:
189 OneHotCompute<uint8_t>(op_context);
190 break;
191 case kTfLiteBool:
192 OneHotCompute<bool>(op_context);
193 break;
194 default:
195 return kTfLiteError;
196 }
197
198 return kTfLiteOk;
199}
200
201} // namespace one_hot
202
203TfLiteRegistration* Register_ONE_HOT() {
204 static TfLiteRegistration r = {
205 nullptr,
206 nullptr,
207 one_hot::Prepare,
208 one_hot::Eval,
209 };
210 return &r;
211}
212
213} // namespace builtin
214} // namespace ops
215} // namespace tflite
216