1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace depth_to_space {
30
31// This file has two implementation of DepthToSpace. Note that DepthToSpace only
32// works on 4D tensors.
33enum KernelType {
34 kReference,
35 kGenericOptimized,
36};
37
38constexpr int kInputTensor = 0;
39constexpr int kOutputTensor = 0;
40
41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 auto* params =
43 reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
44
45 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
47
48 const TfLiteTensor* input;
49 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
50 TfLiteTensor* output;
51 TF_LITE_ENSURE_OK(context,
52 GetOutputSafe(context, node, kOutputTensor, &output));
53
54 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
55
56 auto data_type = output->type;
57 TF_LITE_ENSURE(context,
58 data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
59 data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
60 data_type == kTfLiteInt64);
61 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
62
63 const int block_size = params->block_size;
64 TF_LITE_ENSURE(context, block_size > 0);
65 const int input_height = input->dims->data[1];
66 const int input_width = input->dims->data[2];
67 const int input_channels = input->dims->data[3];
68 int output_height = input_height * block_size;
69 int output_width = input_width * block_size;
70 int output_channels = input_channels / block_size / block_size;
71
72 TF_LITE_ENSURE_EQ(context, input_height, output_height / block_size);
73 TF_LITE_ENSURE_EQ(context, input_width, output_width / block_size);
74 TF_LITE_ENSURE_EQ(context, input_channels,
75 output_channels * block_size * block_size);
76
77 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
78 output_size->data[0] = input->dims->data[0];
79 output_size->data[1] = output_height;
80 output_size->data[2] = output_width;
81 output_size->data[3] = output_channels;
82
83 return context->ResizeTensor(context, output, output_size);
84}
85
86template <KernelType kernel_type>
87TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
88 auto* params =
89 reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
90
91 const TfLiteTensor* input;
92 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
93 TfLiteTensor* output;
94 TF_LITE_ENSURE_OK(context,
95 GetOutputSafe(context, node, kOutputTensor, &output));
96
97#define TF_LITE_DEPTH_TO_SPACE(type, scalar) \
98 tflite::DepthToSpaceParams op_params; \
99 op_params.block_size = params->block_size; \
100 type::DepthToSpace(op_params, GetTensorShape(input), \
101 GetTensorData<scalar>(input), GetTensorShape(output), \
102 GetTensorData<scalar>(output))
103 switch (input->type) { // Already know in/out types are same.
104 case kTfLiteFloat32:
105 if (kernel_type == kReference) {
106 TF_LITE_DEPTH_TO_SPACE(reference_ops, float);
107 } else {
108 TF_LITE_DEPTH_TO_SPACE(optimized_ops, float);
109 }
110 break;
111 case kTfLiteUInt8:
112 if (kernel_type == kReference) {
113 TF_LITE_DEPTH_TO_SPACE(reference_ops, uint8_t);
114 } else {
115 TF_LITE_DEPTH_TO_SPACE(optimized_ops, uint8_t);
116 }
117 break;
118 case kTfLiteInt8:
119 if (kernel_type == kReference) {
120 TF_LITE_DEPTH_TO_SPACE(reference_ops, int8_t);
121 } else {
122 TF_LITE_DEPTH_TO_SPACE(optimized_ops, int8_t);
123 }
124 break;
125 case kTfLiteInt32:
126 if (kernel_type == kReference) {
127 TF_LITE_DEPTH_TO_SPACE(reference_ops, int32_t);
128 } else {
129 TF_LITE_DEPTH_TO_SPACE(optimized_ops, int32_t);
130 }
131 break;
132 case kTfLiteInt64:
133 if (kernel_type == kReference) {
134 TF_LITE_DEPTH_TO_SPACE(reference_ops, int64_t);
135 } else {
136 TF_LITE_DEPTH_TO_SPACE(optimized_ops, int64_t);
137 }
138 break;
139 default:
140 TF_LITE_KERNEL_LOG(context, "Type '%s' not currently supported.",
141 TfLiteTypeGetName(input->type));
142 return kTfLiteError;
143 }
144#undef TF_LITE_DEPTH_TO_SPACE
145
146 return kTfLiteOk;
147}
148
149} // namespace depth_to_space
150
151TfLiteRegistration* Register_DEPTH_TO_SPACE_REF() {
152 static TfLiteRegistration r = {
153 nullptr, nullptr, depth_to_space::Prepare,
154 depth_to_space::Eval<depth_to_space::kReference>};
155 return &r;
156}
157
158TfLiteRegistration* Register_DEPTH_TO_SPACE_GENERIC_OPT() {
159 static TfLiteRegistration r = {
160 nullptr, nullptr, depth_to_space::Prepare,
161 depth_to_space::Eval<depth_to_space::kGenericOptimized>};
162 return &r;
163}
164
165TfLiteRegistration* Register_DEPTH_TO_SPACE() {
166 return Register_DEPTH_TO_SPACE_GENERIC_OPT();
167}
168
169} // namespace builtin
170} // namespace ops
171} // namespace tflite
172