1#pragma once
2
3#include <ATen/ATen.h>
4
5namespace torch {
6namespace linalg {
7
8#ifndef DOXYGEN_SHOULD_SKIP_THIS
9namespace detail {
10
11inline Tensor cholesky(const Tensor& self) {
12 return torch::linalg_cholesky(self);
13}
14
15inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
16 return torch::linalg_cholesky_out(result, self);
17}
18
19inline Tensor det(const Tensor& self) {
20 return torch::linalg_det(self);
21}
22
23inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
24 return torch::linalg_slogdet(input);
25}
26
27inline std::tuple<Tensor&, Tensor&> slogdet_out(
28 Tensor& sign,
29 Tensor& logabsdet,
30 const Tensor& input) {
31 return torch::linalg_slogdet_out(sign, logabsdet, input);
32}
33
34inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
35 return torch::linalg_eig(self);
36}
37
38inline std::tuple<Tensor&, Tensor&> eig_out(
39 Tensor& eigvals,
40 Tensor& eigvecs,
41 const Tensor& self) {
42 return torch::linalg_eig_out(eigvals, eigvecs, self);
43}
44
45inline Tensor eigvals(const Tensor& self) {
46 return torch::linalg_eigvals(self);
47}
48
49inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
50 return torch::linalg_eigvals_out(result, self);
51}
52
53inline std::tuple<Tensor, Tensor> eigh(
54 const Tensor& self,
55 c10::string_view uplo) {
56 return torch::linalg_eigh(self, uplo);
57}
58
59inline std::tuple<Tensor&, Tensor&> eigh_out(
60 Tensor& eigvals,
61 Tensor& eigvecs,
62 const Tensor& self,
63 c10::string_view uplo) {
64 return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo);
65}
66
67inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
68 return torch::linalg_eigvalsh(self, uplo);
69}
70
71inline Tensor& eigvalsh_out(
72 Tensor& result,
73 const Tensor& self,
74 c10::string_view uplo) {
75 return torch::linalg_eigvalsh_out(result, self, uplo);
76}
77
78inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
79 return torch::linalg_householder_product(input, tau);
80}
81
82inline Tensor& householder_product_out(
83 Tensor& result,
84 const Tensor& input,
85 const Tensor& tau) {
86 return torch::linalg_householder_product_out(result, input, tau);
87}
88
89inline std::tuple<Tensor, Tensor> lu_factor(
90 const Tensor& self,
91 const bool pivot) {
92 return torch::linalg_lu_factor(self, pivot);
93}
94
95inline std::tuple<Tensor&, Tensor&> lu_factor_out(
96 Tensor& LU,
97 Tensor& pivots,
98 const Tensor& self,
99 const bool pivot) {
100 return torch::linalg_lu_factor_out(LU, pivots, self, pivot);
101}
102
103inline std::tuple<Tensor, Tensor, Tensor> lu(
104 const Tensor& self,
105 const bool pivot) {
106 return torch::linalg_lu(self, pivot);
107}
108
109inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
110 Tensor& P,
111 Tensor& L,
112 Tensor& U,
113 const Tensor& self,
114 const bool pivot) {
115 return torch::linalg_lu_out(P, L, U, self, pivot);
116}
117
118inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
119 const Tensor& self,
120 const Tensor& b,
121 c10::optional<double> cond,
122 c10::optional<c10::string_view> driver) {
123 return torch::linalg_lstsq(self, b, cond, driver);
124}
125
126inline Tensor matrix_exp(const Tensor& self) {
127 return torch::linalg_matrix_exp(self);
128}
129
130inline Tensor norm(
131 const Tensor& self,
132 const optional<Scalar>& opt_ord,
133 OptionalIntArrayRef opt_dim,
134 bool keepdim,
135 optional<ScalarType> opt_dtype) {
136 return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
137}
138
139inline Tensor norm(
140 const Tensor& self,
141 c10::string_view ord,
142 OptionalIntArrayRef opt_dim,
143 bool keepdim,
144 optional<ScalarType> opt_dtype) {
145 return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype);
146}
147
148inline Tensor& norm_out(
149 Tensor& result,
150 const Tensor& self,
151 const optional<Scalar>& opt_ord,
152 OptionalIntArrayRef opt_dim,
153 bool keepdim,
154 optional<ScalarType> opt_dtype) {
155 return torch::linalg_norm_out(
156 result, self, opt_ord, opt_dim, keepdim, opt_dtype);
157}
158
159inline Tensor& norm_out(
160 Tensor& result,
161 const Tensor& self,
162 c10::string_view ord,
163 OptionalIntArrayRef opt_dim,
164 bool keepdim,
165 optional<ScalarType> opt_dtype) {
166 return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
167}
168
169inline Tensor vector_norm(
170 const Tensor& self,
171 Scalar ord,
172 OptionalIntArrayRef opt_dim,
173 bool keepdim,
174 optional<ScalarType> opt_dtype) {
175 return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
176}
177
178inline Tensor& vector_norm_out(
179 Tensor& result,
180 const Tensor& self,
181 Scalar ord,
182 OptionalIntArrayRef opt_dim,
183 bool keepdim,
184 optional<ScalarType> opt_dtype) {
185 return torch::linalg_vector_norm_out(
186 result, self, ord, opt_dim, keepdim, opt_dtype);
187}
188
189inline Tensor matrix_norm(
190 const Tensor& self,
191 const Scalar& ord,
192 IntArrayRef dim,
193 bool keepdim,
194 optional<ScalarType> dtype) {
195 return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
196}
197
198inline Tensor& matrix_norm_out(
199 const Tensor& self,
200 const Scalar& ord,
201 IntArrayRef dim,
202 bool keepdim,
203 optional<ScalarType> dtype,
204 Tensor& result) {
205 return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
206}
207
208inline Tensor matrix_norm(
209 const Tensor& self,
210 std::string ord,
211 IntArrayRef dim,
212 bool keepdim,
213 optional<ScalarType> dtype) {
214 return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype);
215}
216
217inline Tensor& matrix_norm_out(
218 const Tensor& self,
219 std::string ord,
220 IntArrayRef dim,
221 bool keepdim,
222 optional<ScalarType> dtype,
223 Tensor& result) {
224 return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype);
225}
226
227inline Tensor matrix_power(const Tensor& self, int64_t n) {
228 return torch::linalg_matrix_power(self, n);
229}
230
231inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
232 return torch::linalg_matrix_power_out(result, self, n);
233}
234
235inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
236 return torch::linalg_matrix_rank(input, tol, hermitian);
237}
238
239inline Tensor matrix_rank(
240 const Tensor& input,
241 const Tensor& tol,
242 bool hermitian) {
243 return torch::linalg_matrix_rank(input, tol, hermitian);
244}
245
246inline Tensor matrix_rank(
247 const Tensor& input,
248 c10::optional<double> atol,
249 c10::optional<double> rtol,
250 bool hermitian) {
251 return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
252}
253
254inline Tensor matrix_rank(
255 const Tensor& input,
256 const c10::optional<Tensor>& atol,
257 const c10::optional<Tensor>& rtol,
258 bool hermitian) {
259 return torch::linalg_matrix_rank(input, atol, rtol, hermitian);
260}
261
262inline Tensor& matrix_rank_out(
263 Tensor& result,
264 const Tensor& input,
265 double tol,
266 bool hermitian) {
267 return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
268}
269
270inline Tensor& matrix_rank_out(
271 Tensor& result,
272 const Tensor& input,
273 const Tensor& tol,
274 bool hermitian) {
275 return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
276}
277
278inline Tensor& matrix_rank_out(
279 Tensor& result,
280 const Tensor& input,
281 c10::optional<double> atol,
282 c10::optional<double> rtol,
283 bool hermitian) {
284 return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
285}
286
287inline Tensor& matrix_rank_out(
288 Tensor& result,
289 const Tensor& input,
290 const c10::optional<Tensor>& atol,
291 const c10::optional<Tensor>& rtol,
292 bool hermitian) {
293 return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian);
294}
295
296inline Tensor multi_dot(TensorList tensors) {
297 return torch::linalg_multi_dot(tensors);
298}
299
300inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
301 return torch::linalg_multi_dot_out(result, tensors);
302}
303
304inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) {
305 return torch::linalg_pinv(input, rcond, hermitian);
306}
307
308inline Tensor& pinv_out(
309 Tensor& result,
310 const Tensor& input,
311 double rcond,
312 bool hermitian) {
313 return torch::linalg_pinv_out(result, input, rcond, hermitian);
314}
315
316inline std::tuple<Tensor, Tensor> qr(
317 const Tensor& input,
318 c10::string_view mode) {
319 return torch::linalg_qr(input, mode);
320}
321
322inline std::tuple<Tensor&, Tensor&> qr_out(
323 Tensor& Q,
324 Tensor& R,
325 const Tensor& input,
326 c10::string_view mode) {
327 return torch::linalg_qr_out(Q, R, input, mode);
328}
329
330inline std::tuple<Tensor, Tensor> solve_ex(
331 const Tensor& input,
332 const Tensor& other,
333 bool left,
334 bool check_errors) {
335 return torch::linalg_solve_ex(input, other, left, check_errors);
336}
337
338inline std::tuple<Tensor&, Tensor&> solve_ex_out(
339 Tensor& result,
340 Tensor& info,
341 const Tensor& input,
342 const Tensor& other,
343 bool left,
344 bool check_errors) {
345 return torch::linalg_solve_ex_out(
346 result, info, input, other, left, check_errors);
347}
348
349inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
350 return torch::linalg_solve(input, other, left);
351}
352
353inline Tensor& solve_out(
354 Tensor& result,
355 const Tensor& input,
356 const Tensor& other,
357 bool left) {
358 return torch::linalg_solve_out(result, input, other, left);
359}
360
361inline Tensor solve_triangular(
362 const Tensor& input,
363 const Tensor& other,
364 bool upper,
365 bool left,
366 bool unitriangular) {
367 return torch::linalg_solve_triangular(
368 input, other, upper, left, unitriangular);
369}
370
371inline Tensor& solve_triangular_out(
372 Tensor& result,
373 const Tensor& input,
374 const Tensor& other,
375 bool upper,
376 bool left,
377 bool unitriangular) {
378 return torch::linalg_solve_triangular_out(
379 result, input, other, upper, left, unitriangular);
380}
381
382inline std::tuple<Tensor, Tensor, Tensor> svd(
383 const Tensor& input,
384 bool full_matrices,
385 c10::optional<c10::string_view> driver) {
386 return torch::linalg_svd(input, full_matrices, driver);
387}
388
389inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
390 Tensor& U,
391 Tensor& S,
392 Tensor& Vh,
393 const Tensor& input,
394 bool full_matrices,
395 c10::optional<c10::string_view> driver) {
396 return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver);
397}
398
399inline Tensor svdvals(
400 const Tensor& input,
401 c10::optional<c10::string_view> driver) {
402 return torch::linalg_svdvals(input, driver);
403}
404
405inline Tensor& svdvals_out(
406 Tensor& result,
407 const Tensor& input,
408 c10::optional<c10::string_view> driver) {
409 return torch::linalg_svdvals_out(result, input, driver);
410}
411
412inline Tensor tensorinv(const Tensor& self, int64_t ind) {
413 return torch::linalg_tensorinv(self, ind);
414}
415
416inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
417 return torch::linalg_tensorinv_out(result, self, ind);
418}
419
420inline Tensor tensorsolve(
421 const Tensor& self,
422 const Tensor& other,
423 OptionalIntArrayRef dims) {
424 return torch::linalg_tensorsolve(self, other, dims);
425}
426
427inline Tensor& tensorsolve_out(
428 Tensor& result,
429 const Tensor& self,
430 const Tensor& other,
431 OptionalIntArrayRef dims) {
432 return torch::linalg_tensorsolve_out(result, self, other, dims);
433}
434
435inline Tensor inv(const Tensor& input) {
436 return torch::linalg_inv(input);
437}
438
439inline Tensor& inv_out(Tensor& result, const Tensor& input) {
440 return torch::linalg_inv_out(result, input);
441}
442
443} // namespace detail
444#endif /* DOXYGEN_SHOULD_SKIP_THIS */
445
446/// Cholesky decomposition
447///
448/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky
449///
450/// Example:
451/// ```
452/// auto A = torch::randn({4, 4});
453/// auto A = torch::matmul(A, A.t());
454/// auto L = torch::linalg::cholesky(A);
455/// assert(torch::allclose(torch::matmul(L, L.t()), A));
456/// ```
457inline Tensor cholesky(const Tensor& self) {
458 return detail::cholesky(self);
459}
460
461inline Tensor cholesky_out(Tensor& result, const Tensor& self) {
462 return detail::cholesky_out(result, self);
463}
464
465// C10_DEPRECATED_MESSAGE("linalg_det is deprecated, use det instead.")
466inline Tensor linalg_det(const Tensor& self) {
467 return detail::det(self);
468}
469
470/// See the documentation of torch.linalg.det
471inline Tensor det(const Tensor& self) {
472 return detail::det(self);
473}
474
475/// Computes the sign and (natural) logarithm of the determinant
476///
477/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.slogdet
478inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) {
479 return detail::slogdet(input);
480}
481
482inline std::tuple<Tensor&, Tensor&> slogdet_out(
483 Tensor& sign,
484 Tensor& logabsdet,
485 const Tensor& input) {
486 return detail::slogdet_out(sign, logabsdet, input);
487}
488
489/// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian
490/// matrices
491///
492/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eig
493inline std::tuple<Tensor, Tensor> eig(const Tensor& self) {
494 return detail::eig(self);
495}
496
497inline std::tuple<Tensor&, Tensor&> eig_out(
498 Tensor& eigvals,
499 Tensor& eigvecs,
500 const Tensor& self) {
501 return detail::eig_out(eigvals, eigvecs, self);
502}
503
504/// Computes eigenvalues of non-symmetric/non-hermitian matrices
505///
506/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvals
507inline Tensor eigvals(const Tensor& self) {
508 return detail::eigvals(self);
509}
510
511inline Tensor& eigvals_out(Tensor& result, const Tensor& self) {
512 return detail::eigvals_out(result, self);
513}
514
515/// Computes eigenvalues and eigenvectors
516///
517/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh
518inline std::tuple<Tensor, Tensor> eigh(
519 const Tensor& self,
520 c10::string_view uplo) {
521 return detail::eigh(self, uplo);
522}
523
524inline std::tuple<Tensor&, Tensor&> eigh_out(
525 Tensor& eigvals,
526 Tensor& eigvecs,
527 const Tensor& self,
528 c10::string_view uplo) {
529 return detail::eigh_out(eigvals, eigvecs, self, uplo);
530}
531
532/// Computes eigenvalues
533///
534/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvalsh
535inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) {
536 return detail::eigvalsh(self, uplo);
537}
538
539inline Tensor& eigvalsh_out(
540 Tensor& result,
541 const Tensor& self,
542 c10::string_view uplo) {
543 return detail::eigvalsh_out(result, self, uplo);
544}
545
546/// Computes the product of Householder matrices
547///
548/// See
549/// https://pytorch.org/docs/master/linalg.html#torch.linalg.householder_product
550inline Tensor householder_product(const Tensor& input, const Tensor& tau) {
551 return detail::householder_product(input, tau);
552}
553
554inline Tensor& householder_product_out(
555 Tensor& result,
556 const Tensor& input,
557 const Tensor& tau) {
558 return detail::householder_product_out(result, input, tau);
559}
560
561inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq(
562 const Tensor& self,
563 const Tensor& b,
564 c10::optional<double> cond,
565 c10::optional<c10::string_view> driver) {
566 return detail::lstsq(self, b, cond, driver);
567}
568
569/// Computes the matrix exponential
570///
571/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_exp
572inline Tensor matrix_exp(const Tensor& input) {
573 return detail::matrix_exp(input);
574}
575
576// C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
577inline Tensor linalg_norm(
578 const Tensor& self,
579 const optional<Scalar>& opt_ord,
580 OptionalIntArrayRef opt_dim,
581 bool keepdim,
582 optional<ScalarType> opt_dtype) {
583 return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
584}
585
586// C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.")
587inline Tensor linalg_norm(
588 const Tensor& self,
589 c10::string_view ord,
590 OptionalIntArrayRef opt_dim,
591 bool keepdim,
592 optional<ScalarType> opt_dtype) {
593 return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
594}
595
596// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
597// instead.")
598inline Tensor& linalg_norm_out(
599 Tensor& result,
600 const Tensor& self,
601 const optional<Scalar>& opt_ord,
602 OptionalIntArrayRef opt_dim,
603 bool keepdim,
604 optional<ScalarType> opt_dtype) {
605 return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
606}
607
608// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out
609// instead.")
610inline Tensor& linalg_norm_out(
611 Tensor& result,
612 const Tensor& self,
613 c10::string_view ord,
614 OptionalIntArrayRef opt_dim,
615 bool keepdim,
616 optional<ScalarType> opt_dtype) {
617 return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
618}
619
620/// Computes the LU factorization with partial pivoting
621///
622/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu_factor
623inline std::tuple<Tensor, Tensor> lu_factor(
624 const Tensor& input,
625 const bool pivot = true) {
626 return detail::lu_factor(input, pivot);
627}
628
629inline std::tuple<Tensor&, Tensor&> lu_factor_out(
630 Tensor& LU,
631 Tensor& pivots,
632 const Tensor& self,
633 const bool pivot = true) {
634 return detail::lu_factor_out(LU, pivots, self, pivot);
635}
636
637/// Computes the LU factorization with partial pivoting
638///
639/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu
640inline std::tuple<Tensor, Tensor, Tensor> lu(
641 const Tensor& input,
642 const bool pivot = true) {
643 return detail::lu(input, pivot);
644}
645
646inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out(
647 Tensor& P,
648 Tensor& L,
649 Tensor& U,
650 const Tensor& self,
651 const bool pivot = true) {
652 return detail::lu_out(P, L, U, self, pivot);
653}
654
655inline Tensor norm(
656 const Tensor& self,
657 const optional<Scalar>& opt_ord,
658 OptionalIntArrayRef opt_dim,
659 bool keepdim,
660 optional<ScalarType> opt_dtype) {
661 return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
662}
663
664inline Tensor norm(
665 const Tensor& self,
666 std::string ord,
667 OptionalIntArrayRef opt_dim,
668 bool keepdim,
669 optional<ScalarType> opt_dtype) {
670 return detail::norm(self, ord, opt_dim, keepdim, opt_dtype);
671}
672
673inline Tensor& norm_out(
674 Tensor& result,
675 const Tensor& self,
676 const optional<Scalar>& opt_ord,
677 OptionalIntArrayRef opt_dim,
678 bool keepdim,
679 optional<ScalarType> opt_dtype) {
680 return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
681}
682
683inline Tensor& norm_out(
684 Tensor& result,
685 const Tensor& self,
686 std::string ord,
687 OptionalIntArrayRef opt_dim,
688 bool keepdim,
689 optional<ScalarType> opt_dtype) {
690 return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
691}
692
693/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.vector_norm
694inline Tensor vector_norm(
695 const Tensor& self,
696 Scalar ord,
697 OptionalIntArrayRef opt_dim,
698 bool keepdim,
699 optional<ScalarType> opt_dtype) {
700 return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
701}
702
703inline Tensor& vector_norm_out(
704 Tensor& result,
705 const Tensor& self,
706 Scalar ord,
707 OptionalIntArrayRef opt_dim,
708 bool keepdim,
709 optional<ScalarType> opt_dtype) {
710 return detail::vector_norm_out(
711 result, self, ord, opt_dim, keepdim, opt_dtype);
712}
713
714/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_norm
715inline Tensor matrix_norm(
716 const Tensor& self,
717 const Scalar& ord,
718 IntArrayRef dim,
719 bool keepdim,
720 optional<ScalarType> dtype) {
721 return detail::matrix_norm(self, ord, dim, keepdim, dtype);
722}
723
724inline Tensor& matrix_norm_out(
725 const Tensor& self,
726 const Scalar& ord,
727 IntArrayRef dim,
728 bool keepdim,
729 optional<ScalarType> dtype,
730 Tensor& result) {
731 return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
732}
733
734inline Tensor matrix_norm(
735 const Tensor& self,
736 std::string ord,
737 IntArrayRef dim,
738 bool keepdim,
739 optional<ScalarType> dtype) {
740 return detail::matrix_norm(self, ord, dim, keepdim, dtype);
741}
742
743inline Tensor& matrix_norm_out(
744 const Tensor& self,
745 std::string ord,
746 IntArrayRef dim,
747 bool keepdim,
748 optional<ScalarType> dtype,
749 Tensor& result) {
750 return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result);
751}
752
753/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power
754inline Tensor matrix_power(const Tensor& self, int64_t n) {
755 return detail::matrix_power(self, n);
756}
757
758inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
759 return detail::matrix_power_out(self, n, result);
760}
761
762/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_rank
763inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) {
764 return detail::matrix_rank(input, tol, hermitian);
765}
766
767inline Tensor matrix_rank(
768 const Tensor& input,
769 const Tensor& tol,
770 bool hermitian) {
771 return detail::matrix_rank(input, tol, hermitian);
772}
773
774inline Tensor matrix_rank(
775 const Tensor& input,
776 c10::optional<double> atol,
777 c10::optional<double> rtol,
778 bool hermitian) {
779 return detail::matrix_rank(input, atol, rtol, hermitian);
780}
781
782inline Tensor matrix_rank(
783 const Tensor& input,
784 const c10::optional<Tensor>& atol,
785 const c10::optional<Tensor>& rtol,
786 bool hermitian) {
787 return detail::matrix_rank(input, atol, rtol, hermitian);
788}
789
790inline Tensor& matrix_rank_out(
791 Tensor& result,
792 const Tensor& input,
793 double tol,
794 bool hermitian) {
795 return detail::matrix_rank_out(result, input, tol, hermitian);
796}
797
798inline Tensor& matrix_rank_out(
799 Tensor& result,
800 const Tensor& input,
801 const Tensor& tol,
802 bool hermitian) {
803 return detail::matrix_rank_out(result, input, tol, hermitian);
804}
805
806inline Tensor& matrix_rank_out(
807 Tensor& result,
808 const Tensor& input,
809 c10::optional<double> atol,
810 c10::optional<double> rtol,
811 bool hermitian) {
812 return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
813}
814
815inline Tensor& matrix_rank_out(
816 Tensor& result,
817 const Tensor& input,
818 const c10::optional<Tensor>& atol,
819 const c10::optional<Tensor>& rtol,
820 bool hermitian) {
821 return detail::matrix_rank_out(result, input, atol, rtol, hermitian);
822}
823
824/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.multi_dot
825inline Tensor multi_dot(TensorList tensors) {
826 return detail::multi_dot(tensors);
827}
828
829inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
830 return detail::multi_dot_out(tensors, result);
831}
832
833/// Computes the pseudo-inverse
834///
835/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv
836inline Tensor pinv(
837 const Tensor& input,
838 double rcond = 1e-15,
839 bool hermitian = false) {
840 return detail::pinv(input, rcond, hermitian);
841}
842
843inline Tensor& pinv_out(
844 Tensor& result,
845 const Tensor& input,
846 double rcond = 1e-15,
847 bool hermitian = false) {
848 return detail::pinv_out(result, input, rcond, hermitian);
849}
850
851/// Computes the QR decomposition
852///
853/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.qr
854inline std::tuple<Tensor, Tensor> qr(
855 const Tensor& input,
856 c10::string_view mode = "reduced") {
857 // C++17 Change the initialisation to "reduced"sv
858 // Same for qr_out
859 return detail::qr(input, mode);
860}
861
862inline std::tuple<Tensor&, Tensor&> qr_out(
863 Tensor& Q,
864 Tensor& R,
865 const Tensor& input,
866 c10::string_view mode = "reduced") {
867 return detail::qr_out(Q, R, input, mode);
868}
869
870/// Computes the LDL decomposition
871///
872/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.ldl_factor_ex
873inline std::tuple<Tensor, Tensor, Tensor> ldl_factor_ex(
874 const Tensor& input,
875 bool hermitian,
876 bool check_errors) {
877 return torch::linalg_ldl_factor_ex(input, hermitian, check_errors);
878}
879
880inline std::tuple<Tensor&, Tensor&, Tensor&> ldl_factor_ex_out(
881 Tensor& LD,
882 Tensor& pivots,
883 Tensor& info,
884 const Tensor& input,
885 bool hermitian,
886 bool check_errors) {
887 return torch::linalg_ldl_factor_ex_out(
888 LD, pivots, info, input, hermitian, check_errors);
889}
890
891/// Solve a system of linear equations using the LDL decomposition
892///
893/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.ldl_solve
894inline Tensor ldl_solve(
895 const Tensor& LD,
896 const Tensor& pivots,
897 const Tensor& B,
898 bool hermitian) {
899 return torch::linalg_ldl_solve(LD, pivots, B, hermitian);
900}
901
902inline Tensor& ldl_solve_out(
903 Tensor& result,
904 const Tensor& LD,
905 const Tensor& pivots,
906 const Tensor& B,
907 bool hermitian) {
908 return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian);
909}
910
911/// Solves a system linear system AX = B
912///
913/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_ex
914inline std::tuple<Tensor, Tensor> solve_ex(
915 const Tensor& input,
916 const Tensor& other,
917 bool left,
918 bool check_errors) {
919 return detail::solve_ex(input, other, left, check_errors);
920}
921
922inline std::tuple<Tensor&, Tensor&> solve_ex_out(
923 Tensor& result,
924 Tensor& info,
925 const Tensor& input,
926 const Tensor& other,
927 bool left,
928 bool check_errors) {
929 return detail::solve_ex_out(result, info, input, other, left, check_errors);
930}
931
932/// Computes a tensor `x` such that `matmul(input, x) = other`.
933///
934/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve
935inline Tensor solve(const Tensor& input, const Tensor& other, bool left) {
936 return detail::solve(input, other, left);
937}
938
939inline Tensor& solve_out(
940 Tensor& result,
941 const Tensor& input,
942 const Tensor& other,
943 bool left) {
944 return detail::solve_out(result, input, other, left);
945}
946
947/// Computes a solution of a linear system AX = B for input = A and other = B
948/// whenever A is square upper or lower triangular and does not have zeros in
949/// the diagonal
950///
951/// See
952/// https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular
953inline Tensor solve_triangular(
954 const Tensor& input,
955 const Tensor& other,
956 bool upper,
957 bool left,
958 bool unitriangular) {
959 return detail::solve_triangular(input, other, upper, left, unitriangular);
960}
961
962inline Tensor& solve_triangular_out(
963 Tensor& result,
964 const Tensor& input,
965 const Tensor& other,
966 bool upper,
967 bool left,
968 bool unitriangular) {
969 return detail::solve_triangular_out(
970 result, input, other, upper, left, unitriangular);
971}
972
973/// Computes the singular values and singular vectors
974///
975/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svd
976inline std::tuple<Tensor, Tensor, Tensor> svd(
977 const Tensor& input,
978 bool full_matrices,
979 c10::optional<c10::string_view> driver) {
980 return detail::svd(input, full_matrices, driver);
981}
982
983inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out(
984 Tensor& U,
985 Tensor& S,
986 Tensor& Vh,
987 const Tensor& input,
988 bool full_matrices,
989 c10::optional<c10::string_view> driver) {
990 return detail::svd_out(U, S, Vh, input, full_matrices, driver);
991}
992
993/// Computes the singular values
994///
995/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svdvals
996inline Tensor svdvals(
997 const Tensor& input,
998 c10::optional<c10::string_view> driver) {
999 return detail::svdvals(input, driver);
1000}
1001
1002inline Tensor& svdvals_out(
1003 Tensor& result,
1004 const Tensor& input,
1005 c10::optional<c10::string_view> driver) {
1006 return detail::svdvals_out(result, input, driver);
1007}
1008
1009/// Computes the inverse of a tensor
1010///
1011/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorinv
1012///
1013/// Example:
1014/// ```
1015/// auto a = torch::eye(4*6).reshape({4, 6, 8, 3});
1016/// int64_t ind = 2;
1017/// auto ainv = torch::linalg::tensorinv(a, ind);
1018/// ```
1019inline Tensor tensorinv(const Tensor& self, int64_t ind) {
1020 return detail::tensorinv(self, ind);
1021}
1022
1023inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) {
1024 return detail::tensorinv_out(result, self, ind);
1025}
1026
1027/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`.
1028///
1029/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve
1030///
1031/// Example:
1032/// ```
1033/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4});
1034/// auto b = torch::randn(2*3, 4);
1035/// auto x = torch::linalg::tensorsolve(a, b);
1036/// ```
1037inline Tensor tensorsolve(
1038 const Tensor& input,
1039 const Tensor& other,
1040 OptionalIntArrayRef dims) {
1041 return detail::tensorsolve(input, other, dims);
1042}
1043
1044inline Tensor& tensorsolve_out(
1045 Tensor& result,
1046 const Tensor& input,
1047 const Tensor& other,
1048 OptionalIntArrayRef dims) {
1049 return detail::tensorsolve_out(result, input, other, dims);
1050}
1051
1052/// Computes a tensor `inverse_input` such that `dot(input, inverse_input) =
1053/// eye(input.size(0))`.
1054///
1055/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.inv
1056inline Tensor inv(const Tensor& input) {
1057 return detail::inv(input);
1058}
1059
1060inline Tensor& inv_out(Tensor& result, const Tensor& input) {
1061 return detail::inv_out(result, input);
1062}
1063
1064} // namespace linalg
1065} // namespace torch
1066