1 | // This file is MACHINE GENERATED! Do not edit. |
2 | |
3 | #ifndef TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_ |
4 | #define TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_ |
5 | |
6 | // This file is MACHINE GENERATED! Do not edit. |
7 | |
8 | #include "tensorflow/cc/framework/ops.h" |
9 | #include "tensorflow/cc/framework/scope.h" |
10 | #include "tensorflow/core/framework/tensor.h" |
11 | #include "tensorflow/core/framework/tensor_shape.h" |
12 | #include "tensorflow/core/framework/types.h" |
13 | #include "tensorflow/core/lib/gtl/array_slice.h" |
14 | |
15 | namespace tensorflow { |
16 | namespace ops { |
17 | namespace internal { |
18 | // NOTE: This namespace has internal TensorFlow details that |
19 | // are not part of TensorFlow's public API. |
20 | |
21 | /// @defgroup linalg_ops_internal Linalg Ops Internal |
22 | /// @{ |
23 | |
24 | /// TODO: add doc. |
25 | /// |
26 | /// Args: |
27 | /// * scope: A Scope object |
28 | /// |
29 | /// Returns: |
30 | /// * `Output`: The output tensor. |
31 | class BandedTriangularSolve { |
32 | public: |
33 | /// Optional attribute setters for BandedTriangularSolve |
34 | struct Attrs { |
35 | /// Defaults to true |
36 | TF_MUST_USE_RESULT Attrs Lower(bool x) { |
37 | Attrs ret = *this; |
38 | ret.lower_ = x; |
39 | return ret; |
40 | } |
41 | |
42 | /// Defaults to false |
43 | TF_MUST_USE_RESULT Attrs Adjoint(bool x) { |
44 | Attrs ret = *this; |
45 | ret.adjoint_ = x; |
46 | return ret; |
47 | } |
48 | |
49 | bool lower_ = true; |
50 | bool adjoint_ = false; |
51 | }; |
52 | BandedTriangularSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input |
53 | matrix, ::tensorflow::Input rhs); |
54 | BandedTriangularSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input |
55 | matrix, ::tensorflow::Input rhs, const |
56 | BandedTriangularSolve::Attrs& attrs); |
57 | operator ::tensorflow::Output() const { return output; } |
58 | operator ::tensorflow::Input() const { return output; } |
59 | ::tensorflow::Node* node() const { return output.node(); } |
60 | |
61 | static Attrs Lower(bool x) { |
62 | return Attrs().Lower(x); |
63 | } |
64 | static Attrs Adjoint(bool x) { |
65 | return Attrs().Adjoint(x); |
66 | } |
67 | |
68 | Operation operation; |
69 | ::tensorflow::Output output; |
70 | }; |
71 | |
72 | /// Computes the matrix logarithm of one or more square matrices: |
73 | /// |
74 | /// |
75 | /// \\(log(exp(A)) = A\\) |
76 | /// |
77 | /// This op is only defined for complex matrices. If A is positive-definite and |
78 | /// real, then casting to a complex matrix, taking the logarithm and casting back |
79 | /// to a real matrix will give the correct result. |
80 | /// |
81 | /// This function computes the matrix logarithm using the Schur-Parlett algorithm. |
82 | /// Details of the algorithm can be found in Section 11.6.2 of: |
83 | /// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. |
84 | /// ISBN 978-0-898716-46-7. |
85 | /// |
86 | /// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions |
87 | /// form square matrices. The output is a tensor of the same shape as the input |
88 | /// containing the exponential for all input submatrices `[..., :, :]`. |
89 | /// |
90 | /// Args: |
91 | /// * scope: A Scope object |
92 | /// * input: Shape is `[..., M, M]`. |
93 | /// |
94 | /// Returns: |
95 | /// * `Output`: Shape is `[..., M, M]`. |
96 | /// |
97 | /// @compatibility(scipy) |
98 | /// Equivalent to scipy.linalg.logm |
99 | /// @end_compatibility |
100 | class MatrixLogarithm { |
101 | public: |
102 | MatrixLogarithm(const ::tensorflow::Scope& scope, ::tensorflow::Input input); |
103 | operator ::tensorflow::Output() const { return output; } |
104 | operator ::tensorflow::Input() const { return output; } |
105 | ::tensorflow::Node* node() const { return output.node(); } |
106 | |
107 | Operation operation; |
108 | ::tensorflow::Output output; |
109 | }; |
110 | |
111 | /// Calculate product with tridiagonal matrix. |
112 | /// |
113 | /// Calculates product of two matrices, where left matrix is a tridiagonal matrix. |
114 | /// |
115 | /// Args: |
116 | /// * scope: A Scope object |
117 | /// * superdiag: Tensor of shape `[..., 1, M]`, representing superdiagonals of |
118 | /// tri-diagonal matrices to the left of multiplication. Last element is ignored. |
119 | /// * maindiag: Tensor of shape `[..., 1, M]`, representing main diagonals of tri-diagonal |
120 | /// matrices to the left of multiplication. |
121 | /// * subdiag: Tensor of shape `[..., 1, M]`, representing subdiagonals of tri-diagonal |
122 | /// matrices to the left of multiplication. First element is ignored. |
123 | /// * rhs: Tensor of shape `[..., M, N]`, representing MxN matrices to the right of |
124 | /// multiplication. |
125 | /// |
126 | /// Returns: |
127 | /// * `Output`: Tensor of shape `[..., M, N]` containing the product. |
128 | class TridiagonalMatMul { |
129 | public: |
130 | TridiagonalMatMul(const ::tensorflow::Scope& scope, ::tensorflow::Input |
131 | superdiag, ::tensorflow::Input maindiag, ::tensorflow::Input |
132 | subdiag, ::tensorflow::Input rhs); |
133 | operator ::tensorflow::Output() const { return output; } |
134 | operator ::tensorflow::Input() const { return output; } |
135 | ::tensorflow::Node* node() const { return output.node(); } |
136 | |
137 | Operation operation; |
138 | ::tensorflow::Output output; |
139 | }; |
140 | |
141 | /// Solves tridiagonal systems of equations. |
142 | /// |
143 | /// Solves tridiagonal systems of equations. |
144 | /// Supports batch dimensions and multiple right-hand sides per each left-hand |
145 | /// side. |
146 | /// On CPU, solution is computed via Gaussian elimination with or without partial |
147 | /// pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE |
148 | /// library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv |
149 | /// Partial pivoting is not yet supported by XLA backends. |
150 | /// |
151 | /// Args: |
152 | /// * scope: A Scope object |
153 | /// * diagonals: Tensor of shape `[..., 3, M]` whose innermost 2 dimensions represent the |
154 | /// tridiagonal matrices with three rows being the superdiagonal, diagonals, and |
155 | /// subdiagonals, in order. The last element of the superdiagonal and the first |
156 | /// element of the subdiagonal is ignored. |
157 | /// * rhs: Tensor of shape `[..., M, K]`, representing K right-hand sides per each |
158 | /// left-hand side. |
159 | /// |
160 | /// Optional attributes (see `Attrs`): |
161 | /// * partial_pivoting: Whether to apply partial pivoting. Partial pivoting makes the procedure more |
162 | /// stable, but slower. |
163 | /// |
164 | /// Returns: |
165 | /// * `Output`: Tensor of shape `[..., M, K]` containing the solutions |
166 | class TridiagonalSolve { |
167 | public: |
168 | /// Optional attribute setters for TridiagonalSolve |
169 | struct Attrs { |
170 | /// Whether to apply partial pivoting. Partial pivoting makes the procedure more |
171 | /// stable, but slower. |
172 | /// |
173 | /// Defaults to true |
174 | TF_MUST_USE_RESULT Attrs PartialPivoting(bool x) { |
175 | Attrs ret = *this; |
176 | ret.partial_pivoting_ = x; |
177 | return ret; |
178 | } |
179 | |
180 | /// Defaults to false |
181 | TF_MUST_USE_RESULT Attrs PerturbSingular(bool x) { |
182 | Attrs ret = *this; |
183 | ret.perturb_singular_ = x; |
184 | return ret; |
185 | } |
186 | |
187 | bool partial_pivoting_ = true; |
188 | bool perturb_singular_ = false; |
189 | }; |
190 | TridiagonalSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input |
191 | diagonals, ::tensorflow::Input rhs); |
192 | TridiagonalSolve(const ::tensorflow::Scope& scope, ::tensorflow::Input |
193 | diagonals, ::tensorflow::Input rhs, const |
194 | TridiagonalSolve::Attrs& attrs); |
195 | operator ::tensorflow::Output() const { return output; } |
196 | operator ::tensorflow::Input() const { return output; } |
197 | ::tensorflow::Node* node() const { return output.node(); } |
198 | |
199 | static Attrs PartialPivoting(bool x) { |
200 | return Attrs().PartialPivoting(x); |
201 | } |
202 | static Attrs PerturbSingular(bool x) { |
203 | return Attrs().PerturbSingular(x); |
204 | } |
205 | |
206 | Operation operation; |
207 | ::tensorflow::Output output; |
208 | }; |
209 | |
210 | } // namespace internal |
211 | } // namespace ops |
212 | } // namespace tensorflow |
213 | |
214 | #endif // TENSORFLOW_CC_OPS_LINALG_OPS_INTERNAL_H_ |
215 | |