1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | #include <string.h> |
17 | |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/tensor.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace expand_dims { |
28 | |
29 | // Input indices |
30 | enum { kInput = 0, kAxis }; |
31 | |
32 | namespace { |
33 | TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input, |
34 | int axis, TfLiteTensor* output) { |
35 | const TfLiteIntArray& input_dims = *input.dims; |
36 | if (axis < 0) { |
37 | axis = input_dims.size + 1 + axis; |
38 | } |
39 | TF_LITE_ENSURE(context, axis <= input_dims.size); |
40 | TF_LITE_ENSURE(context, axis >= 0); |
41 | |
42 | TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1); |
43 | for (int i = 0; i < output_dims->size; ++i) { |
44 | if (i < axis) { |
45 | output_dims->data[i] = input_dims.data[i]; |
46 | } else if (i == axis) { |
47 | output_dims->data[i] = 1; |
48 | } else { |
49 | output_dims->data[i] = input_dims.data[i - 1]; |
50 | } |
51 | } |
52 | |
53 | return context->ResizeTensor(context, output, output_dims); |
54 | } |
55 | |
56 | TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, |
57 | const TfLiteTensor& axis, int* axis_value) { |
58 | TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1); |
59 | switch (axis.type) { |
60 | case kTfLiteInt32: |
61 | *axis_value = *GetTensorData<int32_t>(&axis); |
62 | return kTfLiteOk; |
63 | case kTfLiteInt64: |
64 | *axis_value = *GetTensorData<int64_t>(&axis); |
65 | return kTfLiteOk; |
66 | default: |
67 | return kTfLiteError; |
68 | } |
69 | } |
70 | |
71 | } // namespace |
72 | |
73 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
74 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
75 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
76 | |
77 | const TfLiteTensor* input; |
78 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input)); |
79 | const TfLiteTensor* axis; |
80 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis)); |
81 | TfLiteTensor* output; |
82 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
83 | |
84 | output->type = input->type; |
85 | TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); |
86 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, |
87 | output->params.zero_point); |
88 | if (input->type == kTfLiteInt16) { |
89 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
90 | } |
91 | |
92 | if (IsConstantTensor(axis)) { |
93 | int axis_value; |
94 | TF_LITE_ENSURE_OK(context, |
95 | GetAxisValueFromTensor(context, *axis, &axis_value)); |
96 | return ExpandTensorDim(context, *input, axis_value, output); |
97 | } |
98 | SetTensorToDynamic(output); |
99 | |
100 | return kTfLiteOk; |
101 | } |
102 | |
103 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
104 | // Just copy input to output. |
105 | const TfLiteTensor* input; |
106 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input)); |
107 | TfLiteTensor* output; |
108 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
109 | const TfLiteTensor* axis; |
110 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis)); |
111 | if (IsDynamicTensor(output)) { |
112 | int axis_value; |
113 | TF_LITE_ENSURE_OK(context, |
114 | GetAxisValueFromTensor(context, *axis, &axis_value)); |
115 | TF_LITE_ENSURE_OK(context, |
116 | ExpandTensorDim(context, *input, axis_value, output)); |
117 | } |
118 | if (output->type == kTfLiteString) { |
119 | TfLiteTensorRealloc(input->bytes, output); |
120 | } |
121 | memcpy(output->data.raw, input->data.raw, input->bytes); |
122 | return kTfLiteOk; |
123 | } |
124 | |
125 | } // namespace expand_dims |
126 | TfLiteRegistration* Register_EXPAND_DIMS() { |
127 | static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare, |
128 | expand_dims::Eval}; |
129 | return &r; |
130 | } |
131 | } // namespace builtin |
132 | } // namespace ops |
133 | } // namespace tflite |
134 | |