1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | |
5 | #include <c10/util/Optional.h> |
6 | |
7 | #include <torch/csrc/autograd/generated/variable_factories.h> |
8 | #include <torch/csrc/autograd/variable.h> |
9 | |
10 | // TODO: These don't really belong here but torchvision builds in CI need them |
11 | // Remove once the torchvision version being compiled in CI is updated |
12 | #include <ATen/core/dispatch/Dispatcher.h> |
13 | #include <torch/library.h> |
14 | |
15 | namespace torch { |
16 | |
17 | // NOTE [ Exposing declarations in `at::` to `torch::` ] |
18 | // |
19 | // The following line `using namespace at;` is responsible for exposing all |
20 | // declarations in `at::` namespace to `torch::` namespace. |
21 | // |
22 | // According to the rules laid out in |
23 | // https://en.cppreference.com/w/cpp/language/qualified_lookup, section |
24 | // "Namespace members": |
25 | // ``` |
26 | // Qualified lookup within the scope of a namespace N first considers all |
27 | // declarations that are located in N and all declarations that are located in |
28 | // the inline namespace members of N (and, transitively, in their inline |
29 | // namespace members). If there are no declarations in that set then it |
30 | // considers declarations in all namespaces named by using-directives found in N |
31 | // and in all transitive inline namespace members of N. |
32 | // ``` |
33 | // |
34 | // This means that if both `at::` and `torch::` namespaces have a function with |
35 | // the same signature (e.g. both `at::func()` and `torch::func()` exist), after |
36 | // `namespace torch { using namespace at; }`, when we call `torch::func()`, the |
37 | // `func()` function defined in `torch::` namespace will always be called, and |
38 | // the `func()` function defined in `at::` namespace is always hidden. |
39 | using namespace at; // NOLINT |
40 | |
41 | using c10::nullopt; |
42 | using c10::optional; |
43 | |
44 | using Dtype = at::ScalarType; |
45 | |
46 | /// Fixed width dtypes. |
47 | constexpr auto kUInt8 = at::kByte; |
48 | constexpr auto kInt8 = at::kChar; |
49 | constexpr auto kInt16 = at::kShort; |
50 | constexpr auto kInt32 = at::kInt; |
51 | constexpr auto kInt64 = at::kLong; |
52 | constexpr auto kFloat16 = at::kHalf; |
53 | constexpr auto kFloat32 = at::kFloat; |
54 | constexpr auto kFloat64 = at::kDouble; |
55 | |
56 | /// Rust-style short dtypes. |
57 | constexpr auto kU8 = kUInt8; |
58 | constexpr auto kI8 = kInt8; |
59 | constexpr auto kI16 = kInt16; |
60 | constexpr auto kI32 = kInt32; |
61 | constexpr auto kI64 = kInt64; |
62 | constexpr auto kF16 = kFloat16; |
63 | constexpr auto kF32 = kFloat32; |
64 | constexpr auto kF64 = kFloat64; |
65 | } // namespace torch |
66 | |