1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | |
18 | #include "tensorflow/lite/c/builtin_op_data.h" |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/internal/types.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace unpack { |
30 | namespace { |
31 | |
32 | constexpr int kInputTensor = 0; |
33 | |
34 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
35 | const TfLiteUnpackParams* data = |
36 | reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data); |
37 | |
38 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
39 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num); |
40 | |
41 | const TfLiteTensor* input; |
42 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
43 | TF_LITE_ENSURE(context, NumElements(input) > 0); |
44 | int axis = data->axis; |
45 | if (axis < 0) { |
46 | axis += NumDimensions(input); |
47 | } |
48 | TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); |
49 | if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && |
50 | input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 && |
51 | input->type != kTfLiteInt16 && input->type != kTfLiteBool) { |
52 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack." , |
53 | TfLiteTypeGetName(input->type)); |
54 | return kTfLiteError; |
55 | } |
56 | |
57 | const TfLiteIntArray* input_shape = input->dims; |
58 | // Num should be equal to the shape[axis]. |
59 | // Resize outputs. rank will be R - 1. |
60 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1); |
61 | int o = 0; |
62 | for (int index = 0; index < NumDimensions(input); ++index) { |
63 | if (index != axis) { |
64 | output_shape->data[o++] = input_shape->data[index]; |
65 | } |
66 | } |
67 | |
68 | TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[axis]); |
69 | for (int i = 0; i < data->num; ++i) { |
70 | TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); |
71 | TfLiteTensor* output; |
72 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
73 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type); |
74 | // Guarantee input/output quantization params match as we do not support |
75 | // rescaling of unpacked quantized tensors. |
76 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, |
77 | output->params.zero_point); |
78 | TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); |
79 | TF_LITE_ENSURE_OK( |
80 | context, context->ResizeTensor(context, output, copied_output_shape)); |
81 | } |
82 | |
83 | TfLiteIntArrayFree(output_shape); |
84 | return kTfLiteOk; |
85 | } |
86 | |
87 | template <typename T> |
88 | void UnpackImpl(TfLiteContext* context, TfLiteNode* node, |
89 | const TfLiteTensor* input, int output_count, int axis) { |
90 | tflite::UnpackParams op_params; |
91 | op_params.axis = axis; |
92 | op_params.num_split = output_count; |
93 | VectorOfTensors<T> all_outputs(*context, *node->outputs); |
94 | reference_ops::Unpack<T>(op_params, GetTensorShape(input), |
95 | GetTensorData<T>(input), **all_outputs.shapes(), |
96 | all_outputs.data()); |
97 | } |
98 | |
99 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
100 | const TfLiteUnpackParams* data = |
101 | reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data); |
102 | |
103 | const TfLiteTensor* input; |
104 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
105 | switch (input->type) { |
106 | case kTfLiteFloat32: { |
107 | UnpackImpl<float>(context, node, input, data->num, data->axis); |
108 | break; |
109 | } |
110 | case kTfLiteInt32: { |
111 | UnpackImpl<int32_t>(context, node, input, data->num, data->axis); |
112 | break; |
113 | } |
114 | case kTfLiteUInt8: { |
115 | UnpackImpl<uint8_t>(context, node, input, data->num, data->axis); |
116 | break; |
117 | } |
118 | case kTfLiteInt8: { |
119 | UnpackImpl<int8_t>(context, node, input, data->num, data->axis); |
120 | break; |
121 | } |
122 | case kTfLiteBool: { |
123 | UnpackImpl<bool>(context, node, input, data->num, data->axis); |
124 | break; |
125 | } |
126 | case kTfLiteInt16: { |
127 | UnpackImpl<int16_t>(context, node, input, data->num, data->axis); |
128 | break; |
129 | } |
130 | default: { |
131 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by unpack." , |
132 | TfLiteTypeGetName(input->type)); |
133 | return kTfLiteError; |
134 | } |
135 | } |
136 | |
137 | return kTfLiteOk; |
138 | } |
139 | } // namespace |
140 | } // namespace unpack |
141 | |
142 | TfLiteRegistration* Register_UNPACK() { |
143 | static TfLiteRegistration r = {nullptr, nullptr, unpack::Prepare, |
144 | unpack::Eval}; |
145 | return &r; |
146 | } |
147 | |
148 | } // namespace builtin |
149 | } // namespace ops |
150 | } // namespace tflite |
151 | |