1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/builtin_op_data.h"
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace unpack {
30namespace {
31
32constexpr int kInputTensor = 0;
33
34TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35 const TfLiteUnpackParams* data =
36 reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
37
38 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
39 TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
40
41 const TfLiteTensor* input;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43 TF_LITE_ENSURE(context, NumElements(input) > 0);
44 int axis = data->axis;
45 if (axis < 0) {
46 axis += NumDimensions(input);
47 }
48 TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
49 if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
50 input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
51 input->type != kTfLiteInt16 && input->type != kTfLiteBool) {
52 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
53 TfLiteTypeGetName(input->type));
54 return kTfLiteError;
55 }
56
57 const TfLiteIntArray* input_shape = input->dims;
58 // Num should be equal to the shape[axis].
59 // Resize outputs. rank will be R - 1.
60 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
61 int o = 0;
62 for (int index = 0; index < NumDimensions(input); ++index) {
63 if (index != axis) {
64 output_shape->data[o++] = input_shape->data[index];
65 }
66 }
67
68 TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]);
69 for (int i = 0; i < data->num; ++i) {
70 TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
71 TfLiteTensor* output;
72 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
73 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
74 // Guarantee input/output quantization params match as we do not support
75 // rescaling of unpacked quantized tensors.
76 TF_LITE_ENSURE_EQ(context, input->params.zero_point,
77 output->params.zero_point);
78 TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
79 TF_LITE_ENSURE_OK(
80 context, context->ResizeTensor(context, output, copied_output_shape));
81 }
82
83 TfLiteIntArrayFree(output_shape);
84 return kTfLiteOk;
85}
86
87template <typename T>
88void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
89 const TfLiteTensor* input, int output_count, int axis) {
90 tflite::UnpackParams op_params;
91 op_params.axis = axis;
92 op_params.num_split = output_count;
93 VectorOfTensors<T> all_outputs(*context, *node->outputs);
94 reference_ops::Unpack<T>(op_params, GetTensorShape(input),
95 GetTensorData<T>(input), **all_outputs.shapes(),
96 all_outputs.data());
97}
98
99TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
100 const TfLiteUnpackParams* data =
101 reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
102
103 const TfLiteTensor* input;
104 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
105 switch (input->type) {
106 case kTfLiteFloat32: {
107 UnpackImpl<float>(context, node, input, data->num, data->axis);
108 break;
109 }
110 case kTfLiteInt32: {
111 UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
112 break;
113 }
114 case kTfLiteUInt8: {
115 UnpackImpl<uint8_t>(context, node, input, data->num, data->axis);
116 break;
117 }
118 case kTfLiteInt8: {
119 UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
120 break;
121 }
122 case kTfLiteBool: {
123 UnpackImpl<bool>(context, node, input, data->num, data->axis);
124 break;
125 }
126 case kTfLiteInt16: {
127 UnpackImpl<int16_t>(context, node, input, data->num, data->axis);
128 break;
129 }
130 default: {
131 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack.",
132 TfLiteTypeGetName(input->type));
133 return kTfLiteError;
134 }
135 }
136
137 return kTfLiteOk;
138}
139} // namespace
140} // namespace unpack
141
142TfLiteRegistration* Register_UNPACK() {
143 static TfLiteRegistration r = {nullptr, nullptr, unpack::Prepare,
144 unpack::Eval};
145 return &r;
146}
147
148} // namespace builtin
149} // namespace ops
150} // namespace tflite
151