1#pragma once
2#include <ATen/jit_macros.h>
3
4#if AT_USE_JITERATOR()
5
6#include <c10/macros/Export.h>
7#include <c10/util/SmallVector.h>
8#include <ATen/core/Tensor.h>
9
10#include <string>
11#include <vector>
12
13namespace at {
14namespace cuda {
15
16TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
17 const std::string& code_string,
18 const std::string& kernel_name,
19 const int num_outputs,
20 const c10::SmallVector<at::Tensor>& tensors,
21 const c10::SmallVector<at::Scalar>& extra_args,
22 bool return_by_ref);
23
24}} // namespace at::cuda
25
26#else
27
28namespace at { namespace cuda {
29TORCH_CUDA_CPP_API c10::SmallVector<at::Tensor> CompileAndLaunchKernel(
30 const std::string& code_string,
31 const std::string& kernel_name,
32 const int num_outputs,
33 const c10::SmallVector<at::Tensor>& tensors,
34 const c10::SmallVector<at::Scalar>& extra_args,
35 bool return_by_ref) {
36 TORCH_CHECK(false, "Jiterator is not supported");
37 }
38}} // namespace at::cuda
39
40#endif // AT_USE_JITERATOR()
41