1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <string.h> |
16 | |
17 | #include "tensorflow/lite/c/builtin_op_data.h" |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/portable_tensor.h" |
20 | #include "tensorflow/lite/kernels/internal/tensor.h" |
21 | #include "tensorflow/lite/kernels/kernel_util.h" |
22 | |
23 | namespace tflite { |
24 | namespace ops { |
25 | namespace builtin { |
26 | namespace squeeze { |
27 | |
28 | struct SqueezeContext { |
29 | SqueezeContext(TfLiteContext* context, TfLiteNode* node) |
30 | : params(reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data)), |
31 | input(GetInput(context, node, 0)), |
32 | output(GetOutput(context, node, 0)) {} |
33 | TfLiteSqueezeParams* params; |
34 | const TfLiteTensor* const input; |
35 | TfLiteTensor* output; |
36 | }; |
37 | |
38 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
39 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
40 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
41 | |
42 | SqueezeContext op_context(context, node); |
43 | int input_num_dims = NumDimensions(op_context.input); |
44 | int num_squeeze_dims = op_context.params->num_squeeze_dims; |
45 | |
46 | // Determines number of dimensions of output tensor after squeeze. |
47 | const TfLiteIntArray* input_dims = op_context.input->dims; |
48 | const int* squeeze_dims = op_context.params->squeeze_dims; |
49 | TF_LITE_ENSURE(context, input_num_dims <= 8); |
50 | bool should_squeeze[8] = {false}; |
51 | int num_squeezed_dims = 0; |
52 | if (num_squeeze_dims == 0) { |
53 | for (int idx = 0; idx < input_num_dims; ++idx) { |
54 | if (input_dims->data[idx] == 1) { |
55 | should_squeeze[idx] = true; |
56 | ++num_squeezed_dims; |
57 | } |
58 | } |
59 | } else { |
60 | for (int idx = 0; idx < num_squeeze_dims; ++idx) { |
61 | int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims |
62 | : squeeze_dims[idx]; |
63 | TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims && |
64 | input_dims->data[current] == 1); |
65 | if (!should_squeeze[current]) ++num_squeezed_dims; |
66 | should_squeeze[current] = true; |
67 | } |
68 | } |
69 | // Sets output dimensions. |
70 | TfLiteIntArray* output_dims = |
71 | TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims); |
72 | for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) { |
73 | if (!should_squeeze[in_idx]) { |
74 | output_dims->data[out_idx++] = input_dims->data[in_idx]; |
75 | } |
76 | } |
77 | return context->ResizeTensor(context, op_context.output, output_dims); |
78 | } |
79 | |
80 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
81 | SqueezeContext op_context(context, node); |
82 | if (op_context.input->type == kTfLiteString) { |
83 | const int input_flat_size = GetTensorShape(op_context.input).FlatSize(); |
84 | const int output_flat_size = GetTensorShape(op_context.output).FlatSize(); |
85 | TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size); |
86 | SequentialTensorWriter<string> writer(op_context.input, op_context.output); |
87 | for (int i = 0; i < input_flat_size; i++) { |
88 | writer.Write(i); |
89 | } |
90 | return kTfLiteOk; |
91 | } |
92 | |
93 | TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); |
94 | memcpy(op_context.output->data.raw, op_context.input->data.raw, |
95 | op_context.input->bytes); |
96 | return kTfLiteOk; |
97 | } |
98 | |
99 | } // namespace squeeze |
100 | |
101 | TfLiteRegistration* Register_SQUEEZE() { |
102 | static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare, |
103 | squeeze::Eval}; |
104 | return &r; |
105 | } |
106 | |
107 | } // namespace builtin |
108 | } // namespace ops |
109 | } // namespace tflite |
110 | |