1 | #pragma once |
2 | #include <type.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | namespace fuser { |
7 | namespace cuda { |
8 | |
9 | class TORCH_CUDA_CU_API LaunchParams { |
10 | public: |
11 | static constexpr int64_t UNINITIALIZED_VAL = -1; |
12 | |
13 | LaunchParams( |
14 | int64_t gdimx = UNINITIALIZED_VAL, |
15 | int64_t gdimy = UNINITIALIZED_VAL, |
16 | int64_t gdimz = UNINITIALIZED_VAL, |
17 | int64_t bdimx = UNINITIALIZED_VAL, |
18 | int64_t bdimy = UNINITIALIZED_VAL, |
19 | int64_t bdimz = UNINITIALIZED_VAL) |
20 | : gdimx_(gdimx), |
21 | gdimy_(gdimy), |
22 | gdimz_(gdimz), |
23 | bdimx_(bdimx), |
24 | bdimy_(bdimy), |
25 | bdimz_(bdimz) { |
26 | assertValid(); |
27 | } |
28 | |
29 | void assertValid(); |
30 | |
31 | void setSmem(int64_t smem) { |
32 | smem_ = smem; |
33 | } |
34 | |
35 | int64_t smem() const { |
36 | return smem_; |
37 | } |
38 | |
39 | int64_t nBlocks() const { |
40 | return std::abs(gdimx_ * gdimy_ * gdimz_); |
41 | } |
42 | |
43 | int64_t nThreads() const { |
44 | return std::abs(bdimx_ * bdimy_ * bdimz_); |
45 | } |
46 | |
47 | int64_t bdimx() const { |
48 | return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_); |
49 | } |
50 | |
51 | int64_t gdimx() const { |
52 | return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_); |
53 | } |
54 | |
55 | int64_t bdimy() const { |
56 | return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_); |
57 | } |
58 | |
59 | int64_t gdimy() const { |
60 | return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_); |
61 | } |
62 | |
63 | int64_t bdimz() const { |
64 | return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_); |
65 | } |
66 | |
67 | int64_t gdimz() const { |
68 | return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_); |
69 | } |
70 | |
71 | void checkAndSet( |
72 | const int64_t incoming_val, |
73 | int64_t& class_val, |
74 | std::string val) { |
75 | TORCH_INTERNAL_ASSERT( |
76 | class_val == UNINITIALIZED_VAL || incoming_val == class_val, |
77 | "Tried to set " , |
78 | val, |
79 | " from " , |
80 | class_val, |
81 | " to " , |
82 | incoming_val, |
83 | ", but it was already set and new value does not match." , |
84 | " Thread dims all have to be bound to the same value." ); |
85 | TORCH_CHECK( |
86 | incoming_val > 0, |
87 | "Received a thread binding on " , |
88 | val, |
89 | " that is " , |
90 | incoming_val, |
91 | ". Cannot create negative threads." ); |
92 | if (class_val == UNINITIALIZED_VAL) { |
93 | class_val = incoming_val; |
94 | } |
95 | assertValid(); |
96 | } |
97 | |
98 | // Binds dim assocaited with p_type to val |
99 | void bind(int64_t val, ParallelType p_type); |
100 | |
101 | // Adjusted value based on get functions above for each value |
102 | int64_t getDim(ParallelType p_type) const; |
103 | |
104 | // Returns raw value which may be UNINITIALIZED_VAL |
105 | const int64_t& getRawVal(ParallelType p_type) const; |
106 | |
107 | // Returns false if value associated with p_type == UNINITIALIZED_VAL |
108 | bool hasDim(ParallelType p_type) const; |
109 | |
110 | bool operator==(const LaunchParams& other) const; |
111 | |
112 | void print() const; |
113 | |
114 | std::string toString() const; |
115 | |
116 | private: |
117 | // Spell them out because I want signed ints to know if they were initialized |
118 | // or not. |
119 | // TODO: convert to c10::optional |
120 | int64_t gdimx_ = UNINITIALIZED_VAL; |
121 | int64_t gdimy_ = UNINITIALIZED_VAL; |
122 | int64_t gdimz_ = UNINITIALIZED_VAL; |
123 | int64_t bdimx_ = UNINITIALIZED_VAL; |
124 | int64_t bdimy_ = UNINITIALIZED_VAL; |
125 | int64_t bdimz_ = UNINITIALIZED_VAL; |
126 | |
127 | int64_t smem_ = 0; |
128 | |
129 | // TODO: Fill in output sizes |
130 | std::vector<std::vector<int64_t>> output_sizes; |
131 | }; |
132 | |
133 | } // namespace cuda |
134 | } // namespace fuser |
135 | } // namespace jit |
136 | } // namespace torch |
137 | |