1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/compatibility.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace batch_to_space_nd {
29
30// This file has two implementations of BatchToSpaceND.
31enum KernelType {
32 kReference,
33 kGenericOptimized,
34};
35
36struct BatchToSpaceNDContext {
37 BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, 0);
39 block_shape = GetInput(context, node, 1);
40 crops = GetInput(context, node, 2);
41 output = GetOutput(context, node, 0);
42 }
43 const TfLiteTensor* input;
44 const TfLiteTensor* block_shape;
45 const TfLiteTensor* crops;
46 TfLiteTensor* output;
47};
48
49// Currently, only 3D NHC or 4D NHWC input/output op_context are supported.
50// In case of 3D input,it will be converted to 4D by adding W=1 to be NH1C.
51// The 4D array need to have exactly 2 spatial dimensions.
52// TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
53const int kInputMinDimensionNum = 3;
54const int kInputMaxDimensionNum = 4;
55
56TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57 BatchToSpaceNDContext* op_context) {
58 TfLiteIntArray* input_size = op_context->input->dims;
59 const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60 const int* crops = GetTensorData<int32>(op_context->crops);
61
62 int spatial_dims_num = input_size->size - 2;
63 // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
64 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
65 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
66 spatial_dims_num);
67 // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
68 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), 2);
69 TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[0],
70 spatial_dims_num);
71 TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[1], 2);
72
73 for (int i = 0; i < spatial_dims_num * 2; ++i) {
74 TF_LITE_ENSURE(context, crops[i] >= 0);
75 }
76
77 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
78 int output_batch_size = input_size->data[0];
79 for (int dim = 0; dim < spatial_dims_num; ++dim) {
80 // Number of batch must be multiple of (block_shape[dim]).
81 TF_LITE_ENSURE(context, block_shape[dim] != 0);
82 TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape[dim], 0);
83 output_batch_size = output_batch_size / block_shape[dim];
84 output_size->data[dim + 1] = input_size->data[dim + 1] * block_shape[dim] -
85 crops[dim * 2] - crops[dim * 2 + 1];
86 }
87
88 output_size->data[0] = output_batch_size;
89 output_size->data[input_size->size - 1] =
90 input_size->data[input_size->size - 1];
91
92 return context->ResizeTensor(context, op_context->output, output_size);
93}
94
95TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
96 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
97 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
98
99 BatchToSpaceNDContext op_context(context, node);
100 TF_LITE_ENSURE(context,
101 NumDimensions(op_context.input) >= kInputMinDimensionNum);
102 TF_LITE_ENSURE(context,
103 NumDimensions(op_context.input) <= kInputMaxDimensionNum);
104 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
105
106 if (!IsConstantTensor(op_context.block_shape) ||
107 !IsConstantTensor(op_context.crops)) {
108 SetTensorToDynamic(op_context.output);
109 return kTfLiteOk;
110 }
111 return ResizeOutputTensor(context, &op_context);
112}
113
114template <KernelType kernel_type>
115TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
116 BatchToSpaceNDContext op_context(context, node);
117
118 // Resize the output tensor if the output tensor is dynamic.
119 if (IsDynamicTensor(op_context.output)) {
120 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
121 }
122
123#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
124 type::BatchToSpaceND(GetTensorShape(op_context.input), \
125 GetTensorData<scalar>(op_context.input), \
126 GetTensorShape(op_context.block_shape), \
127 GetTensorData<int32_t>(op_context.block_shape), \
128 GetTensorShape(op_context.crops), \
129 GetTensorData<int32_t>(op_context.crops), \
130 GetTensorShape(op_context.output), \
131 GetTensorData<scalar>(op_context.output))
132 switch (op_context.input->type) { // Already know in/out types are same.
133 case kTfLiteFloat32:
134 if (kernel_type == kReference) {
135 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
136 } else {
137 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
138 }
139 break;
140 case kTfLiteUInt8:
141 if (kernel_type == kReference) {
142 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
143 } else {
144 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
145 }
146 break;
147 case kTfLiteInt8:
148 if (kernel_type == kReference) {
149 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
150 } else {
151 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
152 }
153 break;
154 case kTfLiteInt32:
155 if (kernel_type == kReference) {
156 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
157 } else {
158 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
159 }
160 break;
161 case kTfLiteInt64:
162 if (kernel_type == kReference) {
163 TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
164 } else {
165 TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
166 }
167 break;
168 default:
169 TF_LITE_KERNEL_LOG(context,
170 "Type %d is currently not supported by BatchToSpace.",
171 op_context.input->type);
172 return kTfLiteError;
173 }
174#undef TF_LITE_BATCH_TO_SPACE_ND
175 return kTfLiteOk;
176}
177
178} // namespace batch_to_space_nd
179
180TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
181 static TfLiteRegistration r = {
182 nullptr, nullptr, batch_to_space_nd::Prepare,
183 batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
184 return &r;
185}
186
187TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
188 static TfLiteRegistration r = {
189 nullptr, nullptr, batch_to_space_nd::Prepare,
190 batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
191 return &r;
192}
193
194TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
195 return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
196}
197
198} // namespace builtin
199} // namespace ops
200} // namespace tflite
201