1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include "tensorflow/lite/c/common.h" |
18 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
19 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/internal/types.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace transpose { |
30 | |
31 | // This file has two implementations of Transpose. |
32 | enum KernelType { |
33 | kReference, |
34 | kGenericOptimized, |
35 | }; |
36 | |
37 | struct TransposeContext { |
38 | TransposeContext(TfLiteContext* context, TfLiteNode* node) { |
39 | input = GetInput(context, node, 0); |
40 | perm = GetInput(context, node, 1); |
41 | output = GetOutput(context, node, 0); |
42 | } |
43 | const TfLiteTensor* input; |
44 | const TfLiteTensor* perm; |
45 | TfLiteTensor* output; |
46 | }; |
47 | |
48 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
49 | TransposeContext* op_context) { |
50 | int dims = NumDimensions(op_context->input); |
51 | const int* perm_data = GetTensorData<int32_t>(op_context->perm); |
52 | |
53 | // Ensure validity of the permutations tensor as a 1D tensor. |
54 | TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1); |
55 | TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims); |
56 | for (int idx = 0; idx < dims; ++idx) { |
57 | TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims), |
58 | "Transpose op permutations array is out of bounds." ); |
59 | } |
60 | |
61 | // Determine size of output tensor. |
62 | TfLiteIntArray* input_size = op_context->input->dims; |
63 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); |
64 | for (int idx = 0; idx < dims; ++idx) { |
65 | output_size->data[idx] = input_size->data[perm_data[idx]]; |
66 | } |
67 | |
68 | return context->ResizeTensor(context, op_context->output, output_size); |
69 | } |
70 | |
71 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
72 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
73 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
74 | |
75 | TransposeContext op_context(context, node); |
76 | |
77 | // Ensure validity of input tensor. |
78 | TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5, |
79 | "Transpose op only supports 1D-5D input arrays." ); |
80 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type, |
81 | op_context.output->type); |
82 | |
83 | if (!IsConstantTensor(op_context.perm)) { |
84 | SetTensorToDynamic(op_context.output); |
85 | return kTfLiteOk; |
86 | } |
87 | return ResizeOutputTensor(context, &op_context); |
88 | } |
89 | |
90 | template <KernelType kernel_type> |
91 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
92 | TransposeContext op_context(context, node); |
93 | |
94 | // Resize the output tensor if the output tensor is dynamic. |
95 | if (IsDynamicTensor(op_context.output)) { |
96 | TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); |
97 | } |
98 | |
99 | const int* perm_data = GetTensorData<int32_t>(op_context.perm); |
100 | const int size = op_context.perm->dims->data[0]; |
101 | TransposeParams params; |
102 | params.perm_count = size; |
103 | for (int i = 0; i < size; ++i) { |
104 | params.perm[i] = perm_data[i]; |
105 | } |
106 | |
107 | #define TF_LITE_TRANSPOSE(type, scalar) \ |
108 | type::Transpose(params, GetTensorShape(op_context.input), \ |
109 | GetTensorData<scalar>(op_context.input), \ |
110 | GetTensorShape(op_context.output), \ |
111 | GetTensorData<scalar>(op_context.output)) |
112 | |
113 | // Transpose kernel only does rearranging values not numeric evaluations on |
114 | // each cell. It's safe to implement per size of scalar type and this trick |
115 | // keeps the total code size in a reasonable range. |
116 | switch (op_context.input->type) { |
117 | case kTfLiteFloat32: |
118 | case kTfLiteInt32: |
119 | if (kernel_type == kGenericOptimized) { |
120 | TF_LITE_TRANSPOSE(optimized_ops, int32_t); |
121 | } else { |
122 | TF_LITE_TRANSPOSE(reference_ops, int32_t); |
123 | } |
124 | break; |
125 | case kTfLiteUInt8: |
126 | case kTfLiteInt8: |
127 | if (kernel_type == kGenericOptimized) { |
128 | TF_LITE_TRANSPOSE(optimized_ops, int8_t); |
129 | } else { |
130 | TF_LITE_TRANSPOSE(reference_ops, int8_t); |
131 | } |
132 | break; |
133 | case kTfLiteInt16: |
134 | TF_LITE_TRANSPOSE(reference_ops, int16_t); |
135 | break; |
136 | case kTfLiteInt64: |
137 | TF_LITE_TRANSPOSE(reference_ops, int64_t); |
138 | break; |
139 | case kTfLiteBool: |
140 | if (sizeof(bool) == 1) { |
141 | if (kernel_type == kGenericOptimized) { |
142 | TF_LITE_TRANSPOSE(optimized_ops, int8_t); |
143 | } else { |
144 | TF_LITE_TRANSPOSE(reference_ops, int8_t); |
145 | } |
146 | } else { |
147 | TF_LITE_TRANSPOSE(reference_ops, bool); |
148 | } |
149 | break; |
150 | default: |
151 | TF_LITE_KERNEL_LOG(context, |
152 | "Type %s is currently not supported by Transpose." , |
153 | TfLiteTypeGetName(op_context.input->type)); |
154 | return kTfLiteError; |
155 | } |
156 | #undef TF_LITE_TRANSPOSE |
157 | |
158 | return kTfLiteOk; |
159 | } |
160 | |
161 | } // namespace transpose |
162 | |
163 | TfLiteRegistration* Register_TRANSPOSE_REF() { |
164 | static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, |
165 | transpose::Eval<transpose::kReference>}; |
166 | return &r; |
167 | } |
168 | |
169 | TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() { |
170 | static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, |
171 | transpose::Eval<transpose::kGenericOptimized>}; |
172 | return &r; |
173 | } |
174 | |
175 | TfLiteRegistration* Register_TRANSPOSE() { |
176 | return Register_TRANSPOSE_GENERIC_OPTIMIZED(); |
177 | } |
178 | |
179 | } // namespace builtin |
180 | } // namespace ops |
181 | } // namespace tflite |
182 | |