1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/compatibility.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace space_to_batch_nd {
30
31// This file has two implementations of SpaceToBatchND.
32enum KernelType {
33 kReference,
34 kGenericOptimized,
35};
36
37struct SpaceToBatchNDContext {
38 SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
39 input = GetInput(context, node, 0);
40 block_shape = GetInput(context, node, 1);
41 paddings = GetInput(context, node, 2);
42 output = GetOutput(context, node, 0);
43 }
44 const TfLiteTensor* input;
45 const TfLiteTensor* block_shape;
46 const TfLiteTensor* paddings;
47 TfLiteTensor* output;
48};
49
50// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
51// In case of 3D input, it will be extended to 3D NHWC by adding W=1.
52// The 4D array need to have exactly 2 spatial dimensions.
53// TODO(b/149952582): Support arbitrary dimension in SpaceToBatchND.
54const int kInputMinDimensionNum = 3;
55const int kInputMaxDimensionNum = 4;
56
57TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
58 SpaceToBatchNDContext* op_context) {
59 TfLiteIntArray* input_size = op_context->input->dims;
60 const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
61 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
62
63 int spatial_dims_num = input_size->size - 2;
64 // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
65 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
66 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
67 spatial_dims_num);
68 // Paddings should be a 2D tensor with dimension [spatial_dims_num, 2].
69 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 2);
70 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[0],
71 spatial_dims_num);
72 TF_LITE_ENSURE_EQ(context, op_context->paddings->dims->data[1], 2);
73
74 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
75
76 // Ensures the input height and width (with padding) is a multiple of block
77 // shape height and width.
78 int output_batch_size = input_size->data[0];
79 for (int dim = 0; dim < spatial_dims_num; ++dim) {
80 int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
81 paddings_data[dim * 2 + 1]);
82 TF_LITE_ENSURE(context, block_shape[dim] != 0);
83 TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
84 output_size->data[dim + 1] = final_dim_size / block_shape[dim];
85 output_batch_size *= block_shape[dim];
86 }
87
88 output_size->data[0] = output_batch_size;
89 output_size->data[input_size->size - 1] =
90 input_size->data[input_size->size - 1];
91
92 return context->ResizeTensor(context, op_context->output, output_size);
93}
94
95TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
96 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
97 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
98
99 SpaceToBatchNDContext op_context(context, node);
100 TF_LITE_ENSURE(context,
101 NumDimensions(op_context.input) >= kInputMinDimensionNum);
102 TF_LITE_ENSURE(context,
103 NumDimensions(op_context.input) <= kInputMaxDimensionNum);
104 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
105 op_context.output->type);
106
107 if (!IsConstantTensor(op_context.block_shape) ||
108 !IsConstantTensor(op_context.paddings)) {
109 SetTensorToDynamic(op_context.output);
110 return kTfLiteOk;
111 }
112 return ResizeOutputTensor(context, &op_context);
113}
114
115template <KernelType kernel_type>
116TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
117 SpaceToBatchNDContext op_context(context, node);
118
119 // Resize the output tensor if the output tensor is dynamic.
120 if (IsDynamicTensor(op_context.output)) {
121 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
122 }
123
124#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
125 tflite::SpaceToBatchParams op_params; \
126 op_params.output_offset = pad_value; \
127 type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
128 GetTensorData<scalar>(op_context.input), \
129 GetTensorShape(op_context.block_shape), \
130 GetTensorData<int32_t>(op_context.block_shape), \
131 GetTensorShape(op_context.paddings), \
132 GetTensorData<int32_t>(op_context.paddings), \
133 GetTensorShape(op_context.output), \
134 GetTensorData<scalar>(op_context.output))
135 switch (op_context.input->type) { // Already know in/out types are same.
136 case kTfLiteFloat32:
137 if (kernel_type == kReference) {
138 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0);
139 } else {
140 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0);
141 }
142 break;
143 case kTfLiteUInt8:
144 if (kernel_type == kReference) {
145 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t,
146 op_context.output->params.zero_point);
147 } else {
148 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t,
149 op_context.output->params.zero_point);
150 }
151 break;
152 case kTfLiteInt8:
153 if (kernel_type == kReference) {
154 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
155 op_context.output->params.zero_point);
156 } else {
157 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
158 op_context.output->params.zero_point);
159 }
160 break;
161 case kTfLiteInt32:
162 if (kernel_type == kReference) {
163 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);
164 } else {
165 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0);
166 }
167 break;
168 case kTfLiteInt64:
169 if (kernel_type == kReference) {
170 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0);
171 } else {
172 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0);
173 }
174 break;
175 default:
176 TF_LITE_KERNEL_LOG(context,
177 "Type %d is currently not supported by SpaceToBatch.",
178 op_context.input->type);
179 return kTfLiteError;
180 }
181#undef TF_LITE_SPACE_TO_BATCH_ND
182 return kTfLiteOk;
183}
184
185} // namespace space_to_batch_nd
186
187TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
188 static TfLiteRegistration r = {
189 nullptr, nullptr, space_to_batch_nd::Prepare,
190 space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
191 return &r;
192}
193
194TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
195 static TfLiteRegistration r = {
196 nullptr, nullptr, space_to_batch_nd::Prepare,
197 space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
198 return &r;
199}
200
201TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
202 // return Register_SPACE_TO_BATCH_ND_REF();
203 return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
204}
205
206} // namespace builtin
207} // namespace ops
208} // namespace tflite
209