1 | #pragma once |
2 | |
3 | #include <c10/core/SymInt.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/Optional.h> |
7 | |
8 | namespace c10 { |
9 | using SymIntArrayRef = ArrayRef<SymInt>; |
10 | |
11 | inline at::IntArrayRef asIntArrayRefUnchecked(c10::SymIntArrayRef ar) { |
12 | return IntArrayRef(reinterpret_cast<const int64_t*>(ar.data()), ar.size()); |
13 | } |
14 | |
15 | inline c10::optional<at::IntArrayRef> asIntArrayRefSlowOpt( |
16 | c10::SymIntArrayRef ar) { |
17 | for (const c10::SymInt& sci : ar) { |
18 | if (sci.is_symbolic()) { |
19 | return c10::nullopt; |
20 | } |
21 | } |
22 | |
23 | return {asIntArrayRefUnchecked(ar)}; |
24 | } |
25 | |
26 | inline at::IntArrayRef asIntArrayRefSlow( |
27 | c10::SymIntArrayRef ar, |
28 | const char* file, |
29 | int64_t line) { |
30 | for (const c10::SymInt& sci : ar) { |
31 | TORCH_CHECK( |
32 | !sci.is_symbolic(), |
33 | file, |
34 | ":" , |
35 | line, |
36 | ": SymIntArrayRef expected to contain only concrete integers" ); |
37 | } |
38 | return asIntArrayRefUnchecked(ar); |
39 | } |
40 | |
41 | #define C10_AS_INTARRAYREF_SLOW(a) c10::asIntArrayRefSlow(a, __FILE__, __LINE__) |
42 | |
43 | // Prefer using a more semantic constructor, like |
44 | // fromIntArrayRefKnownNonNegative |
45 | inline SymIntArrayRef fromIntArrayRefUnchecked(IntArrayRef array_ref) { |
46 | return SymIntArrayRef( |
47 | reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
48 | } |
49 | |
50 | inline SymIntArrayRef fromIntArrayRefKnownNonNegative(IntArrayRef array_ref) { |
51 | return fromIntArrayRefUnchecked(array_ref); |
52 | } |
53 | |
54 | inline SymIntArrayRef fromIntArrayRefSlow(IntArrayRef array_ref) { |
55 | for (long i : array_ref) { |
56 | TORCH_CHECK( |
57 | SymInt::check_range(i), |
58 | "IntArrayRef contains an int that cannot be represented as a SymInt: " , |
59 | i); |
60 | } |
61 | return SymIntArrayRef( |
62 | reinterpret_cast<const SymInt*>(array_ref.data()), array_ref.size()); |
63 | } |
64 | |
65 | } // namespace c10 |
66 | |