1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/compatibility.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace transpose {
30
31// This file has two implementations of Transpose.
32enum KernelType {
33 kReference,
34 kGenericOptimized,
35};
36
37struct TransposeContext {
38 TransposeContext(TfLiteContext* context, TfLiteNode* node) {
39 input = GetInput(context, node, 0);
40 perm = GetInput(context, node, 1);
41 output = GetOutput(context, node, 0);
42 }
43 const TfLiteTensor* input;
44 const TfLiteTensor* perm;
45 TfLiteTensor* output;
46};
47
48TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
49 TransposeContext* op_context) {
50 int dims = NumDimensions(op_context->input);
51 const int* perm_data = GetTensorData<int32_t>(op_context->perm);
52
53 // Ensure validity of the permutations tensor as a 1D tensor.
54 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
55 TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
56 for (int idx = 0; idx < dims; ++idx) {
57 TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
58 "Transpose op permutations array is out of bounds.");
59 }
60
61 // Determine size of output tensor.
62 TfLiteIntArray* input_size = op_context->input->dims;
63 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
64 for (int idx = 0; idx < dims; ++idx) {
65 output_size->data[idx] = input_size->data[perm_data[idx]];
66 }
67
68 return context->ResizeTensor(context, op_context->output, output_size);
69}
70
71TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
72 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
73 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
74
75 TransposeContext op_context(context, node);
76
77 // Ensure validity of input tensor.
78 TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
79 "Transpose op only supports 1D-5D input arrays.");
80 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
81 op_context.output->type);
82
83 if (!IsConstantTensor(op_context.perm)) {
84 SetTensorToDynamic(op_context.output);
85 return kTfLiteOk;
86 }
87 return ResizeOutputTensor(context, &op_context);
88}
89
90template <KernelType kernel_type>
91TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
92 TransposeContext op_context(context, node);
93
94 // Resize the output tensor if the output tensor is dynamic.
95 if (IsDynamicTensor(op_context.output)) {
96 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
97 }
98
99 const int* perm_data = GetTensorData<int32_t>(op_context.perm);
100 const int size = op_context.perm->dims->data[0];
101 TransposeParams params;
102 params.perm_count = size;
103 for (int i = 0; i < size; ++i) {
104 params.perm[i] = perm_data[i];
105 }
106
107#define TF_LITE_TRANSPOSE(type, scalar) \
108 type::Transpose(params, GetTensorShape(op_context.input), \
109 GetTensorData<scalar>(op_context.input), \
110 GetTensorShape(op_context.output), \
111 GetTensorData<scalar>(op_context.output))
112
113 // Transpose kernel only does rearranging values not numeric evaluations on
114 // each cell. It's safe to implement per size of scalar type and this trick
115 // keeps the total code size in a reasonable range.
116 switch (op_context.input->type) {
117 case kTfLiteFloat32:
118 case kTfLiteInt32:
119 if (kernel_type == kGenericOptimized) {
120 TF_LITE_TRANSPOSE(optimized_ops, int32_t);
121 } else {
122 TF_LITE_TRANSPOSE(reference_ops, int32_t);
123 }
124 break;
125 case kTfLiteUInt8:
126 case kTfLiteInt8:
127 if (kernel_type == kGenericOptimized) {
128 TF_LITE_TRANSPOSE(optimized_ops, int8_t);
129 } else {
130 TF_LITE_TRANSPOSE(reference_ops, int8_t);
131 }
132 break;
133 case kTfLiteInt16:
134 TF_LITE_TRANSPOSE(reference_ops, int16_t);
135 break;
136 case kTfLiteInt64:
137 TF_LITE_TRANSPOSE(reference_ops, int64_t);
138 break;
139 case kTfLiteBool:
140 if (sizeof(bool) == 1) {
141 if (kernel_type == kGenericOptimized) {
142 TF_LITE_TRANSPOSE(optimized_ops, int8_t);
143 } else {
144 TF_LITE_TRANSPOSE(reference_ops, int8_t);
145 }
146 } else {
147 TF_LITE_TRANSPOSE(reference_ops, bool);
148 }
149 break;
150 default:
151 TF_LITE_KERNEL_LOG(context,
152 "Type %s is currently not supported by Transpose.",
153 TfLiteTypeGetName(op_context.input->type));
154 return kTfLiteError;
155 }
156#undef TF_LITE_TRANSPOSE
157
158 return kTfLiteOk;
159}
160
161} // namespace transpose
162
163TfLiteRegistration* Register_TRANSPOSE_REF() {
164 static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
165 transpose::Eval<transpose::kReference>};
166 return &r;
167}
168
169TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() {
170 static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
171 transpose::Eval<transpose::kGenericOptimized>};
172 return &r;
173}
174
175TfLiteRegistration* Register_TRANSPOSE() {
176 return Register_TRANSPOSE_GENERIC_OPTIMIZED();
177}
178
179} // namespace builtin
180} // namespace ops
181} // namespace tflite
182