1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/broadcast_args.h"
16
17#include <algorithm>
18#include <cstdint>
19#include <memory>
20
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/tensor.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace broadcast_args {
29
30constexpr int kShape1Tensor = 0;
31constexpr int kShape2Tensor = 1;
32constexpr int kOutputTensor = 0;
33
34struct BroadcastArgsContext {
35 BroadcastArgsContext(TfLiteContext* context, TfLiteNode* node) {
36 shape1 = GetInput(context, node, kShape1Tensor);
37 shape2 = GetInput(context, node, kShape2Tensor);
38 output = GetOutput(context, node, kOutputTensor);
39 }
40 const TfLiteTensor* shape1;
41 const TfLiteTensor* shape2;
42 TfLiteTensor* output;
43};
44
45TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
46 TF_LITE_ENSURE(context, NumInputs(node) == 2);
47 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
48
49 BroadcastArgsContext op_context(context, node);
50 TF_LITE_ENSURE(context, op_context.shape1->type == kTfLiteInt32 ||
51 op_context.shape1->type == kTfLiteInt64);
52 TF_LITE_ENSURE_EQ(context, op_context.shape1->type, op_context.shape2->type);
53 TF_LITE_ENSURE_EQ(context, op_context.shape1->type, op_context.output->type);
54
55 // Ensures the shapes are 1D tensor.
56 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.shape1), 1);
57 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.shape2), 1);
58
59 // Resizing the shape of the output tensor.
60 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
61 output_shape->data[0] = std::max(SizeOfDimension(op_context.shape1, 0),
62 SizeOfDimension(op_context.shape2, 0));
63 return context->ResizeTensor(context, op_context.output, output_shape);
64}
65
66TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
67 BroadcastArgsContext op_context(context, node);
68
69#define TF_LITE_BROADCAST_ARG(data_type) \
70 reference_ops::BroadcastArgs(GetTensorShape(op_context.shape1), \
71 GetTensorData<data_type>(op_context.shape1), \
72 GetTensorShape(op_context.shape2), \
73 GetTensorData<data_type>(op_context.shape2), \
74 GetTensorShape(op_context.output), \
75 GetTensorData<data_type>(op_context.output))
76
77 if (op_context.output->type == kTfLiteInt32) {
78 TF_LITE_BROADCAST_ARG(int32_t);
79 } else {
80 TF_LITE_BROADCAST_ARG(int64_t);
81 }
82#undef TF_LITE_BROADCAST_ARG
83
84 return kTfLiteOk;
85}
86
87} // namespace broadcast_args
88
89TfLiteRegistration* Register_BROADCAST_ARGS() {
90 static TfLiteRegistration r = {nullptr, nullptr, broadcast_args::Prepare,
91 broadcast_args::Eval};
92 return &r;
93}
94
95} // namespace builtin
96} // namespace ops
97} // namespace tflite
98