1#pragma once
2#include <type.h>
3
4namespace torch {
5namespace jit {
6namespace fuser {
7namespace cuda {
8
9class TORCH_CUDA_CU_API LaunchParams {
10 public:
11 static constexpr int64_t UNINITIALIZED_VAL = -1;
12
13 LaunchParams(
14 int64_t gdimx = UNINITIALIZED_VAL,
15 int64_t gdimy = UNINITIALIZED_VAL,
16 int64_t gdimz = UNINITIALIZED_VAL,
17 int64_t bdimx = UNINITIALIZED_VAL,
18 int64_t bdimy = UNINITIALIZED_VAL,
19 int64_t bdimz = UNINITIALIZED_VAL)
20 : gdimx_(gdimx),
21 gdimy_(gdimy),
22 gdimz_(gdimz),
23 bdimx_(bdimx),
24 bdimy_(bdimy),
25 bdimz_(bdimz) {
26 assertValid();
27 }
28
29 void assertValid();
30
31 void setSmem(int64_t smem) {
32 smem_ = smem;
33 }
34
35 int64_t smem() const {
36 return smem_;
37 }
38
39 int64_t nBlocks() const {
40 return std::abs(gdimx_ * gdimy_ * gdimz_);
41 }
42
43 int64_t nThreads() const {
44 return std::abs(bdimx_ * bdimy_ * bdimz_);
45 }
46
47 int64_t bdimx() const {
48 return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_);
49 }
50
51 int64_t gdimx() const {
52 return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_);
53 }
54
55 int64_t bdimy() const {
56 return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_);
57 }
58
59 int64_t gdimy() const {
60 return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_);
61 }
62
63 int64_t bdimz() const {
64 return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_);
65 }
66
67 int64_t gdimz() const {
68 return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_);
69 }
70
71 void checkAndSet(
72 const int64_t incoming_val,
73 int64_t& class_val,
74 std::string val) {
75 TORCH_INTERNAL_ASSERT(
76 class_val == UNINITIALIZED_VAL || incoming_val == class_val,
77 "Tried to set ",
78 val,
79 " from ",
80 class_val,
81 " to ",
82 incoming_val,
83 ", but it was already set and new value does not match.",
84 " Thread dims all have to be bound to the same value.");
85 TORCH_CHECK(
86 incoming_val > 0,
87 "Received a thread binding on ",
88 val,
89 " that is ",
90 incoming_val,
91 ". Cannot create negative threads.");
92 if (class_val == UNINITIALIZED_VAL) {
93 class_val = incoming_val;
94 }
95 assertValid();
96 }
97
98 // Binds dim assocaited with p_type to val
99 void bind(int64_t val, ParallelType p_type);
100
101 // Adjusted value based on get functions above for each value
102 int64_t getDim(ParallelType p_type) const;
103
104 // Returns raw value which may be UNINITIALIZED_VAL
105 const int64_t& getRawVal(ParallelType p_type) const;
106
107 // Returns false if value associated with p_type == UNINITIALIZED_VAL
108 bool hasDim(ParallelType p_type) const;
109
110 bool operator==(const LaunchParams& other) const;
111
112 void print() const;
113
114 std::string toString() const;
115
116 private:
117 // Spell them out because I want signed ints to know if they were initialized
118 // or not.
119 // TODO: convert to c10::optional
120 int64_t gdimx_ = UNINITIALIZED_VAL;
121 int64_t gdimy_ = UNINITIALIZED_VAL;
122 int64_t gdimz_ = UNINITIALIZED_VAL;
123 int64_t bdimx_ = UNINITIALIZED_VAL;
124 int64_t bdimy_ = UNINITIALIZED_VAL;
125 int64_t bdimz_ = UNINITIALIZED_VAL;
126
127 int64_t smem_ = 0;
128
129 // TODO: Fill in output sizes
130 std::vector<std::vector<int64_t>> output_sizes;
131};
132
133} // namespace cuda
134} // namespace fuser
135} // namespace jit
136} // namespace torch
137