1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/broadcast_to.h"
16
17#include <string.h>
18
19#include <cstdint>
20#include <memory>
21
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/internal/tensor.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace broadcastto {
30
31constexpr int kInputTensor = 0;
32constexpr int kShapeTensor = 1;
33constexpr int kOutputTensor = 0;
34constexpr int kMaxDims = 8;
35
36struct BroadcastToContext {
37 BroadcastToContext(TfLiteContext* context, TfLiteNode* node) {
38 input = GetInput(context, node, kInputTensor);
39 shape = GetInput(context, node, kShapeTensor);
40 output = GetOutput(context, node, kOutputTensor);
41 }
42 const TfLiteTensor* input;
43 const TfLiteTensor* shape;
44 TfLiteTensor* output;
45};
46
47TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
48 BroadcastToContext* op_context) {
49 // Ensures the shape is 1D tensor.
50 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1);
51
52 // Ensure output dims is not less than input dims.
53 int input_num_dims = NumDimensions(op_context->input);
54 int output_num_dims = SizeOfDimension(op_context->shape, 0);
55 TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims,
56 "Output shape must be broadcastable from input shape.");
57 TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims,
58 "BroadcastTo only supports 1-8D tensor.");
59
60 // Check if output shape is broadcastable from input shape.
61 auto get_shape_data = [op_context](int i) -> int32_t {
62 if (op_context->shape->type == kTfLiteInt32) {
63 return GetTensorData<int32_t>(op_context->shape)[i];
64 } else {
65 return GetTensorData<int64_t>(op_context->shape)[i];
66 }
67 };
68
69 int extending_dims = output_num_dims - input_num_dims;
70 for (int idx = 0; idx < input_num_dims; ++idx) {
71 TF_LITE_ENSURE_MSG(context,
72 (SizeOfDimension(op_context->input, idx) == 1 ||
73 SizeOfDimension(op_context->input, idx) ==
74 get_shape_data(extending_dims + idx)),
75 "Output shape must be broadcastable from input shape.");
76 }
77 // Resizing the shape of the output tensor.
78 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims);
79 std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
80 scoped_output_shape(output_shape, TfLiteIntArrayFree);
81 for (int idx = 0; idx < output_num_dims; ++idx) {
82 output_shape->data[idx] = get_shape_data(idx);
83 }
84
85 return context->ResizeTensor(context, op_context->output,
86 scoped_output_shape.release());
87}
88
89TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
90 TF_LITE_ENSURE(context, NumInputs(node) == 2);
91 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
92 TF_LITE_ENSURE_MSG(context,
93 (NumDimensions(GetInput(context, node, 0)) <= kMaxDims),
94 "BroadcastTo only supports 1-8D tensor.");
95
96 BroadcastToContext op_context(context, node);
97 TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 ||
98 op_context.shape->type == kTfLiteInt64);
99 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
100
101 // Not yet support string type due to the use of memcopy with fixed size.
102 TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString);
103
104 if (IsConstantTensor(op_context.shape)) {
105 return ResizeOutputTensor(context, &op_context);
106 }
107
108 SetTensorToDynamic(op_context.output);
109 return kTfLiteOk;
110}
111
112TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
113 BroadcastToContext op_context(context, node);
114 if (IsDynamicTensor(op_context.output)) {
115 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
116 }
117
118 // BroadcastTo op support upto 8 dims, matching the support of Tensorflow.
119 reference_ops::BroadcastTo<kMaxDims>(
120 GetTensorShape(op_context.input), op_context.input->data.raw,
121 GetTensorShape(op_context.output), op_context.output->data.raw,
122 op_context.input->type);
123 return kTfLiteOk;
124}
125
126} // namespace broadcastto
127
128TfLiteRegistration* Register_BROADCAST_TO() {
129 static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare,
130 broadcastto::Eval};
131 return &r;
132}
133
134} // namespace builtin
135} // namespace ops
136} // namespace tflite
137