1 | #include <ATen/native/MathBitsFallback.h> |
2 | #include <ATen/native/MathBitFallThroughLists.h> |
3 | |
4 | namespace at::native { |
5 | struct ConjFallback : MathOpFallback { |
6 | ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate" ) {} |
7 | bool is_bit_set(const Tensor& tensor) override { |
8 | return tensor.is_conj(); |
9 | } |
10 | }; |
11 | |
12 | void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { |
13 | ConjFallback object; |
14 | object.fallback_impl(op, dispatch_keys, stack); |
15 | } |
16 | |
17 | TORCH_LIBRARY_IMPL(_, Conjugate, m) { |
18 | m.fallback(torch::CppFunction::makeFromBoxedFunction<&conjugateFallback>()); |
19 | } |
20 | |
21 | TORCH_LIBRARY_IMPL(aten, Conjugate, m) { |
22 | m.impl("set_.source_Storage_storage_offset" , torch::CppFunction::makeFallthrough()); |
23 | m.impl("set_.source_Tensor" , torch::CppFunction::makeFallthrough()); |
24 | m.impl("set_" , torch::CppFunction::makeFallthrough()); |
25 | m.impl("copy_" , torch::CppFunction::makeFallthrough()); |
26 | m.impl("clone" , torch::CppFunction::makeFallthrough()); |
27 | m.impl("_conj_physical" , torch::CppFunction::makeFallthrough()); |
28 | m.impl("conj_physical" , torch::CppFunction::makeFallthrough()); |
29 | m.impl("conj_physical_" , torch::CppFunction::makeFallthrough()); |
30 | m.impl("resolve_conj" , torch::CppFunction::makeFallthrough()); |
31 | m.impl("resolve_neg" , torch::CppFunction::makeFallthrough()); |
32 | m.impl("repeat_interleave.Tensor" , torch::CppFunction::makeFallthrough()); |
33 | m.impl("repeat_interleave.self_Tensor" , torch::CppFunction::makeFallthrough()); |
34 | m.impl("repeat_interleave.self_int" , torch::CppFunction::makeFallthrough()); |
35 | |
36 | // See test_metadata_check_when_primal_has_conj_bit in test_autograd.py |
37 | m.impl("_has_same_storage_numel" , torch::CppFunction::makeFallthrough()); |
38 | m.impl("_new_zeros_with_same_feature_meta" , torch::CppFunction::makeFallthrough()); |
39 | |
40 | // linear algebra functions |
41 | m.impl("dot" , torch::CppFunction::makeFallthrough()); |
42 | m.impl("vdot" , torch::CppFunction::makeFallthrough()); |
43 | m.impl("dot.out" , torch::CppFunction::makeFallthrough()); |
44 | m.impl("vdot.out" , torch::CppFunction::makeFallthrough()); |
45 | m.impl("mm" , torch::CppFunction::makeFallthrough()); |
46 | m.impl("linalg_solve_triangular" , torch::CppFunction::makeFallthrough()); |
47 | m.impl("linalg_solve_triangular.out" , torch::CppFunction::makeFallthrough()); |
48 | m.impl("mm.out" , torch::CppFunction::makeFallthrough()); |
49 | m.impl("addmm" , torch::CppFunction::makeFallthrough()); |
50 | m.impl("addmm_" , torch::CppFunction::makeFallthrough()); |
51 | m.impl("addmm.out" , torch::CppFunction::makeFallthrough()); |
52 | m.impl("bmm" , torch::CppFunction::makeFallthrough()); |
53 | m.impl("bmm.out" , torch::CppFunction::makeFallthrough()); |
54 | m.impl("baddbmm" , torch::CppFunction::makeFallthrough()); |
55 | m.impl("baddbmm_" , torch::CppFunction::makeFallthrough()); |
56 | m.impl("baddbmm.out" , torch::CppFunction::makeFallthrough()); |
57 | m.impl("linalg_svd" , torch::CppFunction::makeFallthrough()); |
58 | m.impl("linalg_svd.U" , torch::CppFunction::makeFallthrough()); |
59 | |
60 | TORCH_VIEW_FNS(m) |
61 | TENSOR_UTILITIES_AND_CONSTRUCTORS(m) |
62 | TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) |
63 | } |
64 | |
65 | } // namespace at::native |
66 | |