1 | #pragma once |
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <type.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <array> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | // This should match the tensor used in the code generation (almost exactly) |
16 | template <typename T, int N, typename nvfuser_index_t> |
17 | struct TensorArgCodegen { |
18 | T& operator[](nvfuser_index_t ind) { |
19 | return data[ind]; |
20 | }; |
21 | |
22 | T* data; |
23 | std::array<nvfuser_index_t, N> size; |
24 | std::array<nvfuser_index_t, N> stride; |
25 | constexpr int nDims() const { |
26 | return N; |
27 | } |
28 | void setSize(int i, nvfuser_index_t s) { |
29 | size[i] = s; |
30 | } |
31 | void setStride(int i, nvfuser_index_t s) { |
32 | stride[i] = s; |
33 | } |
34 | nvfuser_index_t getSize(int i) const { |
35 | return size[i]; |
36 | } |
37 | nvfuser_index_t getStride(int i) const { |
38 | return stride[i]; |
39 | } |
40 | }; |
41 | |
42 | // 0-Dim GPU based tensor |
43 | template <typename T, typename nvfuser_index_t> |
44 | struct TensorArgCodegen<T, 0, nvfuser_index_t> { |
45 | T& operator[](nvfuser_index_t ind) { |
46 | return data[ind]; |
47 | }; |
48 | |
49 | T* data; |
50 | constexpr int nDims() const { |
51 | return 0; |
52 | } |
53 | void setSize(int, nvfuser_index_t) { |
54 | TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor" ); |
55 | } |
56 | void setStride(int, nvfuser_index_t) { |
57 | TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor" ); |
58 | } |
59 | nvfuser_index_t getSize(int i) const { |
60 | TORCH_INTERNAL_ASSERT(false, "Tried to get size of a 0-dim tensor" ); |
61 | } |
62 | nvfuser_index_t getStride(int i) const { |
63 | TORCH_INTERNAL_ASSERT(false, "Tried to get stride of a 0-dim tensor" ); |
64 | } |
65 | }; |
66 | |
67 | // Specialization for 0-dim case that's easy to pass in a CPU based tensor |
68 | // without memcpy |
69 | template <typename T> |
70 | struct CpuScalarTensorCodegen { |
71 | T& operator[](int) { |
72 | return data; |
73 | }; |
74 | |
75 | T data; |
76 | }; |
77 | |
78 | // TODO: macro this and the printer below |
79 | enum class ArgType { |
80 | PhiloxCudaState, |
81 | Long, |
82 | Double, |
83 | ComplexDouble, |
84 | Bool, |
85 | Tensor, |
86 | CpuScalarTensor |
87 | }; |
88 | |
89 | inline std::string argTypeToString(ArgType type) { |
90 | std::string ret; |
91 | switch (type) { |
92 | case ArgType::PhiloxCudaState: |
93 | ret = "PhiloxCudaState" ; |
94 | break; |
95 | case ArgType::Long: |
96 | ret = "Long" ; |
97 | break; |
98 | case ArgType::Double: |
99 | ret = "Double" ; |
100 | break; |
101 | case ArgType::ComplexDouble: |
102 | ret = "ComplexDouble" ; |
103 | break; |
104 | case ArgType::Bool: |
105 | ret = "Bool" ; |
106 | break; |
107 | case ArgType::Tensor: |
108 | ret = "Tensor" ; |
109 | break; |
110 | case ArgType::CpuScalarTensor: |
111 | ret = "CpuScalarTensor" ; |
112 | break; |
113 | } |
114 | return ret; |
115 | } |
116 | |
117 | struct ArgAbstract { |
118 | virtual ~ArgAbstract() = default; |
119 | virtual const void* arg() const = 0; |
120 | virtual void* arg() = 0; |
121 | virtual bool isType(ArgType type) const = 0; |
122 | virtual ArgType type() const = 0; |
123 | virtual std::unique_ptr<ArgAbstract> copy_unique_ptr() const = 0; |
124 | virtual void print() const { |
125 | printf("input type: %s\n" , argTypeToString(type()).c_str()); |
126 | }; |
127 | }; |
128 | |
129 | #define DEF_HELPEE_FUNC(TARGET_TYPE, ARG_NAME) \ |
130 | bool isType(ArgType type) const override { \ |
131 | return ArgType::TARGET_TYPE == type; \ |
132 | } \ |
133 | ArgType type() const override { \ |
134 | return ArgType::TARGET_TYPE; \ |
135 | } \ |
136 | const void* arg() const override { \ |
137 | return &ARG_NAME; \ |
138 | } \ |
139 | void* arg() override { \ |
140 | return &ARG_NAME; \ |
141 | } \ |
142 | std::unique_ptr<ArgAbstract> copy_unique_ptr() const override { \ |
143 | return std::make_unique<TARGET_TYPE##Arg>(*this); \ |
144 | } |
145 | |
146 | #define DEF_PRINT_FUNC \ |
147 | void print() const override { \ |
148 | std::cout << val_ << std::endl; \ |
149 | } |
150 | |
151 | struct PhiloxCudaStateArg : public ArgAbstract { |
152 | at::PhiloxCudaState val_; |
153 | PhiloxCudaStateArg(at::PhiloxCudaState _val) : val_(_val){}; |
154 | DEF_HELPEE_FUNC(PhiloxCudaState, val_) |
155 | }; |
156 | |
157 | struct LongArg : public ArgAbstract { |
158 | int64_t val_; |
159 | explicit LongArg(int64_t _val) : val_(_val) {} |
160 | DEF_HELPEE_FUNC(Long, val_) |
161 | DEF_PRINT_FUNC |
162 | }; |
163 | |
164 | struct DoubleArg : public ArgAbstract { |
165 | double val_; |
166 | explicit DoubleArg(double _val) : val_(_val) {} |
167 | DEF_HELPEE_FUNC(Double, val_) |
168 | DEF_PRINT_FUNC |
169 | }; |
170 | |
171 | struct ComplexDoubleArg : public ArgAbstract { |
172 | c10::complex<double> val_; |
173 | explicit ComplexDoubleArg(c10::complex<double> _val) : val_(_val) {} |
174 | DEF_HELPEE_FUNC(ComplexDouble, val_) |
175 | DEF_PRINT_FUNC |
176 | }; |
177 | |
178 | struct BoolArg : public ArgAbstract { |
179 | bool val_; |
180 | explicit BoolArg(bool _val) : val_(_val) {} |
181 | DEF_HELPEE_FUNC(Bool, val_) |
182 | DEF_PRINT_FUNC |
183 | }; |
184 | |
185 | struct TensorArgAbstract : ArgAbstract { |
186 | virtual void setSize(int i, int64_t size) = 0; |
187 | virtual void setStride(int i, int64_t stride) = 0; |
188 | virtual void setPointer(void* ptr) = 0; |
189 | virtual void setDataType(DataType data_type) = 0; |
190 | virtual void setTensor(at::Tensor tensor) = 0; |
191 | |
192 | virtual int64_t getRank() const = 0; |
193 | virtual int64_t getSize(int i) const = 0; |
194 | virtual int64_t getStride(int i) const = 0; |
195 | virtual void* getPointer() const = 0; |
196 | virtual DataType getDataType() const = 0; |
197 | virtual int64_t numel() const = 0; |
198 | virtual at::Tensor getTensor() const = 0; |
199 | |
200 | // TODO: clean it up and also print out dtype |
201 | void print() const override { |
202 | auto rank = getRank(); |
203 | std::cout << "tensor dtype: " << getDataType() << " sizes: (" ; |
204 | for (auto i = 0; i < rank; i++) { |
205 | std::cout << getSize(i) << ", " ; |
206 | } |
207 | std::cout << ") stride: (" ; |
208 | for (auto i = 0; i < rank; i++) { |
209 | std::cout << getStride(i) << ", " ; |
210 | } |
211 | std::cout << ") pointer: " << getPointer() << std::endl; |
212 | } |
213 | }; |
214 | |
215 | template <typename TENSOR_TYPE, typename nvfuser_index_t> |
216 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
217 | struct TensorArg : public TensorArgAbstract { |
218 | TENSOR_TYPE instance_; |
219 | // TODO: this is ugly, we should be extracting data type from `instance_` |
220 | // instead |
221 | DataType data_type_ = DataType::Null; |
222 | at::Tensor tensor_; |
223 | |
224 | void setSize(int i, int64_t size) override { |
225 | instance_.setSize(i, (nvfuser_index_t)size); |
226 | } |
227 | void setStride(int i, int64_t stride) override { |
228 | instance_.setStride(i, (nvfuser_index_t)stride); |
229 | } |
230 | void setPointer(void* ptr) override { |
231 | instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr); |
232 | } |
233 | void setDataType(DataType data_type) override { |
234 | data_type_ = data_type; |
235 | } |
236 | void setTensor(at::Tensor tensor) override { |
237 | tensor_ = tensor; |
238 | } |
239 | |
240 | int64_t getSize(int i) const override { |
241 | return instance_.getSize(i); |
242 | } |
243 | int64_t getStride(int i) const override { |
244 | return instance_.getStride(i); |
245 | } |
246 | int64_t getRank() const override { |
247 | return instance_.nDims(); |
248 | } |
249 | void* getPointer() const override { |
250 | return instance_.data; |
251 | } |
252 | DataType getDataType() const override { |
253 | return data_type_; |
254 | } |
255 | at::Tensor getTensor() const override { |
256 | return tensor_; |
257 | } |
258 | int64_t numel() const override { |
259 | int64_t ret = 1; |
260 | for (auto i : c10::irange(instance_.nDims())) { |
261 | ret *= instance_.getSize(i); |
262 | } |
263 | return ret; |
264 | } |
265 | |
266 | DEF_HELPEE_FUNC(Tensor, instance_) |
267 | }; |
268 | |
269 | template <typename CPU_TENSOR_TYPE> |
270 | struct CpuScalarTensorArg : public ArgAbstract { |
271 | CPU_TENSOR_TYPE instance_; |
272 | |
273 | CpuScalarTensorArg() = delete; |
274 | |
275 | explicit CpuScalarTensorArg(decltype(CPU_TENSOR_TYPE::data) _data) { |
276 | instance_.data = _data; |
277 | } |
278 | |
279 | DEF_HELPEE_FUNC(CpuScalarTensor, instance_) |
280 | }; |
281 | |
282 | // TODO: This class needs some further clean up and refactor |
283 | //! KernelArgumentHolder copies meta information from kernel inputs, including |
284 | //! tensor sizes/shapes/dtype/memory_ptr and copies scalar inputs. It is used |
285 | //! for both compilation as well as kernel execution. The important thing is to |
286 | //! strip ownership of tensor from KernelArgumentHolder, so that during async |
287 | //! compilation, we are not unnecessarily holding memory that is not needed. |
288 | class TORCH_CUDA_CU_API KernelArgumentHolder { |
289 | public: |
290 | //! create KernelArgumentHolder from c10 inputs. Note that we we not taking |
291 | //! the ownership of the memory from the original inputs, but just recording |
292 | //! its meta data for kernel execution/compilation. |
293 | static KernelArgumentHolder createKernelArgumentHolder( |
294 | const c10::ArrayRef<c10::IValue>& inputs); |
295 | |
296 | KernelIndexMode getIndexMode() const { |
297 | return index_mode_; |
298 | } |
299 | |
300 | explicit KernelArgumentHolder(KernelIndexMode index_mode) |
301 | : index_mode_(index_mode) {} |
302 | |
303 | KernelArgumentHolder(const KernelArgumentHolder& self) |
304 | : device_index_(self.getDeviceIndex()), |
305 | cache_id_(self.getCacheId()), |
306 | index_mode_(self.getIndexMode()) { |
307 | for (const auto& arg : self.arguments_) { |
308 | push(arg.get()); |
309 | } |
310 | } |
311 | |
312 | KernelArgumentHolder& operator=(const KernelArgumentHolder& self) { |
313 | device_index_ = self.getDeviceIndex(); |
314 | index_mode_ = self.getIndexMode(); |
315 | for (const auto& arg : self.arguments_) { |
316 | push(arg.get()); |
317 | } |
318 | return *this; |
319 | } |
320 | |
321 | // Push a tensor to the arguments |
322 | void push(const at::Tensor& tensor); |
323 | |
324 | // Push a scalar or integer to the arguments |
325 | void push(const IValue& val); |
326 | |
327 | void push(const at::PhiloxCudaState& val); |
328 | |
329 | // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers |
330 | // in the buffer |
331 | void** getBuffer(); |
332 | |
333 | void push(const c10::ArrayRef<c10::IValue>& args); |
334 | |
335 | void push(const std::vector<at::Tensor>& tensors); |
336 | |
337 | void push(const ArgAbstract* arg); |
338 | |
339 | void swap(int i, const ArgAbstract* arg); |
340 | |
341 | // push int64 |
342 | void push(int64_t val); |
343 | |
344 | const ArgAbstract* back() const { |
345 | return arguments_.back().get(); |
346 | } |
347 | |
348 | void appendPhiloxRNGSeed(uint64_t rand_offset); |
349 | |
350 | const ArgAbstract* operator[](int ind) const { |
351 | return arguments_.at(ind).get(); |
352 | }; |
353 | |
354 | size_t size() const { |
355 | return arguments_.size(); |
356 | } |
357 | |
358 | bool empty() const { |
359 | return arguments_.empty(); |
360 | } |
361 | |
362 | void setDeviceIndex(int index) { |
363 | device_index_ = index; |
364 | } |
365 | |
366 | int getDeviceIndex() const { |
367 | return device_index_; |
368 | } |
369 | |
370 | void setCacheId(size_t id) { |
371 | cache_id_ = id; |
372 | } |
373 | |
374 | c10::optional<size_t> getCacheId() const { |
375 | return cache_id_; |
376 | } |
377 | |
378 | void print() const { |
379 | for (const auto& arg : arguments_) { |
380 | arg->print(); |
381 | } |
382 | } |
383 | |
384 | private: |
385 | std::vector<std::unique_ptr<ArgAbstract>> arguments_; |
386 | std::vector<void*> void_ptrs_; |
387 | bool changed_ = true; |
388 | |
389 | int device_index_ = 0; |
390 | c10::optional<size_t> cache_id_ = c10::nullopt; |
391 | KernelIndexMode index_mode_ = KernelIndexMode::INT64; |
392 | }; |
393 | |
394 | } // namespace cuda |
395 | } // namespace fuser |
396 | } // namespace jit |
397 | } // namespace torch |
398 | |