1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/broadcast_to.h" |
16 | |
17 | #include <string.h> |
18 | |
19 | #include <cstdint> |
20 | #include <memory> |
21 | |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace broadcastto { |
30 | |
31 | constexpr int kInputTensor = 0; |
32 | constexpr int kShapeTensor = 1; |
33 | constexpr int kOutputTensor = 0; |
34 | constexpr int kMaxDims = 8; |
35 | |
36 | struct BroadcastToContext { |
37 | BroadcastToContext(TfLiteContext* context, TfLiteNode* node) { |
38 | input = GetInput(context, node, kInputTensor); |
39 | shape = GetInput(context, node, kShapeTensor); |
40 | output = GetOutput(context, node, kOutputTensor); |
41 | } |
42 | const TfLiteTensor* input; |
43 | const TfLiteTensor* shape; |
44 | TfLiteTensor* output; |
45 | }; |
46 | |
47 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
48 | BroadcastToContext* op_context) { |
49 | // Ensures the shape is 1D tensor. |
50 | TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->shape), 1); |
51 | |
52 | // Ensure output dims is not less than input dims. |
53 | int input_num_dims = NumDimensions(op_context->input); |
54 | int output_num_dims = SizeOfDimension(op_context->shape, 0); |
55 | TF_LITE_ENSURE_MSG(context, input_num_dims <= output_num_dims, |
56 | "Output shape must be broadcastable from input shape." ); |
57 | TF_LITE_ENSURE_MSG(context, output_num_dims <= kMaxDims, |
58 | "BroadcastTo only supports 1-8D tensor." ); |
59 | |
60 | // Check if output shape is broadcastable from input shape. |
61 | auto get_shape_data = [op_context](int i) -> int32_t { |
62 | if (op_context->shape->type == kTfLiteInt32) { |
63 | return GetTensorData<int32_t>(op_context->shape)[i]; |
64 | } else { |
65 | return GetTensorData<int64_t>(op_context->shape)[i]; |
66 | } |
67 | }; |
68 | |
69 | int extending_dims = output_num_dims - input_num_dims; |
70 | for (int idx = 0; idx < input_num_dims; ++idx) { |
71 | TF_LITE_ENSURE_MSG(context, |
72 | (SizeOfDimension(op_context->input, idx) == 1 || |
73 | SizeOfDimension(op_context->input, idx) == |
74 | get_shape_data(extending_dims + idx)), |
75 | "Output shape must be broadcastable from input shape." ); |
76 | } |
77 | // Resizing the shape of the output tensor. |
78 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_num_dims); |
79 | std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> |
80 | scoped_output_shape(output_shape, TfLiteIntArrayFree); |
81 | for (int idx = 0; idx < output_num_dims; ++idx) { |
82 | output_shape->data[idx] = get_shape_data(idx); |
83 | } |
84 | |
85 | return context->ResizeTensor(context, op_context->output, |
86 | scoped_output_shape.release()); |
87 | } |
88 | |
89 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
90 | TF_LITE_ENSURE(context, NumInputs(node) == 2); |
91 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
92 | TF_LITE_ENSURE_MSG(context, |
93 | (NumDimensions(GetInput(context, node, 0)) <= kMaxDims), |
94 | "BroadcastTo only supports 1-8D tensor." ); |
95 | |
96 | BroadcastToContext op_context(context, node); |
97 | TF_LITE_ENSURE(context, op_context.shape->type == kTfLiteInt32 || |
98 | op_context.shape->type == kTfLiteInt64); |
99 | TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); |
100 | |
101 | // Not yet support string type due to the use of memcopy with fixed size. |
102 | TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString); |
103 | |
104 | if (IsConstantTensor(op_context.shape)) { |
105 | return ResizeOutputTensor(context, &op_context); |
106 | } |
107 | |
108 | SetTensorToDynamic(op_context.output); |
109 | return kTfLiteOk; |
110 | } |
111 | |
112 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
113 | BroadcastToContext op_context(context, node); |
114 | if (IsDynamicTensor(op_context.output)) { |
115 | TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); |
116 | } |
117 | |
118 | // BroadcastTo op support upto 8 dims, matching the support of Tensorflow. |
119 | reference_ops::BroadcastTo<kMaxDims>( |
120 | GetTensorShape(op_context.input), op_context.input->data.raw, |
121 | GetTensorShape(op_context.output), op_context.output->data.raw, |
122 | op_context.input->type); |
123 | return kTfLiteOk; |
124 | } |
125 | |
126 | } // namespace broadcastto |
127 | |
128 | TfLiteRegistration* Register_BROADCAST_TO() { |
129 | static TfLiteRegistration r = {nullptr, nullptr, broadcastto::Prepare, |
130 | broadcastto::Eval}; |
131 | return &r; |
132 | } |
133 | |
134 | } // namespace builtin |
135 | } // namespace ops |
136 | } // namespace tflite |
137 | |