1 | #include <torch/library.h> |
2 | #include <ATen/core/boxing/KernelFunction.h> |
3 | |
4 | using torch::CppFunction; |
5 | |
6 | namespace at { |
7 | |
8 | // Note: [DispatchKey::VmapMode usage] |
9 | // Whenever we're inside a vmap, all Tensors dispatch on this key. At the moment, |
10 | // this key is used to disable random operations inside of vmap. If you are looking |
11 | // for Batching Rules, those are registered with DispatchKey::Batched instead. |
12 | // |
13 | // Note: [Ambiguity of random operations inside vmap] |
14 | // Random operations have an ambiguity where it isn't clear if they should |
15 | // apply the same randomness or apply different randomness. For example: |
16 | // |
17 | // >>> vmap(lambda t: torch.rand(1))(torch.zeros(5)) |
18 | // Should the above return the same random number 5 times, or a different one? |
19 | // |
20 | // We haven't made a decision on that yet so we are temporarily banning random |
21 | // operations inside of vmap while we gather user feedback. |
22 | |
23 | template <typename... Args> Tensor unsupportedRandomOp(Args... args) { |
24 | TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. " , |
25 | "Please perform random operations outside of vmap as a workaround" ); |
26 | } |
27 | |
28 | template <typename... Args> Tensor& unsupportedRandomOp_(Args... args) { |
29 | TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. " , |
30 | "Please perform random operations outside of vmap as a workaround" ); |
31 | } |
32 | |
33 | TORCH_LIBRARY_IMPL(_, VmapMode, m) { |
34 | m.fallback(torch::CppFunction::makeFallthrough()); |
35 | } |
36 | |
37 | TORCH_LIBRARY_IMPL(aten, VmapMode, m) { |
38 | // NB: I'd really like to register a special kernel like |
39 | // CppFunction::makeNamedNotSupported() to avoid listing out the types of everything. |
40 | // However, registering e.g. CppFunction::makeNamedNotSupported() as an implementation |
41 | // only works for operators that support boxing. |
42 | #define TENSOROPTIONS c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool> |
43 | |
44 | // random operations (out-of-place) |
45 | m.impl("bernoulli" , unsupportedRandomOp<const Tensor&, optional<Generator>>); |
46 | m.impl("bernoulli.out" , unsupportedRandomOp_<const Tensor&, optional<Generator>, Tensor&>); |
47 | m.impl("bernoulli.p" , unsupportedRandomOp<const Tensor&, double, optional<Generator>>); |
48 | m.impl("bernoulli_.Tensor" , unsupportedRandomOp_<Tensor&, const Tensor&, optional<Generator>>); |
49 | m.impl("bernoulli_.float" , unsupportedRandomOp_<Tensor&, double, optional<Generator>>); |
50 | |
51 | m.impl("cauchy_" , unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>); |
52 | m.impl("exponential_" , unsupportedRandomOp_<Tensor&, double, optional<Generator>>); |
53 | m.impl("geometric_" , unsupportedRandomOp_<Tensor&, double, optional<Generator>>); |
54 | m.impl("log_normal_" , unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>); |
55 | m.impl("multinomial" , unsupportedRandomOp<const Tensor&, int64_t, bool, optional<Generator>>); |
56 | m.impl("multinomial.out" , unsupportedRandomOp_<const Tensor&, int64_t, bool, optional<Generator>, Tensor&>); |
57 | |
58 | m.impl("normal.Tensor_float" , unsupportedRandomOp<const Tensor&, double, optional<Generator>>); |
59 | m.impl("normal.Tensor_float_out" , unsupportedRandomOp_<const Tensor&, double, optional<Generator>, Tensor&>); |
60 | m.impl("normal.float_Tensor_out" , unsupportedRandomOp_<double, const Tensor&, optional<Generator>, Tensor&>); |
61 | m.impl("normal.float_Tensor" , unsupportedRandomOp<double, const Tensor&, optional<Generator>>); |
62 | m.impl("normal.Tensor_Tensor" , unsupportedRandomOp<const Tensor&, const Tensor&, optional<Generator>>); |
63 | m.impl("normal.Tensor_Tensor_out" , unsupportedRandomOp_<const Tensor&, const Tensor&, optional<Generator>, Tensor&>); |
64 | m.impl("normal.float_float" , unsupportedRandomOp<double, double, IntArrayRef, optional<Generator>, TENSOROPTIONS>); |
65 | m.impl("normal.float_float_out" , unsupportedRandomOp_<double, double, IntArrayRef, optional<Generator>, Tensor&>); |
66 | m.impl("normal_" , unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>); |
67 | |
68 | m.impl("poisson" , unsupportedRandomOp<const Tensor&, optional<Generator>>); |
69 | |
70 | m.impl("random_.from" , unsupportedRandomOp_<Tensor&, int64_t, optional<int64_t>, optional<Generator>>); |
71 | m.impl("random_.to" , unsupportedRandomOp_<Tensor&, int64_t, optional<Generator>>); |
72 | m.impl("random_" , unsupportedRandomOp_<Tensor&, optional<Generator>>); |
73 | |
74 | m.impl("rand_like" , unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>); |
75 | m.impl("randn_like" , unsupportedRandomOp<const Tensor&, TENSOROPTIONS, optional<MemoryFormat>>); |
76 | |
77 | m.impl("randint_like" , unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, optional<MemoryFormat>>); |
78 | m.impl("randint_like.low_dtype" , unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, optional<MemoryFormat>>); |
79 | |
80 | m.impl("rand" , unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>); |
81 | m.impl("rand.generator" , unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>); |
82 | m.impl("rand.names" , unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>); |
83 | m.impl("rand.generator_with_names" , unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>); |
84 | m.impl("rand.out" , unsupportedRandomOp_<IntArrayRef, Tensor&>); |
85 | m.impl("rand.generator_out" , unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>); |
86 | |
87 | m.impl("randn" , unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>); |
88 | m.impl("randn.generator" , unsupportedRandomOp<IntArrayRef, optional<Generator>, TENSOROPTIONS>); |
89 | m.impl("randn.names" , unsupportedRandomOp<IntArrayRef, optional<DimnameList>, TENSOROPTIONS>); |
90 | m.impl("randn.generator_with_names" , unsupportedRandomOp<IntArrayRef, optional<Generator>, optional<DimnameList>, TENSOROPTIONS>); |
91 | m.impl("randn.out" , unsupportedRandomOp_<IntArrayRef, Tensor&>); |
92 | m.impl("randn.generator_out" , unsupportedRandomOp_<IntArrayRef, optional<Generator>, Tensor&>); |
93 | |
94 | m.impl("randperm" , unsupportedRandomOp<int64_t, TENSOROPTIONS>); |
95 | m.impl("randperm.generator" , unsupportedRandomOp<int64_t, optional<Generator>, TENSOROPTIONS>); |
96 | m.impl("randperm.out" , unsupportedRandomOp_<int64_t, Tensor&>); |
97 | m.impl("randperm.generator_out" , unsupportedRandomOp_<int64_t, optional<Generator>, Tensor&>); |
98 | |
99 | m.impl("randint" , unsupportedRandomOp<int64_t, IntArrayRef, TENSOROPTIONS>); |
100 | m.impl("randint.generator" , unsupportedRandomOp<int64_t, IntArrayRef, optional<Generator>, TENSOROPTIONS>); |
101 | m.impl("randint.low" , unsupportedRandomOp<int64_t, int64_t, IntArrayRef, TENSOROPTIONS>); |
102 | m.impl("randint.low_generator" , unsupportedRandomOp<int64_t, int64_t, IntArrayRef, optional<Generator>, TENSOROPTIONS>); |
103 | m.impl("randint.out" , unsupportedRandomOp_<int64_t, IntArrayRef, Tensor&>); |
104 | m.impl("randint.generator_out" , unsupportedRandomOp_<int64_t, IntArrayRef, optional<Generator>, Tensor&>); |
105 | m.impl("randint.low_out" , unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, Tensor&>); |
106 | m.impl("randint.low_generator_out" , unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, optional<Generator>, Tensor&>); |
107 | |
108 | m.impl("uniform_" , unsupportedRandomOp_<Tensor&, double, double, optional<Generator>>); |
109 | |
110 | #undef TENSOROPTIONS |
111 | } |
112 | |
113 | |
114 | } // namespace at |
115 | |