1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace matrix_diag {
28
29constexpr int kInputTensor = 0;
30constexpr int kOutputTensor = 0;
31
32TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 const TfLiteTensor* input;
36 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
37 TfLiteIntArray* input_dims = input->dims;
38 int input_dims_size = input_dims->size;
39 TF_LITE_ENSURE(context, input_dims_size >= 1);
40
41 TfLiteTensor* output;
42 TF_LITE_ENSURE_OK(context,
43 GetOutputSafe(context, node, kOutputTensor, &output));
44 // Resize the output tensor.
45 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size + 1);
46 for (int i = 0; i < input_dims_size; i++) {
47 output_shape->data[i] = input_dims->data[i];
48 }
49 // Last dimension in the output is the same as the last dimension in the
50 // input.
51 output_shape->data[input_dims_size] = input_dims->data[input_dims_size - 1];
52 output->type = input->type;
53 TF_LITE_ENSURE_OK(context,
54 context->ResizeTensor(context, output, output_shape));
55
56 return kTfLiteOk;
57}
58
59// Fill the tensor to make a diagonal matrix in each batch, i.e., when
60// row index and column index are the same, fill with the next input value.
61// All other entries get zero.
62// TODO(b/128636574) Move to reference_ops.
63template <typename T>
64void FillDiagImpl(const T* in, T* out, const int batch_size, const int row_size,
65 const int col_size) {
66 int idx = 0;
67 for (int b = 0; b < batch_size; b++) {
68 for (int i = 0; i < row_size; i++) {
69 for (int j = 0; j < col_size; ++j) {
70 // input values go on the diagonal, 0 elsewhere
71 if (i == j) {
72 out[i * col_size + j] = in[idx];
73 idx++;
74 } else {
75 out[i * col_size + j] = 0;
76 }
77 }
78 }
79 out += row_size * col_size;
80 }
81}
82
83template <typename T>
84void FillDiag(const TfLiteTensor* input, TfLiteTensor* output,
85 const int batch_size, const int row_size, const int col_size) {
86 FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(output), batch_size,
87 row_size, col_size);
88}
89
90// Fill a tensor with given input on the diagonal, zero elsewhere
91void FillDiagHelper(const TfLiteTensor* input, TfLiteTensor* output) {
92 const int num_output_dims = output->dims->size;
93 int batch_size = 1;
94 for (int i = 0; i < num_output_dims - 2; ++i) {
95 batch_size *= output->dims->data[i];
96 }
97
98 const int row_size = output->dims->data[num_output_dims - 2];
99 const int col_size = output->dims->data[num_output_dims - 1];
100 switch (output->type) {
101 case kTfLiteInt64: {
102 return FillDiag<int64_t>(input, output, batch_size, row_size, col_size);
103 }
104 case kTfLiteInt32: {
105 return FillDiag<int32_t>(input, output, batch_size, row_size, col_size);
106 }
107 case kTfLiteInt16: {
108 return FillDiag<int16_t>(input, output, batch_size, row_size, col_size);
109 }
110 case kTfLiteInt8: {
111 return FillDiag<int8_t>(input, output, batch_size, row_size, col_size);
112 }
113 case kTfLiteUInt8: {
114 return FillDiag<uint8_t>(input, output, batch_size, row_size, col_size);
115 }
116 default:
117 return FillDiag<float>(input, output, batch_size, row_size, col_size);
118 }
119}
120
121TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
122 TfLiteTensor* output;
123 TF_LITE_ENSURE_OK(context,
124 GetOutputSafe(context, node, kOutputTensor, &output));
125 const TfLiteTensor* input;
126 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
127 FillDiagHelper(input, output);
128 return kTfLiteOk;
129}
130
131} // namespace matrix_diag
132
133TfLiteRegistration* Register_MATRIX_DIAG() {
134 static TfLiteRegistration r = {nullptr, nullptr, matrix_diag::Prepare,
135 matrix_diag::Eval};
136 return &r;
137}
138
139} // namespace builtin
140} // namespace ops
141} // namespace tflite
142