1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/builtin_op_data.h"
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace pack {
30namespace {
31
32constexpr int kOutputTensor = 0;
33
34TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35 TfLitePackParams* data =
36 reinterpret_cast<TfLitePackParams*>(node->builtin_data);
37
38 TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count);
39 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
40
41 const TfLiteTensor* input0;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input0));
43 const int dimension_size = NumDimensions(input0) + 1;
44 if (data->axis < 0) {
45 data->axis += dimension_size;
46 }
47 TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis);
48 TF_LITE_ENSURE(context, data->axis >= 0);
49
50 if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32 &&
51 input0->type != kTfLiteUInt8 && input0->type != kTfLiteInt8 &&
52 input0->type != kTfLiteInt16 && input0->type != kTfLiteInt64) {
53 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
54 TfLiteTypeGetName(input0->type));
55 return kTfLiteError;
56 }
57 // Make sure all inputs have the same shape and type.
58 for (int i = 1; i < data->values_count; ++i) {
59 const TfLiteTensor* input;
60 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
61 TF_LITE_ENSURE(context, HaveSameShapes(input0, input));
62 TF_LITE_ENSURE_TYPES_EQ(context, input0->type, input->type);
63 }
64
65 // Resize output. rank R will become rank R + 1
66 const TfLiteIntArray* input_shape = input0->dims;
67 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size);
68 int i = 0;
69 for (int index = 0; index < dimension_size; ++index) {
70 if (index == data->axis) {
71 output_shape->data[index] = data->values_count;
72 } else {
73 output_shape->data[index] = input_shape->data[i++];
74 }
75 }
76
77 TfLiteTensor* output;
78 TF_LITE_ENSURE_OK(context,
79 GetOutputSafe(context, node, kOutputTensor, &output));
80 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input0->type);
81
82 // Guarantee input/output quantization params match as we do not support
83 // packing quantized tensors.
84 for (int i = 0; i < data->values_count; i++) {
85 const TfLiteTensor* input;
86 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
87 TF_LITE_ENSURE_EQ(context, input->params.zero_point,
88 output->params.zero_point);
89 TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
90 }
91
92 return context->ResizeTensor(context, output, output_shape);
93}
94
95template <typename T>
96TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
97 TfLiteTensor* output, int values_count, int axis) {
98 TF_LITE_ENSURE(context, axis >= 0);
99
100 VectorOfTensors<T> all_inputs(*context, *node->inputs);
101 tflite::PackParams op_params;
102 op_params.axis = axis;
103 op_params.inputs_count = values_count;
104
105 reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
106 GetTensorShape(output), GetTensorData<T>(output));
107 return kTfLiteOk;
108}
109
110TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
111 const TfLitePackParams* data =
112 reinterpret_cast<TfLitePackParams*>(node->builtin_data);
113
114 TfLiteTensor* output;
115 TF_LITE_ENSURE_OK(context,
116 GetOutputSafe(context, node, kOutputTensor, &output));
117 switch (output->type) {
118 case kTfLiteFloat32: {
119 return PackImpl<float>(context, node, output, data->values_count,
120 data->axis);
121 }
122 case kTfLiteUInt8: {
123 return PackImpl<uint8_t>(context, node, output, data->values_count,
124 data->axis);
125 }
126 case kTfLiteInt8: {
127 return PackImpl<int8_t>(context, node, output, data->values_count,
128 data->axis);
129 }
130 case kTfLiteInt16: {
131 return PackImpl<int16_t>(context, node, output, data->values_count,
132 data->axis);
133 }
134 case kTfLiteInt32: {
135 return PackImpl<int32_t>(context, node, output, data->values_count,
136 data->axis);
137 }
138 case kTfLiteInt64: {
139 return PackImpl<int64_t>(context, node, output, data->values_count,
140 data->axis);
141 }
142 default: {
143 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by pack.",
144 TfLiteTypeGetName(output->type));
145 return kTfLiteError;
146 }
147 }
148}
149
150} // namespace
151} // namespace pack
152
153TfLiteRegistration* Register_PACK() {
154 static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
155 return &r;
156}
157
158} // namespace builtin
159} // namespace ops
160} // namespace tflite
161