1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <string.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/portable_tensor.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/kernel_util.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace squeeze {
27
28struct SqueezeContext {
29 SqueezeContext(TfLiteContext* context, TfLiteNode* node)
30 : params(reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data)),
31 input(GetInput(context, node, 0)),
32 output(GetOutput(context, node, 0)) {}
33 TfLiteSqueezeParams* params;
34 const TfLiteTensor* const input;
35 TfLiteTensor* output;
36};
37
38TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
39 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
40 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
41
42 SqueezeContext op_context(context, node);
43 int input_num_dims = NumDimensions(op_context.input);
44 int num_squeeze_dims = op_context.params->num_squeeze_dims;
45
46 // Determines number of dimensions of output tensor after squeeze.
47 const TfLiteIntArray* input_dims = op_context.input->dims;
48 const int* squeeze_dims = op_context.params->squeeze_dims;
49 TF_LITE_ENSURE(context, input_num_dims <= 8);
50 bool should_squeeze[8] = {false};
51 int num_squeezed_dims = 0;
52 if (num_squeeze_dims == 0) {
53 for (int idx = 0; idx < input_num_dims; ++idx) {
54 if (input_dims->data[idx] == 1) {
55 should_squeeze[idx] = true;
56 ++num_squeezed_dims;
57 }
58 }
59 } else {
60 for (int idx = 0; idx < num_squeeze_dims; ++idx) {
61 int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims
62 : squeeze_dims[idx];
63 TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims &&
64 input_dims->data[current] == 1);
65 if (!should_squeeze[current]) ++num_squeezed_dims;
66 should_squeeze[current] = true;
67 }
68 }
69 // Sets output dimensions.
70 TfLiteIntArray* output_dims =
71 TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims);
72 for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) {
73 if (!should_squeeze[in_idx]) {
74 output_dims->data[out_idx++] = input_dims->data[in_idx];
75 }
76 }
77 return context->ResizeTensor(context, op_context.output, output_dims);
78}
79
80TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
81 SqueezeContext op_context(context, node);
82 if (op_context.input->type == kTfLiteString) {
83 const int input_flat_size = GetTensorShape(op_context.input).FlatSize();
84 const int output_flat_size = GetTensorShape(op_context.output).FlatSize();
85 TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size);
86 SequentialTensorWriter<string> writer(op_context.input, op_context.output);
87 for (int i = 0; i < input_flat_size; i++) {
88 writer.Write(i);
89 }
90 return kTfLiteOk;
91 }
92
93 TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes);
94 memcpy(op_context.output->data.raw, op_context.input->data.raw,
95 op_context.input->bytes);
96 return kTfLiteOk;
97}
98
99} // namespace squeeze
100
101TfLiteRegistration* Register_SQUEEZE() {
102 static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare,
103 squeeze::Eval};
104 return &r;
105}
106
107} // namespace builtin
108} // namespace ops
109} // namespace tflite
110