1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/c/builtin_op_data.h"
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19#include "tensorflow/lite/kernels/internal/tensor.h"
20#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21#include "tensorflow/lite/kernels/kernel_util.h"
22
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace cumsum {
28
29static const int kInputTensor = 0;
30static const int kAxisTensor = 1;
31static const int kOutputTensor = 0;
32
33TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36
37 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
38 const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
39
40 TF_LITE_ENSURE(context, input->type == kTfLiteInt32 ||
41 input->type == kTfLiteFloat32 ||
42 input->type == kTfLiteInt64);
43 TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
44
45 TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
46
47 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
48
49 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
50
51 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
52 return context->ResizeTensor(context, output, output_shape);
53}
54
55TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
56 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
57 const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
58
59 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
60
61 auto* params = reinterpret_cast<TfLiteCumsumParams*>(node->builtin_data);
62
63 int axis = *GetTensorData<int>(axis_tensor);
64 if (axis < 0) axis += NumDimensions(input);
65
66 if (axis < 0 || axis >= NumDimensions(input)) {
67 TF_LITE_KERNEL_LOG(context, "Invalid axis: ", axis);
68 return kTfLiteError;
69 }
70
71 switch (input->type) {
72 case kTfLiteInt32: {
73 optimized_ops::CumSum(GetTensorData<int>(input), GetTensorShape(input),
74 axis, params->exclusive, params->reverse,
75 GetTensorData<int>(output));
76 break;
77 }
78 case kTfLiteInt64: {
79 optimized_ops::CumSum(GetTensorData<int64_t>(input),
80 GetTensorShape(input), axis, params->exclusive,
81 params->reverse, GetTensorData<int64_t>(output));
82 break;
83 }
84 case kTfLiteFloat32: {
85 optimized_ops::CumSum(GetTensorData<float>(input), GetTensorShape(input),
86 axis, params->exclusive, params->reverse,
87 GetTensorData<float>(output));
88 break;
89 }
90 default: {
91 TF_LITE_KERNEL_LOG(
92 context,
93 "Unsupported input type, cumsum only supports int32 & float32.");
94 return kTfLiteError;
95 }
96 }
97
98 return kTfLiteOk;
99}
100
101} // namespace cumsum
102
103TfLiteRegistration* Register_CUMSUM() {
104 static TfLiteRegistration r = {nullptr, nullptr, cumsum::Prepare,
105 cumsum::Eval};
106 return &r;
107}
108
109} // namespace builtin
110} // namespace ops
111} // namespace tflite
112