1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace space_to_depth {
30
31// This file has two implementation of SpaceToDepth. Note that SpaceToDepth
32// only works on 4D tensors.
33enum KernelType {
34 kReference,
35 kGenericOptimized,
36};
37
38constexpr int kInputTensor = 0;
39constexpr int kOutputTensor = 0;
40
41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 auto* params =
43 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
44
45 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
47
48 const TfLiteTensor* input;
49 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
50 TfLiteTensor* output;
51 TF_LITE_ENSURE_OK(context,
52 GetOutputSafe(context, node, kOutputTensor, &output));
53
54 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
55
56 auto data_type = output->type;
57 TF_LITE_ENSURE(context,
58 data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
59 data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
60 data_type == kTfLiteInt64);
61 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
62
63 const int block_size = params->block_size;
64 TF_LITE_ENSURE(context, block_size > 0);
65 const int input_height = input->dims->data[1];
66 const int input_width = input->dims->data[2];
67 int output_height = input_height / block_size;
68 int output_width = input_width / block_size;
69
70 TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size);
71 TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size);
72
73 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
74 output_size->data[0] = input->dims->data[0];
75 output_size->data[1] = output_height;
76 output_size->data[2] = output_width;
77 output_size->data[3] = input->dims->data[3] * block_size * block_size;
78
79 return context->ResizeTensor(context, output, output_size);
80}
81
82template <KernelType kernel_type>
83TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
84 auto* params =
85 reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
86
87 const TfLiteTensor* input;
88 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
89 TfLiteTensor* output;
90 TF_LITE_ENSURE_OK(context,
91 GetOutputSafe(context, node, kOutputTensor, &output));
92
93#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
94 tflite::SpaceToDepthParams op_params; \
95 op_params.block_size = params->block_size; \
96 type::SpaceToDepth(op_params, GetTensorShape(input), \
97 GetTensorData<scalar>(input), GetTensorShape(output), \
98 GetTensorData<scalar>(output))
99 switch (input->type) { // Already know in/out types are same.
100 case kTfLiteFloat32:
101 if (kernel_type == kReference) {
102 TF_LITE_SPACE_TO_DEPTH(reference_ops, float);
103 } else {
104 TF_LITE_SPACE_TO_DEPTH(optimized_ops, float);
105 }
106 break;
107 case kTfLiteUInt8:
108 if (kernel_type == kReference) {
109 TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t);
110 } else {
111 TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
112 }
113 break;
114 case kTfLiteInt8:
115 if (kernel_type == kReference) {
116 TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t);
117 } else {
118 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t);
119 }
120 break;
121 case kTfLiteInt32:
122 if (kernel_type == kReference) {
123 TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
124 } else {
125 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t);
126 }
127 break;
128 case kTfLiteInt64:
129 if (kernel_type == kReference) {
130 TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t);
131 } else {
132 TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t);
133 }
134 break;
135 default:
136 TF_LITE_KERNEL_LOG(context, "Type '%s' not currently supported.",
137 TfLiteTypeGetName(input->type));
138 return kTfLiteError;
139 }
140#undef TF_LITE_SPACE_TO_DEPTH
141
142 return kTfLiteOk;
143}
144
145} // namespace space_to_depth
146
147TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() {
148 static TfLiteRegistration r = {
149 nullptr, nullptr, space_to_depth::Prepare,
150 space_to_depth::Eval<space_to_depth::kReference>};
151 return &r;
152}
153
154TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() {
155 static TfLiteRegistration r = {
156 nullptr, nullptr, space_to_depth::Prepare,
157 space_to_depth::Eval<space_to_depth::kGenericOptimized>};
158 return &r;
159}
160
161TfLiteRegistration* Register_SPACE_TO_DEPTH() {
162 return Register_SPACE_TO_DEPTH_GENERIC_OPT();
163}
164
165} // namespace builtin
166} // namespace ops
167} // namespace tflite
168