1#include <c10/util/irange.h>
2
3// Extract size and strides
4#include <kernel_cache.h>
5
6#include <executor_kernel_arg.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13namespace {
14
15template <typename T, typename nvfuser_index_t>
16std::unique_ptr<TensorArgAbstract> getTensorArg(int nDims) {
17 switch (nDims) {
18 case (0):
19 return std::make_unique<TensorArg<
20 TensorArgCodegen<T, 0, nvfuser_index_t>,
21 nvfuser_index_t>>();
22 case (1):
23 return std::make_unique<TensorArg<
24 TensorArgCodegen<T, 1, nvfuser_index_t>,
25 nvfuser_index_t>>();
26 case (2):
27 return std::make_unique<TensorArg<
28 TensorArgCodegen<T, 2, nvfuser_index_t>,
29 nvfuser_index_t>>();
30 case (3):
31 return std::make_unique<TensorArg<
32 TensorArgCodegen<T, 3, nvfuser_index_t>,
33 nvfuser_index_t>>();
34 case (4):
35 return std::make_unique<TensorArg<
36 TensorArgCodegen<T, 4, nvfuser_index_t>,
37 nvfuser_index_t>>();
38 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
39 case (5):
40 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
41 return std::make_unique<TensorArg<
42 TensorArgCodegen<T, 5, nvfuser_index_t>,
43 nvfuser_index_t>>();
44 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
45 case (6):
46 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
47 return std::make_unique<TensorArg<
48 TensorArgCodegen<T, 6, nvfuser_index_t>,
49 nvfuser_index_t>>();
50 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
51 case (7):
52 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
53 return std::make_unique<TensorArg<
54 TensorArgCodegen<T, 7, nvfuser_index_t>,
55 nvfuser_index_t>>();
56 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
57 case (8):
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
59 return std::make_unique<TensorArg<
60 TensorArgCodegen<T, 8, nvfuser_index_t>,
61 nvfuser_index_t>>();
62 default:
63 TORCH_INTERNAL_ASSERT(
64 false,
65 "Tried to generate a tensor to run a generated kernel with ",
66 nDims,
67 " dimensions, however only 0 to 8 dimensional tensor are supported.");
68 }
69 return nullptr;
70}
71
72template <typename INDEX_MODE>
73std::unique_ptr<TensorArgAbstract> getTensorArg(
74 c10::ScalarType dtype,
75 int nDims) {
76 switch (dtype) {
77 case c10::ScalarType::Double:
78 return getTensorArg<double, INDEX_MODE>(nDims);
79 case c10::ScalarType::Float:
80 return getTensorArg<float, INDEX_MODE>(nDims);
81 case c10::ScalarType::Half:
82 return getTensorArg<at::Half, INDEX_MODE>(nDims);
83 case c10::ScalarType::BFloat16:
84 return getTensorArg<at::BFloat16, INDEX_MODE>(nDims);
85 case c10::ScalarType::Bool:
86 return getTensorArg<bool, INDEX_MODE>(nDims);
87 case c10::ScalarType::Long:
88 return getTensorArg<int64_t, INDEX_MODE>(nDims);
89 case c10::ScalarType::Int:
90 return getTensorArg<int32_t, INDEX_MODE>(nDims);
91 case c10::ScalarType::ComplexFloat:
92 return getTensorArg<c10::complex<float>, INDEX_MODE>(nDims);
93 case c10::ScalarType::ComplexDouble:
94 return getTensorArg<c10::complex<double>, INDEX_MODE>(nDims);
95 default:
96 TORCH_CHECK(
97 false,
98 "Dtype: ",
99 dtype,
100 " not currently supported in code generated kernels.");
101 }
102}
103
104std::unique_ptr<TensorArgAbstract> getTensorArg(
105 c10::ScalarType dtype,
106 int nDims,
107 KernelIndexMode index_mode) {
108 switch (index_mode) {
109 case KernelIndexMode::INT32:
110 return getTensorArg<int>(dtype, nDims);
111 case KernelIndexMode::INT64:
112 return getTensorArg<int64_t>(dtype, nDims);
113 default:
114 break;
115 }
116
117 TORCH_INTERNAL_ASSERT(false, "unknown index mode");
118 return nullptr;
119}
120
121} // namespace
122
123KernelArgumentHolder KernelArgumentHolder::createKernelArgumentHolder(
124 const c10::ArrayRef<c10::IValue>& inputs) {
125 if (inputs.empty()) {
126 // default to int32 on device 0
127 KernelArgumentHolder args(KernelIndexMode::INT32);
128 args.setDeviceIndex(0);
129 return args;
130 }
131 auto device_index = getCommonDeviceCUDA(inputs);
132 auto index_mode = collectIndexMode(inputs);
133
134 KernelArgumentHolder args(index_mode);
135 args.setDeviceIndex(device_index);
136 args.push(inputs);
137
138 return args;
139}
140
141// Push a tensor to the arguments
142void KernelArgumentHolder::push(const at::Tensor& tensor) {
143 changed_ = true;
144 if (is_cpu_scalar(tensor)) {
145 switch (tensor.scalar_type()) {
146 case c10::ScalarType::ComplexDouble:
147 arguments_.push_back(std::make_unique<CpuScalarTensorArg<
148 CpuScalarTensorCodegen<c10::complex<double>>>>(
149 tensor.data_ptr<c10::complex<double>>()[0]));
150 break;
151 case c10::ScalarType::ComplexFloat:
152 arguments_.push_back(std::make_unique<CpuScalarTensorArg<
153 CpuScalarTensorCodegen<c10::complex<float>>>>(
154 tensor.data_ptr<c10::complex<float>>()[0]));
155 break;
156 case c10::ScalarType::Double:
157 arguments_.push_back(
158 std::make_unique<
159 CpuScalarTensorArg<CpuScalarTensorCodegen<double>>>(
160 tensor.data_ptr<double>()[0]));
161 break;
162 case c10::ScalarType::Float:
163 arguments_.push_back(
164 std::make_unique<CpuScalarTensorArg<CpuScalarTensorCodegen<float>>>(
165 tensor.data_ptr<float>()[0]));
166 break;
167 case c10::ScalarType::Half:
168 arguments_.push_back(
169 std::make_unique<
170 CpuScalarTensorArg<CpuScalarTensorCodegen<at::Half>>>(
171 tensor.data_ptr<at::Half>()[0]));
172 break;
173 case c10::ScalarType::BFloat16:
174 arguments_.push_back(
175 std::make_unique<
176 CpuScalarTensorArg<CpuScalarTensorCodegen<at::BFloat16>>>(
177 tensor.data_ptr<at::BFloat16>()[0]));
178 break;
179 case c10::ScalarType::Bool:
180 arguments_.push_back(
181 std::make_unique<CpuScalarTensorArg<CpuScalarTensorCodegen<bool>>>(
182 tensor.data_ptr<bool>()[0]));
183 break;
184 case c10::ScalarType::Long:
185 arguments_.push_back(
186 std::make_unique<
187 CpuScalarTensorArg<CpuScalarTensorCodegen<int64_t>>>(
188 tensor.data_ptr<int64_t>()[0]));
189 break;
190 case c10::ScalarType::Int:
191 arguments_.push_back(
192 std::make_unique<
193 CpuScalarTensorArg<CpuScalarTensorCodegen<int32_t>>>(
194 tensor.data_ptr<int32_t>()[0]));
195 break;
196 default:
197 TORCH_CHECK(
198 false,
199 "Dtype: ",
200 tensor.scalar_type(),
201 " not currently supported in code generated kernels.");
202 }
203 } else {
204 int nDims = tensor.ndimension();
205
206 c10::ScalarType dtype = tensor.scalar_type();
207 std::unique_ptr<TensorArgAbstract> tensor_arg =
208 getTensorArg(dtype, nDims, index_mode_);
209 tensor_arg->setTensor(tensor);
210 tensor_arg->setPointer(tensor.data_ptr());
211 tensor_arg->setDataType(aten_to_data_type(dtype));
212 for (const auto i : c10::irange(nDims)) {
213 tensor_arg->setSize(i, tensor.sizes()[i]);
214 tensor_arg->setStride(i, tensor.strides()[i]);
215 }
216 arguments_.push_back(std::move(tensor_arg));
217 }
218}
219
220// Push a scalar or integer to the arguments
221void KernelArgumentHolder::push(const IValue& val) {
222 changed_ = true;
223 TORCH_INTERNAL_ASSERT(
224 val.isScalar(),
225 "Tried to push an arg to run in a fused kernel, expected a scalar but got, ",
226 val);
227 auto scalar_val = val.toScalar();
228 switch (scalar_val.type()) {
229 // NOLINTNEXTLINE(bugprone-branch-clone)
230 case c10::ScalarType::ComplexDouble:
231 arguments_.push_back(
232 std::make_unique<ComplexDoubleArg>(scalar_val.toComplexDouble()));
233 return;
234 case c10::ScalarType::Double:
235 arguments_.push_back(std::make_unique<DoubleArg>(scalar_val.toDouble()));
236 return;
237 case c10::ScalarType::Long:
238 arguments_.push_back(std::make_unique<LongArg>(scalar_val.toLong()));
239 return;
240 case c10::ScalarType::Bool:
241 arguments_.push_back(std::make_unique<BoolArg>(scalar_val.toBool()));
242 return;
243 default:
244 TORCH_INTERNAL_ASSERT(
245 false,
246 " Tried to create argument to send to a fused kernel, but got an unexpected type.");
247 }
248 TORCH_INTERNAL_ASSERT(
249 false,
250 " Tried to create argument to send to a fused kernel, but got a non-scalar type.");
251}
252
253void KernelArgumentHolder::push(int64_t val) {
254 arguments_.push_back(std::make_unique<LongArg>(val));
255}
256
257void KernelArgumentHolder::push(const at::PhiloxCudaState& val) {
258 arguments_.push_back(std::make_unique<PhiloxCudaStateArg>(val));
259}
260
261// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
262// in the buffer
263void** KernelArgumentHolder::getBuffer() {
264 if (changed_) {
265 void_ptrs_ = std::vector<void*>(arguments_.size(), nullptr);
266 for (const auto i : c10::irange(arguments_.size())) {
267 void_ptrs_[i] = static_cast<void*>(arguments_[i]->arg());
268 }
269 changed_ = false;
270 }
271 return void_ptrs_.data();
272}
273
274void KernelArgumentHolder::push(const c10::ArrayRef<c10::IValue>& args) {
275 // Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
276 // allocated here from the subgraph could be, and very likely are, different
277 // from I/O expected by the generated CUDA kernel.
278 for (const auto& arg : args) {
279 if (arg.isTensor()) {
280 push(arg.toTensor());
281 } else {
282 push(arg);
283 }
284 }
285}
286
287void KernelArgumentHolder::push(const std::vector<at::Tensor>& tensors) {
288 for (const auto& tensor : tensors) {
289 push(tensor);
290 }
291}
292
293void KernelArgumentHolder::push(const ArgAbstract* arg) {
294 changed_ = true;
295 arguments_.emplace_back(arg->copy_unique_ptr());
296}
297
298void KernelArgumentHolder::swap(int i, const ArgAbstract* arg) {
299 changed_ = true;
300 auto holder = arg->copy_unique_ptr();
301 arguments_[i].swap(holder);
302}
303
304void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) {
305 at::PhiloxCudaState philox_engine_inputs;
306 auto gen = at::cuda::detail::getDefaultCUDAGenerator();
307 {
308 // See Note [Acquire lock when using random generators]
309 std::lock_guard<std::mutex> lock(gen.mutex());
310 philox_engine_inputs =
311 at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_cuda_state(
312 rand_offset);
313 }
314 push(philox_engine_inputs);
315}
316
317} // namespace cuda
318} // namespace fuser
319} // namespace jit
320} // namespace torch
321