1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace matrix_set_diag {
28
29constexpr int kInputTensor = 0;
30constexpr int kDiagonalTensor = 1;
31constexpr int kOutputTensor = 0;
32
33TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36 const TfLiteTensor* input;
37 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
38 TfLiteIntArray* input_dims = input->dims;
39 int input_dims_size = input_dims->size;
40 TF_LITE_ENSURE(context, input_dims_size >= 2);
41
42 TfLiteTensor* output;
43 TF_LITE_ENSURE_OK(context,
44 GetOutputSafe(context, node, kOutputTensor, &output));
45
46 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(input_dims_size);
47 for (int i = 0; i < input_dims_size; i++) {
48 output_shape->data[i] = input_dims->data[i];
49 }
50
51 // Resize the output tensor to the same size as the input tensor.
52 output->type = input->type;
53 TF_LITE_ENSURE_OK(context,
54 context->ResizeTensor(context, output, output_shape));
55
56 return kTfLiteOk;
57}
58
59// Fill the tensor to make a diagonal matrix in each batch, i.e., when
60// row index and column index are the same, fill with the next diagonal value.
61// All other entries are the same as the input value.
62// TODO(b/128636574) Move to reference_ops.
63template <typename T>
64void FillDiagImpl(const T* in, const T* diag, T* out, const int batch_size,
65 const int row_size, const int col_size) {
66 int idx = 0;
67 for (int b = 0; b < batch_size; b++) {
68 for (int i = 0; i < row_size; i++) {
69 for (int j = 0; j < col_size; ++j) {
70 // diag values go on the diagonal, in values elsewhere
71 if (i == j) {
72 out[i * col_size + j] = diag[idx];
73 idx++;
74 } else {
75 out[i * col_size + j] = in[i * col_size + j];
76 }
77 }
78 }
79 out += row_size * col_size;
80 in += row_size * col_size;
81 }
82}
83
84template <typename T>
85void FillDiag(const TfLiteTensor* input, const TfLiteTensor* diag,
86 TfLiteTensor* output, const int batch_size, const int row_size,
87 const int col_size) {
88 FillDiagImpl<T>(GetTensorData<T>(input), GetTensorData<T>(diag),
89 GetTensorData<T>(output), batch_size, row_size, col_size);
90}
91
92// Fill a tensor with given "diag" values on the diagonal, input values
93// elsewhere.
94void FillDiagHelper(const TfLiteTensor* input, const TfLiteTensor* diag,
95 TfLiteTensor* output) {
96 const int num_output_dims = output->dims->size;
97 int batch_size = 1;
98 for (int i = 0; i < num_output_dims - 2; ++i) {
99 batch_size *= output->dims->data[i];
100 }
101
102 const int row_size = output->dims->data[num_output_dims - 2];
103 const int col_size = output->dims->data[num_output_dims - 1];
104 switch (output->type) {
105 case kTfLiteInt64: {
106 return FillDiag<int64_t>(input, diag, output, batch_size, row_size,
107 col_size);
108 }
109 case kTfLiteInt32: {
110 return FillDiag<int32_t>(input, diag, output, batch_size, row_size,
111 col_size);
112 }
113 case kTfLiteInt16: {
114 return FillDiag<int16_t>(input, diag, output, batch_size, row_size,
115 col_size);
116 }
117 case kTfLiteInt8: {
118 return FillDiag<int8_t>(input, diag, output, batch_size, row_size,
119 col_size);
120 }
121 case kTfLiteUInt8: {
122 return FillDiag<uint8_t>(input, diag, output, batch_size, row_size,
123 col_size);
124 }
125 default:
126 return FillDiag<float>(input, diag, output, batch_size, row_size,
127 col_size);
128 }
129}
130
131TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132 TfLiteTensor* output;
133 TF_LITE_ENSURE_OK(context,
134 GetOutputSafe(context, node, kOutputTensor, &output));
135 const TfLiteTensor* input;
136 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
137 const TfLiteTensor* diag;
138 TF_LITE_ENSURE_OK(context,
139 GetInputSafe(context, node, kDiagonalTensor, &diag));
140 FillDiagHelper(input, diag, output);
141 return kTfLiteOk;
142}
143
144} // namespace matrix_set_diag
145
146TfLiteRegistration* Register_MATRIX_SET_DIAG() {
147 static TfLiteRegistration r = {nullptr, nullptr, matrix_set_diag::Prepare,
148 matrix_set_diag::Eval};
149 return &r;
150}
151
152} // namespace builtin
153} // namespace ops
154} // namespace tflite
155