1// This file is MACHINE GENERATED! Do not edit.
2
3#ifndef TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_
4#define TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_
5
6// This file is MACHINE GENERATED! Do not edit.
7
8#include "tensorflow/cc/framework/ops.h"
9#include "tensorflow/cc/framework/scope.h"
10#include "tensorflow/core/framework/tensor.h"
11#include "tensorflow/core/framework/tensor_shape.h"
12#include "tensorflow/core/framework/types.h"
13#include "tensorflow/core/lib/gtl/array_slice.h"
14
15namespace tensorflow {
16namespace ops {
17namespace internal {
18// NOTE: This namespace has internal TensorFlow details that
19// are not part of TensorFlow's public API.
20
21/// @defgroup linalg_ops_internal Linalg Ops Internal
22/// @{
23
24/// TODO: add doc.
25///
26/// Args:
27/// * scope: A Scope object
28///
29/// Returns:
30/// * `Output`: The output tensor.
31class BandedTriangularSolve {
32 public:
33 /// Optional attribute setters for BandedTriangularSolve
34 struct Attrs {
35 /// Defaults to true
36 TF_MUST_USE_RESULT Attrs Lower(bool x) {
37 Attrs ret = *this;
38 ret.lower_ = x;
39 return ret;
40 }
41
42 /// Defaults to false
43 TF_MUST_USE_RESULT Attrs Adjoint(bool x) {
44 Attrs ret = *this;
45 ret.adjoint_ = x;
46 return ret;
47 }
48
49 bool lower_ = true;
50 bool adjoint_ = false;
51 };
52 BandedTriangularSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input
53 matrix, ::tensorflow::Input rhs);
54 BandedTriangularSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input
55 matrix, ::tensorflow::Input rhs, const
56 BandedTriangularSolve::Attrs& attrs);
57 operator ::tensorflow::Output() const { return output; }
58 operator ::tensorflow::Input() const { return output; }
59 ::tensorflow::Node* node() const { return output.node(); }
60
61 static Attrs Lower(bool x) {
62 return Attrs().Lower(x);
63 }
64 static Attrs Adjoint(bool x) {
65 return Attrs().Adjoint(x);
66 }
67
68 Operation operation;
69 ::tensorflow::Output output;
70};
71
72/// Computes the matrix logarithm of one or more square matrices:
73///
74///
75/// \\(log(exp(A)) = A\\)
76///
77/// This op is only defined for complex matrices. If A is positive-definite and
78/// real, then casting to a complex matrix, taking the logarithm and casting back
79/// to a real matrix will give the correct result.
80///
81/// This function computes the matrix logarithm using the Schur-Parlett algorithm.
82/// Details of the algorithm can be found in Section 11.6.2 of:
83/// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008.
84/// ISBN 978-0-898716-46-7.
85///
86/// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
87/// form square matrices. The output is a tensor of the same shape as the input
88/// containing the exponential for all input submatrices `[..., :, :]`.
89///
90/// Args:
91/// * scope: A Scope object
92/// * input: Shape is `[..., M, M]`.
93///
94/// Returns:
95/// * `Output`: Shape is `[..., M, M]`.
96///
97/// @compatibility(scipy)
98/// Equivalent to scipy.linalg.logm
99/// @end_compatibility
100class MatrixLogarithm {
101 public:
102 MatrixLogarithm(const ::tensorflow::Scope& scope, ::tensorflow::Input input);
103 operator ::tensorflow::Output() const { return output; }
104 operator ::tensorflow::Input() const { return output; }
105 ::tensorflow::Node* node() const { return output.node(); }
106
107 Operation operation;
108 ::tensorflow::Output output;
109};
110
111/// Calculate product with tridiagonal matrix.
112///
113/// Calculates product of two matrices, where left matrix is a tridiagonal matrix.
114///
115/// Args:
116/// * scope: A Scope object
117/// * superdiag: Tensor of shape `[..., 1, M]`, representing superdiagonals of
118/// tri-diagonal matrices to the left of multiplication. Last element is ignored.
119/// * maindiag: Tensor of shape `[..., 1, M]`, representing main diagonals of tri-diagonal
120/// matrices to the left of multiplication.
121/// * subdiag: Tensor of shape `[..., 1, M]`, representing subdiagonals of tri-diagonal
122/// matrices to the left of multiplication. First element is ignored.
123/// * rhs: Tensor of shape `[..., M, N]`, representing MxN matrices to the right of
124/// multiplication.
125///
126/// Returns:
127/// * `Output`: Tensor of shape `[..., M, N]` containing the product.
128class TridiagonalMatMul {
129 public:
130 TridiagonalMatMul(const ::tensorflow::Scope& scope, ::tensorflow::Input
131 superdiag, ::tensorflow::Input maindiag, ::tensorflow::Input
132 subdiag, ::tensorflow::Input rhs);
133 operator ::tensorflow::Output() const { return output; }
134 operator ::tensorflow::Input() const { return output; }
135 ::tensorflow::Node* node() const { return output.node(); }
136
137 Operation operation;
138 ::tensorflow::Output output;
139};
140
141/// Solves tridiagonal systems of equations.
142///
143/// Solves tridiagonal systems of equations.
144/// Supports batch dimensions and multiple right-hand sides per each left-hand
145/// side.
146/// On CPU, solution is computed via Gaussian elimination with or without partial
147/// pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
148/// library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
149/// Partial pivoting is not yet supported by XLA backends.
150///
151/// Args:
152/// * scope: A Scope object
153/// * diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the
154/// tridiagonal matrices with three rows being the superdiagonal, diagonals, and
155/// subdiagonals, in order. The last element of the superdiagonal and the first
156/// element of the subdiagonal is ignored.
157/// * rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each
158/// left-hand side.
159///
160/// Optional attributes (see `Attrs`):
161/// * partial_pivoting: Whether to apply partial pivoting. Partial pivoting makes the procedure more
162/// stable, but slower.
163///
164/// Returns:
165/// * `Output`: Tensor of shape `[..., M, K]` containing the solutions
166class TridiagonalSolve {
167 public:
168 /// Optional attribute setters for TridiagonalSolve
169 struct Attrs {
170 /// Whether to apply partial pivoting. Partial pivoting makes the procedure more
171 /// stable, but slower.
172 ///
173 /// Defaults to true
174 TF_MUST_USE_RESULT Attrs PartialPivoting(bool x) {
175 Attrs ret = *this;
176 ret.partial_pivoting_ = x;
177 return ret;
178 }
179
180 /// Defaults to false
181 TF_MUST_USE_RESULT Attrs PerturbSingular(bool x) {
182 Attrs ret = *this;
183 ret.perturb_singular_ = x;
184 return ret;
185 }
186
187 bool partial_pivoting_ = true;
188 bool perturb_singular_ = false;
189 };
190 TridiagonalSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input
191 diagonals, ::tensorflow::Input rhs);
192 TridiagonalSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input
193 diagonals, ::tensorflow::Input rhs, const
194 TridiagonalSolve::Attrs& attrs);
195 operator ::tensorflow::Output() const { return output; }
196 operator ::tensorflow::Input() const { return output; }
197 ::tensorflow::Node* node() const { return output.node(); }
198
199 static Attrs PartialPivoting(bool x) {
200 return Attrs().PartialPivoting(x);
201 }
202 static Attrs PerturbSingular(bool x) {
203 return Attrs().PerturbSingular(x);
204 }
205
206 Operation operation;
207 ::tensorflow::Output output;
208};
209
210} // namespace internal
211} // namespace ops
212} // namespace tensorflow
213
214#endif // TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_
215