1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | |
5 | namespace torch { |
6 | namespace linalg { |
7 | |
8 | #ifndef DOXYGEN_SHOULD_SKIP_THIS |
9 | namespace detail { |
10 | |
11 | inline Tensor cholesky(const Tensor& self) { |
12 | return torch::linalg_cholesky(self); |
13 | } |
14 | |
15 | inline Tensor cholesky_out(Tensor& result, const Tensor& self) { |
16 | return torch::linalg_cholesky_out(result, self); |
17 | } |
18 | |
19 | inline Tensor det(const Tensor& self) { |
20 | return torch::linalg_det(self); |
21 | } |
22 | |
23 | inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) { |
24 | return torch::linalg_slogdet(input); |
25 | } |
26 | |
27 | inline std::tuple<Tensor&, Tensor&> slogdet_out( |
28 | Tensor& sign, |
29 | Tensor& logabsdet, |
30 | const Tensor& input) { |
31 | return torch::linalg_slogdet_out(sign, logabsdet, input); |
32 | } |
33 | |
34 | inline std::tuple<Tensor, Tensor> eig(const Tensor& self) { |
35 | return torch::linalg_eig(self); |
36 | } |
37 | |
38 | inline std::tuple<Tensor&, Tensor&> eig_out( |
39 | Tensor& eigvals, |
40 | Tensor& eigvecs, |
41 | const Tensor& self) { |
42 | return torch::linalg_eig_out(eigvals, eigvecs, self); |
43 | } |
44 | |
45 | inline Tensor eigvals(const Tensor& self) { |
46 | return torch::linalg_eigvals(self); |
47 | } |
48 | |
49 | inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { |
50 | return torch::linalg_eigvals_out(result, self); |
51 | } |
52 | |
53 | inline std::tuple<Tensor, Tensor> eigh( |
54 | const Tensor& self, |
55 | c10::string_view uplo) { |
56 | return torch::linalg_eigh(self, uplo); |
57 | } |
58 | |
59 | inline std::tuple<Tensor&, Tensor&> eigh_out( |
60 | Tensor& eigvals, |
61 | Tensor& eigvecs, |
62 | const Tensor& self, |
63 | c10::string_view uplo) { |
64 | return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); |
65 | } |
66 | |
67 | inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { |
68 | return torch::linalg_eigvalsh(self, uplo); |
69 | } |
70 | |
71 | inline Tensor& eigvalsh_out( |
72 | Tensor& result, |
73 | const Tensor& self, |
74 | c10::string_view uplo) { |
75 | return torch::linalg_eigvalsh_out(result, self, uplo); |
76 | } |
77 | |
78 | inline Tensor householder_product(const Tensor& input, const Tensor& tau) { |
79 | return torch::linalg_householder_product(input, tau); |
80 | } |
81 | |
82 | inline Tensor& householder_product_out( |
83 | Tensor& result, |
84 | const Tensor& input, |
85 | const Tensor& tau) { |
86 | return torch::linalg_householder_product_out(result, input, tau); |
87 | } |
88 | |
89 | inline std::tuple<Tensor, Tensor> lu_factor( |
90 | const Tensor& self, |
91 | const bool pivot) { |
92 | return torch::linalg_lu_factor(self, pivot); |
93 | } |
94 | |
95 | inline std::tuple<Tensor&, Tensor&> lu_factor_out( |
96 | Tensor& LU, |
97 | Tensor& pivots, |
98 | const Tensor& self, |
99 | const bool pivot) { |
100 | return torch::linalg_lu_factor_out(LU, pivots, self, pivot); |
101 | } |
102 | |
103 | inline std::tuple<Tensor, Tensor, Tensor> lu( |
104 | const Tensor& self, |
105 | const bool pivot) { |
106 | return torch::linalg_lu(self, pivot); |
107 | } |
108 | |
109 | inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out( |
110 | Tensor& P, |
111 | Tensor& L, |
112 | Tensor& U, |
113 | const Tensor& self, |
114 | const bool pivot) { |
115 | return torch::linalg_lu_out(P, L, U, self, pivot); |
116 | } |
117 | |
118 | inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq( |
119 | const Tensor& self, |
120 | const Tensor& b, |
121 | c10::optional<double> cond, |
122 | c10::optional<c10::string_view> driver) { |
123 | return torch::linalg_lstsq(self, b, cond, driver); |
124 | } |
125 | |
126 | inline Tensor matrix_exp(const Tensor& self) { |
127 | return torch::linalg_matrix_exp(self); |
128 | } |
129 | |
130 | inline Tensor norm( |
131 | const Tensor& self, |
132 | const optional<Scalar>& opt_ord, |
133 | OptionalIntArrayRef opt_dim, |
134 | bool keepdim, |
135 | optional<ScalarType> opt_dtype) { |
136 | return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); |
137 | } |
138 | |
139 | inline Tensor norm( |
140 | const Tensor& self, |
141 | c10::string_view ord, |
142 | OptionalIntArrayRef opt_dim, |
143 | bool keepdim, |
144 | optional<ScalarType> opt_dtype) { |
145 | return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); |
146 | } |
147 | |
148 | inline Tensor& norm_out( |
149 | Tensor& result, |
150 | const Tensor& self, |
151 | const optional<Scalar>& opt_ord, |
152 | OptionalIntArrayRef opt_dim, |
153 | bool keepdim, |
154 | optional<ScalarType> opt_dtype) { |
155 | return torch::linalg_norm_out( |
156 | result, self, opt_ord, opt_dim, keepdim, opt_dtype); |
157 | } |
158 | |
159 | inline Tensor& norm_out( |
160 | Tensor& result, |
161 | const Tensor& self, |
162 | c10::string_view ord, |
163 | OptionalIntArrayRef opt_dim, |
164 | bool keepdim, |
165 | optional<ScalarType> opt_dtype) { |
166 | return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); |
167 | } |
168 | |
169 | inline Tensor vector_norm( |
170 | const Tensor& self, |
171 | Scalar ord, |
172 | OptionalIntArrayRef opt_dim, |
173 | bool keepdim, |
174 | optional<ScalarType> opt_dtype) { |
175 | return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype); |
176 | } |
177 | |
178 | inline Tensor& vector_norm_out( |
179 | Tensor& result, |
180 | const Tensor& self, |
181 | Scalar ord, |
182 | OptionalIntArrayRef opt_dim, |
183 | bool keepdim, |
184 | optional<ScalarType> opt_dtype) { |
185 | return torch::linalg_vector_norm_out( |
186 | result, self, ord, opt_dim, keepdim, opt_dtype); |
187 | } |
188 | |
189 | inline Tensor matrix_norm( |
190 | const Tensor& self, |
191 | const Scalar& ord, |
192 | IntArrayRef dim, |
193 | bool keepdim, |
194 | optional<ScalarType> dtype) { |
195 | return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); |
196 | } |
197 | |
198 | inline Tensor& matrix_norm_out( |
199 | const Tensor& self, |
200 | const Scalar& ord, |
201 | IntArrayRef dim, |
202 | bool keepdim, |
203 | optional<ScalarType> dtype, |
204 | Tensor& result) { |
205 | return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); |
206 | } |
207 | |
208 | inline Tensor matrix_norm( |
209 | const Tensor& self, |
210 | std::string ord, |
211 | IntArrayRef dim, |
212 | bool keepdim, |
213 | optional<ScalarType> dtype) { |
214 | return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); |
215 | } |
216 | |
217 | inline Tensor& matrix_norm_out( |
218 | const Tensor& self, |
219 | std::string ord, |
220 | IntArrayRef dim, |
221 | bool keepdim, |
222 | optional<ScalarType> dtype, |
223 | Tensor& result) { |
224 | return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); |
225 | } |
226 | |
227 | inline Tensor matrix_power(const Tensor& self, int64_t n) { |
228 | return torch::linalg_matrix_power(self, n); |
229 | } |
230 | |
231 | inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) { |
232 | return torch::linalg_matrix_power_out(result, self, n); |
233 | } |
234 | |
235 | inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { |
236 | return torch::linalg_matrix_rank(input, tol, hermitian); |
237 | } |
238 | |
239 | inline Tensor matrix_rank( |
240 | const Tensor& input, |
241 | const Tensor& tol, |
242 | bool hermitian) { |
243 | return torch::linalg_matrix_rank(input, tol, hermitian); |
244 | } |
245 | |
246 | inline Tensor matrix_rank( |
247 | const Tensor& input, |
248 | c10::optional<double> atol, |
249 | c10::optional<double> rtol, |
250 | bool hermitian) { |
251 | return torch::linalg_matrix_rank(input, atol, rtol, hermitian); |
252 | } |
253 | |
254 | inline Tensor matrix_rank( |
255 | const Tensor& input, |
256 | const c10::optional<Tensor>& atol, |
257 | const c10::optional<Tensor>& rtol, |
258 | bool hermitian) { |
259 | return torch::linalg_matrix_rank(input, atol, rtol, hermitian); |
260 | } |
261 | |
262 | inline Tensor& matrix_rank_out( |
263 | Tensor& result, |
264 | const Tensor& input, |
265 | double tol, |
266 | bool hermitian) { |
267 | return torch::linalg_matrix_rank_out(result, input, tol, hermitian); |
268 | } |
269 | |
270 | inline Tensor& matrix_rank_out( |
271 | Tensor& result, |
272 | const Tensor& input, |
273 | const Tensor& tol, |
274 | bool hermitian) { |
275 | return torch::linalg_matrix_rank_out(result, input, tol, hermitian); |
276 | } |
277 | |
278 | inline Tensor& matrix_rank_out( |
279 | Tensor& result, |
280 | const Tensor& input, |
281 | c10::optional<double> atol, |
282 | c10::optional<double> rtol, |
283 | bool hermitian) { |
284 | return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); |
285 | } |
286 | |
287 | inline Tensor& matrix_rank_out( |
288 | Tensor& result, |
289 | const Tensor& input, |
290 | const c10::optional<Tensor>& atol, |
291 | const c10::optional<Tensor>& rtol, |
292 | bool hermitian) { |
293 | return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); |
294 | } |
295 | |
296 | inline Tensor multi_dot(TensorList tensors) { |
297 | return torch::linalg_multi_dot(tensors); |
298 | } |
299 | |
300 | inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) { |
301 | return torch::linalg_multi_dot_out(result, tensors); |
302 | } |
303 | |
304 | inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) { |
305 | return torch::linalg_pinv(input, rcond, hermitian); |
306 | } |
307 | |
308 | inline Tensor& pinv_out( |
309 | Tensor& result, |
310 | const Tensor& input, |
311 | double rcond, |
312 | bool hermitian) { |
313 | return torch::linalg_pinv_out(result, input, rcond, hermitian); |
314 | } |
315 | |
316 | inline std::tuple<Tensor, Tensor> qr( |
317 | const Tensor& input, |
318 | c10::string_view mode) { |
319 | return torch::linalg_qr(input, mode); |
320 | } |
321 | |
322 | inline std::tuple<Tensor&, Tensor&> qr_out( |
323 | Tensor& Q, |
324 | Tensor& R, |
325 | const Tensor& input, |
326 | c10::string_view mode) { |
327 | return torch::linalg_qr_out(Q, R, input, mode); |
328 | } |
329 | |
330 | inline std::tuple<Tensor, Tensor> solve_ex( |
331 | const Tensor& input, |
332 | const Tensor& other, |
333 | bool left, |
334 | bool check_errors) { |
335 | return torch::linalg_solve_ex(input, other, left, check_errors); |
336 | } |
337 | |
338 | inline std::tuple<Tensor&, Tensor&> solve_ex_out( |
339 | Tensor& result, |
340 | Tensor& info, |
341 | const Tensor& input, |
342 | const Tensor& other, |
343 | bool left, |
344 | bool check_errors) { |
345 | return torch::linalg_solve_ex_out( |
346 | result, info, input, other, left, check_errors); |
347 | } |
348 | |
349 | inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { |
350 | return torch::linalg_solve(input, other, left); |
351 | } |
352 | |
353 | inline Tensor& solve_out( |
354 | Tensor& result, |
355 | const Tensor& input, |
356 | const Tensor& other, |
357 | bool left) { |
358 | return torch::linalg_solve_out(result, input, other, left); |
359 | } |
360 | |
361 | inline Tensor solve_triangular( |
362 | const Tensor& input, |
363 | const Tensor& other, |
364 | bool upper, |
365 | bool left, |
366 | bool unitriangular) { |
367 | return torch::linalg_solve_triangular( |
368 | input, other, upper, left, unitriangular); |
369 | } |
370 | |
371 | inline Tensor& solve_triangular_out( |
372 | Tensor& result, |
373 | const Tensor& input, |
374 | const Tensor& other, |
375 | bool upper, |
376 | bool left, |
377 | bool unitriangular) { |
378 | return torch::linalg_solve_triangular_out( |
379 | result, input, other, upper, left, unitriangular); |
380 | } |
381 | |
382 | inline std::tuple<Tensor, Tensor, Tensor> svd( |
383 | const Tensor& input, |
384 | bool full_matrices, |
385 | c10::optional<c10::string_view> driver) { |
386 | return torch::linalg_svd(input, full_matrices, driver); |
387 | } |
388 | |
389 | inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out( |
390 | Tensor& U, |
391 | Tensor& S, |
392 | Tensor& Vh, |
393 | const Tensor& input, |
394 | bool full_matrices, |
395 | c10::optional<c10::string_view> driver) { |
396 | return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver); |
397 | } |
398 | |
399 | inline Tensor svdvals( |
400 | const Tensor& input, |
401 | c10::optional<c10::string_view> driver) { |
402 | return torch::linalg_svdvals(input, driver); |
403 | } |
404 | |
405 | inline Tensor& svdvals_out( |
406 | Tensor& result, |
407 | const Tensor& input, |
408 | c10::optional<c10::string_view> driver) { |
409 | return torch::linalg_svdvals_out(result, input, driver); |
410 | } |
411 | |
412 | inline Tensor tensorinv(const Tensor& self, int64_t ind) { |
413 | return torch::linalg_tensorinv(self, ind); |
414 | } |
415 | |
416 | inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { |
417 | return torch::linalg_tensorinv_out(result, self, ind); |
418 | } |
419 | |
420 | inline Tensor tensorsolve( |
421 | const Tensor& self, |
422 | const Tensor& other, |
423 | OptionalIntArrayRef dims) { |
424 | return torch::linalg_tensorsolve(self, other, dims); |
425 | } |
426 | |
427 | inline Tensor& tensorsolve_out( |
428 | Tensor& result, |
429 | const Tensor& self, |
430 | const Tensor& other, |
431 | OptionalIntArrayRef dims) { |
432 | return torch::linalg_tensorsolve_out(result, self, other, dims); |
433 | } |
434 | |
435 | inline Tensor inv(const Tensor& input) { |
436 | return torch::linalg_inv(input); |
437 | } |
438 | |
439 | inline Tensor& inv_out(Tensor& result, const Tensor& input) { |
440 | return torch::linalg_inv_out(result, input); |
441 | } |
442 | |
443 | } // namespace detail |
444 | #endif /* DOXYGEN_SHOULD_SKIP_THIS */ |
445 | |
446 | /// Cholesky decomposition |
447 | /// |
448 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.cholesky |
449 | /// |
450 | /// Example: |
451 | /// ``` |
452 | /// auto A = torch::randn({4, 4}); |
453 | /// auto A = torch::matmul(A, A.t()); |
454 | /// auto L = torch::linalg::cholesky(A); |
455 | /// assert(torch::allclose(torch::matmul(L, L.t()), A)); |
456 | /// ``` |
457 | inline Tensor cholesky(const Tensor& self) { |
458 | return detail::cholesky(self); |
459 | } |
460 | |
461 | inline Tensor cholesky_out(Tensor& result, const Tensor& self) { |
462 | return detail::cholesky_out(result, self); |
463 | } |
464 | |
465 | // C10_DEPRECATED_MESSAGE("linalg_det is deprecated, use det instead.") |
466 | inline Tensor linalg_det(const Tensor& self) { |
467 | return detail::det(self); |
468 | } |
469 | |
470 | /// See the documentation of torch.linalg.det |
471 | inline Tensor det(const Tensor& self) { |
472 | return detail::det(self); |
473 | } |
474 | |
475 | /// Computes the sign and (natural) logarithm of the determinant |
476 | /// |
477 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.slogdet |
478 | inline std::tuple<Tensor, Tensor> slogdet(const Tensor& input) { |
479 | return detail::slogdet(input); |
480 | } |
481 | |
482 | inline std::tuple<Tensor&, Tensor&> slogdet_out( |
483 | Tensor& sign, |
484 | Tensor& logabsdet, |
485 | const Tensor& input) { |
486 | return detail::slogdet_out(sign, logabsdet, input); |
487 | } |
488 | |
489 | /// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian |
490 | /// matrices |
491 | /// |
492 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eig |
493 | inline std::tuple<Tensor, Tensor> eig(const Tensor& self) { |
494 | return detail::eig(self); |
495 | } |
496 | |
497 | inline std::tuple<Tensor&, Tensor&> eig_out( |
498 | Tensor& eigvals, |
499 | Tensor& eigvecs, |
500 | const Tensor& self) { |
501 | return detail::eig_out(eigvals, eigvecs, self); |
502 | } |
503 | |
504 | /// Computes eigenvalues of non-symmetric/non-hermitian matrices |
505 | /// |
506 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvals |
507 | inline Tensor eigvals(const Tensor& self) { |
508 | return detail::eigvals(self); |
509 | } |
510 | |
511 | inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { |
512 | return detail::eigvals_out(result, self); |
513 | } |
514 | |
515 | /// Computes eigenvalues and eigenvectors |
516 | /// |
517 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigh |
518 | inline std::tuple<Tensor, Tensor> eigh( |
519 | const Tensor& self, |
520 | c10::string_view uplo) { |
521 | return detail::eigh(self, uplo); |
522 | } |
523 | |
524 | inline std::tuple<Tensor&, Tensor&> eigh_out( |
525 | Tensor& eigvals, |
526 | Tensor& eigvecs, |
527 | const Tensor& self, |
528 | c10::string_view uplo) { |
529 | return detail::eigh_out(eigvals, eigvecs, self, uplo); |
530 | } |
531 | |
532 | /// Computes eigenvalues |
533 | /// |
534 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.eigvalsh |
535 | inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { |
536 | return detail::eigvalsh(self, uplo); |
537 | } |
538 | |
539 | inline Tensor& eigvalsh_out( |
540 | Tensor& result, |
541 | const Tensor& self, |
542 | c10::string_view uplo) { |
543 | return detail::eigvalsh_out(result, self, uplo); |
544 | } |
545 | |
546 | /// Computes the product of Householder matrices |
547 | /// |
548 | /// See |
549 | /// https://pytorch.org/docs/master/linalg.html#torch.linalg.householder_product |
550 | inline Tensor householder_product(const Tensor& input, const Tensor& tau) { |
551 | return detail::householder_product(input, tau); |
552 | } |
553 | |
554 | inline Tensor& householder_product_out( |
555 | Tensor& result, |
556 | const Tensor& input, |
557 | const Tensor& tau) { |
558 | return detail::householder_product_out(result, input, tau); |
559 | } |
560 | |
561 | inline std::tuple<Tensor, Tensor, Tensor, Tensor> lstsq( |
562 | const Tensor& self, |
563 | const Tensor& b, |
564 | c10::optional<double> cond, |
565 | c10::optional<c10::string_view> driver) { |
566 | return detail::lstsq(self, b, cond, driver); |
567 | } |
568 | |
569 | /// Computes the matrix exponential |
570 | /// |
571 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_exp |
572 | inline Tensor matrix_exp(const Tensor& input) { |
573 | return detail::matrix_exp(input); |
574 | } |
575 | |
576 | // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") |
577 | inline Tensor linalg_norm( |
578 | const Tensor& self, |
579 | const optional<Scalar>& opt_ord, |
580 | OptionalIntArrayRef opt_dim, |
581 | bool keepdim, |
582 | optional<ScalarType> opt_dtype) { |
583 | return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); |
584 | } |
585 | |
586 | // C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") |
587 | inline Tensor linalg_norm( |
588 | const Tensor& self, |
589 | c10::string_view ord, |
590 | OptionalIntArrayRef opt_dim, |
591 | bool keepdim, |
592 | optional<ScalarType> opt_dtype) { |
593 | return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); |
594 | } |
595 | |
596 | // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out |
597 | // instead.") |
598 | inline Tensor& linalg_norm_out( |
599 | Tensor& result, |
600 | const Tensor& self, |
601 | const optional<Scalar>& opt_ord, |
602 | OptionalIntArrayRef opt_dim, |
603 | bool keepdim, |
604 | optional<ScalarType> opt_dtype) { |
605 | return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); |
606 | } |
607 | |
608 | // C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out |
609 | // instead.") |
610 | inline Tensor& linalg_norm_out( |
611 | Tensor& result, |
612 | const Tensor& self, |
613 | c10::string_view ord, |
614 | OptionalIntArrayRef opt_dim, |
615 | bool keepdim, |
616 | optional<ScalarType> opt_dtype) { |
617 | return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); |
618 | } |
619 | |
620 | /// Computes the LU factorization with partial pivoting |
621 | /// |
622 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu_factor |
623 | inline std::tuple<Tensor, Tensor> lu_factor( |
624 | const Tensor& input, |
625 | const bool pivot = true) { |
626 | return detail::lu_factor(input, pivot); |
627 | } |
628 | |
629 | inline std::tuple<Tensor&, Tensor&> lu_factor_out( |
630 | Tensor& LU, |
631 | Tensor& pivots, |
632 | const Tensor& self, |
633 | const bool pivot = true) { |
634 | return detail::lu_factor_out(LU, pivots, self, pivot); |
635 | } |
636 | |
637 | /// Computes the LU factorization with partial pivoting |
638 | /// |
639 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu |
640 | inline std::tuple<Tensor, Tensor, Tensor> lu( |
641 | const Tensor& input, |
642 | const bool pivot = true) { |
643 | return detail::lu(input, pivot); |
644 | } |
645 | |
646 | inline std::tuple<Tensor&, Tensor&, Tensor&> lu_out( |
647 | Tensor& P, |
648 | Tensor& L, |
649 | Tensor& U, |
650 | const Tensor& self, |
651 | const bool pivot = true) { |
652 | return detail::lu_out(P, L, U, self, pivot); |
653 | } |
654 | |
655 | inline Tensor norm( |
656 | const Tensor& self, |
657 | const optional<Scalar>& opt_ord, |
658 | OptionalIntArrayRef opt_dim, |
659 | bool keepdim, |
660 | optional<ScalarType> opt_dtype) { |
661 | return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); |
662 | } |
663 | |
664 | inline Tensor norm( |
665 | const Tensor& self, |
666 | std::string ord, |
667 | OptionalIntArrayRef opt_dim, |
668 | bool keepdim, |
669 | optional<ScalarType> opt_dtype) { |
670 | return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); |
671 | } |
672 | |
673 | inline Tensor& norm_out( |
674 | Tensor& result, |
675 | const Tensor& self, |
676 | const optional<Scalar>& opt_ord, |
677 | OptionalIntArrayRef opt_dim, |
678 | bool keepdim, |
679 | optional<ScalarType> opt_dtype) { |
680 | return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); |
681 | } |
682 | |
683 | inline Tensor& norm_out( |
684 | Tensor& result, |
685 | const Tensor& self, |
686 | std::string ord, |
687 | OptionalIntArrayRef opt_dim, |
688 | bool keepdim, |
689 | optional<ScalarType> opt_dtype) { |
690 | return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); |
691 | } |
692 | |
693 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.vector_norm |
694 | inline Tensor vector_norm( |
695 | const Tensor& self, |
696 | Scalar ord, |
697 | OptionalIntArrayRef opt_dim, |
698 | bool keepdim, |
699 | optional<ScalarType> opt_dtype) { |
700 | return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype); |
701 | } |
702 | |
703 | inline Tensor& vector_norm_out( |
704 | Tensor& result, |
705 | const Tensor& self, |
706 | Scalar ord, |
707 | OptionalIntArrayRef opt_dim, |
708 | bool keepdim, |
709 | optional<ScalarType> opt_dtype) { |
710 | return detail::vector_norm_out( |
711 | result, self, ord, opt_dim, keepdim, opt_dtype); |
712 | } |
713 | |
714 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_norm |
715 | inline Tensor matrix_norm( |
716 | const Tensor& self, |
717 | const Scalar& ord, |
718 | IntArrayRef dim, |
719 | bool keepdim, |
720 | optional<ScalarType> dtype) { |
721 | return detail::matrix_norm(self, ord, dim, keepdim, dtype); |
722 | } |
723 | |
724 | inline Tensor& matrix_norm_out( |
725 | const Tensor& self, |
726 | const Scalar& ord, |
727 | IntArrayRef dim, |
728 | bool keepdim, |
729 | optional<ScalarType> dtype, |
730 | Tensor& result) { |
731 | return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); |
732 | } |
733 | |
734 | inline Tensor matrix_norm( |
735 | const Tensor& self, |
736 | std::string ord, |
737 | IntArrayRef dim, |
738 | bool keepdim, |
739 | optional<ScalarType> dtype) { |
740 | return detail::matrix_norm(self, ord, dim, keepdim, dtype); |
741 | } |
742 | |
743 | inline Tensor& matrix_norm_out( |
744 | const Tensor& self, |
745 | std::string ord, |
746 | IntArrayRef dim, |
747 | bool keepdim, |
748 | optional<ScalarType> dtype, |
749 | Tensor& result) { |
750 | return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); |
751 | } |
752 | |
753 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power |
754 | inline Tensor matrix_power(const Tensor& self, int64_t n) { |
755 | return detail::matrix_power(self, n); |
756 | } |
757 | |
758 | inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) { |
759 | return detail::matrix_power_out(self, n, result); |
760 | } |
761 | |
762 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_rank |
763 | inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { |
764 | return detail::matrix_rank(input, tol, hermitian); |
765 | } |
766 | |
767 | inline Tensor matrix_rank( |
768 | const Tensor& input, |
769 | const Tensor& tol, |
770 | bool hermitian) { |
771 | return detail::matrix_rank(input, tol, hermitian); |
772 | } |
773 | |
774 | inline Tensor matrix_rank( |
775 | const Tensor& input, |
776 | c10::optional<double> atol, |
777 | c10::optional<double> rtol, |
778 | bool hermitian) { |
779 | return detail::matrix_rank(input, atol, rtol, hermitian); |
780 | } |
781 | |
782 | inline Tensor matrix_rank( |
783 | const Tensor& input, |
784 | const c10::optional<Tensor>& atol, |
785 | const c10::optional<Tensor>& rtol, |
786 | bool hermitian) { |
787 | return detail::matrix_rank(input, atol, rtol, hermitian); |
788 | } |
789 | |
790 | inline Tensor& matrix_rank_out( |
791 | Tensor& result, |
792 | const Tensor& input, |
793 | double tol, |
794 | bool hermitian) { |
795 | return detail::matrix_rank_out(result, input, tol, hermitian); |
796 | } |
797 | |
798 | inline Tensor& matrix_rank_out( |
799 | Tensor& result, |
800 | const Tensor& input, |
801 | const Tensor& tol, |
802 | bool hermitian) { |
803 | return detail::matrix_rank_out(result, input, tol, hermitian); |
804 | } |
805 | |
806 | inline Tensor& matrix_rank_out( |
807 | Tensor& result, |
808 | const Tensor& input, |
809 | c10::optional<double> atol, |
810 | c10::optional<double> rtol, |
811 | bool hermitian) { |
812 | return detail::matrix_rank_out(result, input, atol, rtol, hermitian); |
813 | } |
814 | |
815 | inline Tensor& matrix_rank_out( |
816 | Tensor& result, |
817 | const Tensor& input, |
818 | const c10::optional<Tensor>& atol, |
819 | const c10::optional<Tensor>& rtol, |
820 | bool hermitian) { |
821 | return detail::matrix_rank_out(result, input, atol, rtol, hermitian); |
822 | } |
823 | |
824 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.multi_dot |
825 | inline Tensor multi_dot(TensorList tensors) { |
826 | return detail::multi_dot(tensors); |
827 | } |
828 | |
829 | inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) { |
830 | return detail::multi_dot_out(tensors, result); |
831 | } |
832 | |
833 | /// Computes the pseudo-inverse |
834 | /// |
835 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv |
836 | inline Tensor pinv( |
837 | const Tensor& input, |
838 | double rcond = 1e-15, |
839 | bool hermitian = false) { |
840 | return detail::pinv(input, rcond, hermitian); |
841 | } |
842 | |
843 | inline Tensor& pinv_out( |
844 | Tensor& result, |
845 | const Tensor& input, |
846 | double rcond = 1e-15, |
847 | bool hermitian = false) { |
848 | return detail::pinv_out(result, input, rcond, hermitian); |
849 | } |
850 | |
851 | /// Computes the QR decomposition |
852 | /// |
853 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.qr |
854 | inline std::tuple<Tensor, Tensor> qr( |
855 | const Tensor& input, |
856 | c10::string_view mode = "reduced" ) { |
857 | // C++17 Change the initialisation to "reduced"sv |
858 | // Same for qr_out |
859 | return detail::qr(input, mode); |
860 | } |
861 | |
862 | inline std::tuple<Tensor&, Tensor&> qr_out( |
863 | Tensor& Q, |
864 | Tensor& R, |
865 | const Tensor& input, |
866 | c10::string_view mode = "reduced" ) { |
867 | return detail::qr_out(Q, R, input, mode); |
868 | } |
869 | |
870 | /// Computes the LDL decomposition |
871 | /// |
872 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.ldl_factor_ex |
873 | inline std::tuple<Tensor, Tensor, Tensor> ldl_factor_ex( |
874 | const Tensor& input, |
875 | bool hermitian, |
876 | bool check_errors) { |
877 | return torch::linalg_ldl_factor_ex(input, hermitian, check_errors); |
878 | } |
879 | |
880 | inline std::tuple<Tensor&, Tensor&, Tensor&> ldl_factor_ex_out( |
881 | Tensor& LD, |
882 | Tensor& pivots, |
883 | Tensor& info, |
884 | const Tensor& input, |
885 | bool hermitian, |
886 | bool check_errors) { |
887 | return torch::linalg_ldl_factor_ex_out( |
888 | LD, pivots, info, input, hermitian, check_errors); |
889 | } |
890 | |
891 | /// Solve a system of linear equations using the LDL decomposition |
892 | /// |
893 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.ldl_solve |
894 | inline Tensor ldl_solve( |
895 | const Tensor& LD, |
896 | const Tensor& pivots, |
897 | const Tensor& B, |
898 | bool hermitian) { |
899 | return torch::linalg_ldl_solve(LD, pivots, B, hermitian); |
900 | } |
901 | |
902 | inline Tensor& ldl_solve_out( |
903 | Tensor& result, |
904 | const Tensor& LD, |
905 | const Tensor& pivots, |
906 | const Tensor& B, |
907 | bool hermitian) { |
908 | return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian); |
909 | } |
910 | |
911 | /// Solves a system linear system AX = B |
912 | /// |
913 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_ex |
914 | inline std::tuple<Tensor, Tensor> solve_ex( |
915 | const Tensor& input, |
916 | const Tensor& other, |
917 | bool left, |
918 | bool check_errors) { |
919 | return detail::solve_ex(input, other, left, check_errors); |
920 | } |
921 | |
922 | inline std::tuple<Tensor&, Tensor&> solve_ex_out( |
923 | Tensor& result, |
924 | Tensor& info, |
925 | const Tensor& input, |
926 | const Tensor& other, |
927 | bool left, |
928 | bool check_errors) { |
929 | return detail::solve_ex_out(result, info, input, other, left, check_errors); |
930 | } |
931 | |
932 | /// Computes a tensor `x` such that `matmul(input, x) = other`. |
933 | /// |
934 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.solve |
935 | inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { |
936 | return detail::solve(input, other, left); |
937 | } |
938 | |
939 | inline Tensor& solve_out( |
940 | Tensor& result, |
941 | const Tensor& input, |
942 | const Tensor& other, |
943 | bool left) { |
944 | return detail::solve_out(result, input, other, left); |
945 | } |
946 | |
947 | /// Computes a solution of a linear system AX = B for input = A and other = B |
948 | /// whenever A is square upper or lower triangular and does not have zeros in |
949 | /// the diagonal |
950 | /// |
951 | /// See |
952 | /// https://pytorch.org/docs/master/linalg.html#torch.linalg.solve_triangular |
953 | inline Tensor solve_triangular( |
954 | const Tensor& input, |
955 | const Tensor& other, |
956 | bool upper, |
957 | bool left, |
958 | bool unitriangular) { |
959 | return detail::solve_triangular(input, other, upper, left, unitriangular); |
960 | } |
961 | |
962 | inline Tensor& solve_triangular_out( |
963 | Tensor& result, |
964 | const Tensor& input, |
965 | const Tensor& other, |
966 | bool upper, |
967 | bool left, |
968 | bool unitriangular) { |
969 | return detail::solve_triangular_out( |
970 | result, input, other, upper, left, unitriangular); |
971 | } |
972 | |
973 | /// Computes the singular values and singular vectors |
974 | /// |
975 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svd |
976 | inline std::tuple<Tensor, Tensor, Tensor> svd( |
977 | const Tensor& input, |
978 | bool full_matrices, |
979 | c10::optional<c10::string_view> driver) { |
980 | return detail::svd(input, full_matrices, driver); |
981 | } |
982 | |
983 | inline std::tuple<Tensor&, Tensor&, Tensor&> svd_out( |
984 | Tensor& U, |
985 | Tensor& S, |
986 | Tensor& Vh, |
987 | const Tensor& input, |
988 | bool full_matrices, |
989 | c10::optional<c10::string_view> driver) { |
990 | return detail::svd_out(U, S, Vh, input, full_matrices, driver); |
991 | } |
992 | |
993 | /// Computes the singular values |
994 | /// |
995 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.svdvals |
996 | inline Tensor svdvals( |
997 | const Tensor& input, |
998 | c10::optional<c10::string_view> driver) { |
999 | return detail::svdvals(input, driver); |
1000 | } |
1001 | |
1002 | inline Tensor& svdvals_out( |
1003 | Tensor& result, |
1004 | const Tensor& input, |
1005 | c10::optional<c10::string_view> driver) { |
1006 | return detail::svdvals_out(result, input, driver); |
1007 | } |
1008 | |
1009 | /// Computes the inverse of a tensor |
1010 | /// |
1011 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorinv |
1012 | /// |
1013 | /// Example: |
1014 | /// ``` |
1015 | /// auto a = torch::eye(4*6).reshape({4, 6, 8, 3}); |
1016 | /// int64_t ind = 2; |
1017 | /// auto ainv = torch::linalg::tensorinv(a, ind); |
1018 | /// ``` |
1019 | inline Tensor tensorinv(const Tensor& self, int64_t ind) { |
1020 | return detail::tensorinv(self, ind); |
1021 | } |
1022 | |
1023 | inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { |
1024 | return detail::tensorinv_out(result, self, ind); |
1025 | } |
1026 | |
1027 | /// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`. |
1028 | /// |
1029 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve |
1030 | /// |
1031 | /// Example: |
1032 | /// ``` |
1033 | /// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4}); |
1034 | /// auto b = torch::randn(2*3, 4); |
1035 | /// auto x = torch::linalg::tensorsolve(a, b); |
1036 | /// ``` |
1037 | inline Tensor tensorsolve( |
1038 | const Tensor& input, |
1039 | const Tensor& other, |
1040 | OptionalIntArrayRef dims) { |
1041 | return detail::tensorsolve(input, other, dims); |
1042 | } |
1043 | |
1044 | inline Tensor& tensorsolve_out( |
1045 | Tensor& result, |
1046 | const Tensor& input, |
1047 | const Tensor& other, |
1048 | OptionalIntArrayRef dims) { |
1049 | return detail::tensorsolve_out(result, input, other, dims); |
1050 | } |
1051 | |
1052 | /// Computes a tensor `inverse_input` such that `dot(input, inverse_input) = |
1053 | /// eye(input.size(0))`. |
1054 | /// |
1055 | /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.inv |
1056 | inline Tensor inv(const Tensor& input) { |
1057 | return detail::inv(input); |
1058 | } |
1059 | |
1060 | inline Tensor& inv_out(Tensor& result, const Tensor& input) { |
1061 | return detail::inv_out(result, input); |
1062 | } |
1063 | |
1064 | } // namespace linalg |
1065 | } // namespace torch |
1066 | |