1 | #pragma once |
2 | #include <ATen/jit_macros.h> |
3 | |
4 | #if AT_USE_JITERATOR() |
5 | |
6 | #include <c10/macros/Export.h> |
7 | #include <c10/util/SmallVector.h> |
8 | #include <ATen/core/Tensor.h> |
9 | |
10 | #include <string> |
11 | #include <vector> |
12 | |
13 | namespace at { |
14 | namespace cuda { |
15 | |
16 | TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel( |
17 | const std::string& code_string, |
18 | const std::string& kernel_name, |
19 | const int num_outputs, |
20 | const c10::SmallVector<at::Tensor>& tensors, |
21 | const c10::SmallVector<at::Scalar>& , |
22 | bool return_by_ref); |
23 | |
24 | }} // namespace at::cuda |
25 | |
26 | #else |
27 | |
28 | namespace at { namespace cuda { |
29 | TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel( |
30 | const std::string& code_string, |
31 | const std::string& kernel_name, |
32 | const int num_outputs, |
33 | const c10::SmallVector<at::Tensor>& tensors, |
34 | const c10::SmallVector<at::Scalar>& extra_args, |
35 | bool return_by_ref) { |
36 | TORCH_CHECK(false, "Jiterator is not supported" ); |
37 | } |
38 | }} // namespace at::cuda |
39 | |
40 | #endif // AT_USE_JITERATOR() |
41 | |