1 | #pragma once |
2 | |
3 | #include <c10/core/SymInt.h> |
4 | #include <c10/util/Exception.h> |
5 | |
6 | namespace c10 { |
7 | |
8 | namespace detail { |
9 | // This template can only be specialized at int64_t and c10::SymInt; |
10 | // you'll get linker errors otherwise |
11 | template <typename T> |
12 | C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar); |
13 | } // namespace detail |
14 | |
15 | template <typename T> |
16 | T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) { |
17 | // Inline the fast paths |
18 | if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) { |
19 | // For SymInts, we want an explicit control flow to trigger a guard, so we |
20 | // may as well branch too. |
21 | if (dim < 0) { |
22 | return dim + dim_post_expr; |
23 | } |
24 | return dim; |
25 | } |
26 | // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors) |
27 | return c10::detail::maybe_wrap_dim_slow<T>( |
28 | std::move(dim), std::move(dim_post_expr), wrap_scalar); |
29 | } |
30 | |
31 | inline int64_t maybe_wrap_dim( |
32 | int64_t dim, |
33 | int64_t dim_post_expr, |
34 | bool wrap_scalar = true) { |
35 | return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); |
36 | } |
37 | |
38 | inline c10::SymInt maybe_wrap_dim( |
39 | c10::SymInt dim, |
40 | c10::SymInt dim_post_expr, |
41 | bool wrap_scalar = true) { |
42 | return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); |
43 | } |
44 | |
45 | } // namespace c10 |
46 | |