1#pragma once
2
3#include <c10/core/SymInt.h>
4#include <c10/util/Exception.h>
5
6namespace c10 {
7
8namespace detail {
9// This template can only be specialized at int64_t and c10::SymInt;
10// you'll get linker errors otherwise
11template <typename T>
12C10_API T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar);
13} // namespace detail
14
15template <typename T>
16T _maybe_wrap_dim(T dim, T dim_post_expr, bool wrap_scalar = true) {
17 // Inline the fast paths
18 if (C10_LIKELY(dim_post_expr * -1 <= dim && dim < dim_post_expr)) {
19 // For SymInts, we want an explicit control flow to trigger a guard, so we
20 // may as well branch too.
21 if (dim < 0) {
22 return dim + dim_post_expr;
23 }
24 return dim;
25 }
26 // Check edge-cases out-of-line (wrapping scalars and out-of-bounds errors)
27 return c10::detail::maybe_wrap_dim_slow<T>(
28 std::move(dim), std::move(dim_post_expr), wrap_scalar);
29}
30
31inline int64_t maybe_wrap_dim(
32 int64_t dim,
33 int64_t dim_post_expr,
34 bool wrap_scalar = true) {
35 return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar);
36}
37
38inline c10::SymInt maybe_wrap_dim(
39 c10::SymInt dim,
40 c10::SymInt dim_post_expr,
41 bool wrap_scalar = true) {
42 return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar);
43}
44
45} // namespace c10
46