1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <ATen/cuda/CUDAGeneratorImpl.h>
5#include <c10/util/Exception.h>
6#include <type.h>
7#include <torch/csrc/jit/ir/ir.h>
8#include <array>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15// This should match the tensor used in the code generation (almost exactly)
16template <typename T, int N, typename nvfuser_index_t>
17struct TensorArgCodegen {
18 T& operator[](nvfuser_index_t ind) {
19 return data[ind];
20 };
21
22 T* data;
23 std::array<nvfuser_index_t, N> size;
24 std::array<nvfuser_index_t, N> stride;
25 constexpr int nDims() const {
26 return N;
27 }
28 void setSize(int i, nvfuser_index_t s) {
29 size[i] = s;
30 }
31 void setStride(int i, nvfuser_index_t s) {
32 stride[i] = s;
33 }
34 nvfuser_index_t getSize(int i) const {
35 return size[i];
36 }
37 nvfuser_index_t getStride(int i) const {
38 return stride[i];
39 }
40};
41
42// 0-Dim GPU based tensor
43template <typename T, typename nvfuser_index_t>
44struct TensorArgCodegen<T, 0, nvfuser_index_t> {
45 T& operator[](nvfuser_index_t ind) {
46 return data[ind];
47 };
48
49 T* data;
50 constexpr int nDims() const {
51 return 0;
52 }
53 void setSize(int, nvfuser_index_t) {
54 TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor");
55 }
56 void setStride(int, nvfuser_index_t) {
57 TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor");
58 }
59 nvfuser_index_t getSize(int i) const {
60 TORCH_INTERNAL_ASSERT(false, "Tried to get size of a 0-dim tensor");
61 }
62 nvfuser_index_t getStride(int i) const {
63 TORCH_INTERNAL_ASSERT(false, "Tried to get stride of a 0-dim tensor");
64 }
65};
66
67// Specialization for 0-dim case that's easy to pass in a CPU based tensor
68// without memcpy
69template <typename T>
70struct CpuScalarTensorCodegen {
71 T& operator[](int) {
72 return data;
73 };
74
75 T data;
76};
77
78// TODO: macro this and the printer below
79enum class ArgType {
80 PhiloxCudaState,
81 Long,
82 Double,
83 ComplexDouble,
84 Bool,
85 Tensor,
86 CpuScalarTensor
87};
88
89inline std::string argTypeToString(ArgType type) {
90 std::string ret;
91 switch (type) {
92 case ArgType::PhiloxCudaState:
93 ret = "PhiloxCudaState";
94 break;
95 case ArgType::Long:
96 ret = "Long";
97 break;
98 case ArgType::Double:
99 ret = "Double";
100 break;
101 case ArgType::ComplexDouble:
102 ret = "ComplexDouble";
103 break;
104 case ArgType::Bool:
105 ret = "Bool";
106 break;
107 case ArgType::Tensor:
108 ret = "Tensor";
109 break;
110 case ArgType::CpuScalarTensor:
111 ret = "CpuScalarTensor";
112 break;
113 }
114 return ret;
115}
116
117struct ArgAbstract {
118 virtual ~ArgAbstract() = default;
119 virtual const void* arg() const = 0;
120 virtual void* arg() = 0;
121 virtual bool isType(ArgType type) const = 0;
122 virtual ArgType type() const = 0;
123 virtual std::unique_ptr<ArgAbstract> copy_unique_ptr() const = 0;
124 virtual void print() const {
125 printf("input type: %s\n", argTypeToString(type()).c_str());
126 };
127};
128
129#define DEF_HELPEE_FUNC(TARGET_TYPE, ARG_NAME) \
130 bool isType(ArgType type) const override { \
131 return ArgType::TARGET_TYPE == type; \
132 } \
133 ArgType type() const override { \
134 return ArgType::TARGET_TYPE; \
135 } \
136 const void* arg() const override { \
137 return &ARG_NAME; \
138 } \
139 void* arg() override { \
140 return &ARG_NAME; \
141 } \
142 std::unique_ptr<ArgAbstract> copy_unique_ptr() const override { \
143 return std::make_unique<TARGET_TYPE##Arg>(*this); \
144 }
145
146#define DEF_PRINT_FUNC \
147 void print() const override { \
148 std::cout << val_ << std::endl; \
149 }
150
151struct PhiloxCudaStateArg : public ArgAbstract {
152 at::PhiloxCudaState val_;
153 PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){};
154 DEF_HELPEE_FUNC(PhiloxCudaState, val_)
155};
156
157struct LongArg : public ArgAbstract {
158 int64_t val_;
159 explicit LongArg(int64_t _val) : val_(_val) {}
160 DEF_HELPEE_FUNC(Long, val_)
161 DEF_PRINT_FUNC
162};
163
164struct DoubleArg : public ArgAbstract {
165 double val_;
166 explicit DoubleArg(double _val) : val_(_val) {}
167 DEF_HELPEE_FUNC(Double, val_)
168 DEF_PRINT_FUNC
169};
170
171struct ComplexDoubleArg : public ArgAbstract {
172 c10::complex<double> val_;
173 explicit ComplexDoubleArg(c10::complex<double> _val) : val_(_val) {}
174 DEF_HELPEE_FUNC(ComplexDouble, val_)
175 DEF_PRINT_FUNC
176};
177
178struct BoolArg : public ArgAbstract {
179 bool val_;
180 explicit BoolArg(bool _val) : val_(_val) {}
181 DEF_HELPEE_FUNC(Bool, val_)
182 DEF_PRINT_FUNC
183};
184
185struct TensorArgAbstract : ArgAbstract {
186 virtual void setSize(int i, int64_t size) = 0;
187 virtual void setStride(int i, int64_t stride) = 0;
188 virtual void setPointer(void* ptr) = 0;
189 virtual void setDataType(DataType data_type) = 0;
190 virtual void setTensor(at::Tensor tensor) = 0;
191
192 virtual int64_t getRank() const = 0;
193 virtual int64_t getSize(int i) const = 0;
194 virtual int64_t getStride(int i) const = 0;
195 virtual void* getPointer() const = 0;
196 virtual DataType getDataType() const = 0;
197 virtual int64_t numel() const = 0;
198 virtual at::Tensor getTensor() const = 0;
199
200 // TODO: clean it up and also print out dtype
201 void print() const override {
202 auto rank = getRank();
203 std::cout << "tensor dtype: " << getDataType() << " sizes: (";
204 for (auto i = 0; i < rank; i++) {
205 std::cout << getSize(i) << ", ";
206 }
207 std::cout << ") stride: (";
208 for (auto i = 0; i < rank; i++) {
209 std::cout << getStride(i) << ", ";
210 }
211 std::cout << ") pointer: " << getPointer() << std::endl;
212 }
213};
214
215template <typename TENSOR_TYPE, typename nvfuser_index_t>
216// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
217struct TensorArg : public TensorArgAbstract {
218 TENSOR_TYPE instance_;
219 // TODO: this is ugly, we should be extracting data type from `instance_`
220 // instead
221 DataType data_type_ = DataType::Null;
222 at::Tensor tensor_;
223
224 void setSize(int i, int64_t size) override {
225 instance_.setSize(i, (nvfuser_index_t)size);
226 }
227 void setStride(int i, int64_t stride) override {
228 instance_.setStride(i, (nvfuser_index_t)stride);
229 }
230 void setPointer(void* ptr) override {
231 instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
232 }
233 void setDataType(DataType data_type) override {
234 data_type_ = data_type;
235 }
236 void setTensor(at::Tensor tensor) override {
237 tensor_ = tensor;
238 }
239
240 int64_t getSize(int i) const override {
241 return instance_.getSize(i);
242 }
243 int64_t getStride(int i) const override {
244 return instance_.getStride(i);
245 }
246 int64_t getRank() const override {
247 return instance_.nDims();
248 }
249 void* getPointer() const override {
250 return instance_.data;
251 }
252 DataType getDataType() const override {
253 return data_type_;
254 }
255 at::Tensor getTensor() const override {
256 return tensor_;
257 }
258 int64_t numel() const override {
259 int64_t ret = 1;
260 for (auto i : c10::irange(instance_.nDims())) {
261 ret *= instance_.getSize(i);
262 }
263 return ret;
264 }
265
266 DEF_HELPEE_FUNC(Tensor, instance_)
267};
268
269template <typename CPU_TENSOR_TYPE>
270struct CpuScalarTensorArg : public ArgAbstract {
271 CPU_TENSOR_TYPE instance_;
272
273 CpuScalarTensorArg() = delete;
274
275 explicit CpuScalarTensorArg(decltype(CPU_TENSOR_TYPE::data) _data) {
276 instance_.data = _data;
277 }
278
279 DEF_HELPEE_FUNC(CpuScalarTensor, instance_)
280};
281
282// TODO: This class needs some further clean up and refactor
283//! KernelArgumentHolder copies meta information from kernel inputs, including
284//! tensor sizes/shapes/dtype/memory_ptr and copies scalar inputs. It is used
285//! for both compilation as well as kernel execution. The important thing is to
286//! strip ownership of tensor from KernelArgumentHolder, so that during async
287//! compilation, we are not unnecessarily holding memory that is not needed.
288class TORCH_CUDA_CU_API KernelArgumentHolder {
289 public:
290 //! create KernelArgumentHolder from c10 inputs. Note that we we not taking
291 //! the ownership of the memory from the original inputs, but just recording
292 //! its meta data for kernel execution/compilation.
293 static KernelArgumentHolder createKernelArgumentHolder(
294 const c10::ArrayRef<c10::IValue>& inputs);
295
296 KernelIndexMode getIndexMode() const {
297 return index_mode_;
298 }
299
300 explicit KernelArgumentHolder(KernelIndexMode index_mode)
301 : index_mode_(index_mode) {}
302
303 KernelArgumentHolder(const KernelArgumentHolder& self)
304 : device_index_(self.getDeviceIndex()),
305 cache_id_(self.getCacheId()),
306 index_mode_(self.getIndexMode()) {
307 for (const auto& arg : self.arguments_) {
308 push(arg.get());
309 }
310 }
311
312 KernelArgumentHolder& operator=(const KernelArgumentHolder& self) {
313 device_index_ = self.getDeviceIndex();
314 index_mode_ = self.getIndexMode();
315 for (const auto& arg : self.arguments_) {
316 push(arg.get());
317 }
318 return *this;
319 }
320
321 // Push a tensor to the arguments
322 void push(const at::Tensor& tensor);
323
324 // Push a scalar or integer to the arguments
325 void push(const IValue& val);
326
327 void push(const at::PhiloxCudaState& val);
328
329 // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
330 // in the buffer
331 void** getBuffer();
332
333 void push(const c10::ArrayRef<c10::IValue>& args);
334
335 void push(const std::vector<at::Tensor>& tensors);
336
337 void push(const ArgAbstract* arg);
338
339 void swap(int i, const ArgAbstract* arg);
340
341 // push int64
342 void push(int64_t val);
343
344 const ArgAbstract* back() const {
345 return arguments_.back().get();
346 }
347
348 void appendPhiloxRNGSeed(uint64_t rand_offset);
349
350 const ArgAbstract* operator[](int ind) const {
351 return arguments_.at(ind).get();
352 };
353
354 size_t size() const {
355 return arguments_.size();
356 }
357
358 bool empty() const {
359 return arguments_.empty();
360 }
361
362 void setDeviceIndex(int index) {
363 device_index_ = index;
364 }
365
366 int getDeviceIndex() const {
367 return device_index_;
368 }
369
370 void setCacheId(size_t id) {
371 cache_id_ = id;
372 }
373
374 c10::optional<size_t> getCacheId() const {
375 return cache_id_;
376 }
377
378 void print() const {
379 for (const auto& arg : arguments_) {
380 arg->print();
381 }
382 }
383
384 private:
385 std::vector<std::unique_ptr<ArgAbstract>> arguments_;
386 std::vector<void*> void_ptrs_;
387 bool changed_ = true;
388
389 int device_index_ = 0;
390 c10::optional<size_t> cache_id_ = c10::nullopt;
391 KernelIndexMode index_mode_ = KernelIndexMode::INT64;
392};
393
394} // namespace cuda
395} // namespace fuser
396} // namespace jit
397} // namespace torch
398