1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_interface_nodes.h> |
6 | #include <type.h> |
7 | #include <type_promotion.h> |
8 | |
9 | class Val; |
10 | |
11 | /* |
12 | * The operations defined in this header is intended as user facing functions. |
13 | * Generally users should not directly instantiate temporary TensorViews they |
14 | * should instead use the functions below which will automatically create IR |
15 | * nodes, and return a resulting TensorView of correctly tracked shapes. |
16 | */ |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace fuser { |
21 | namespace cuda { |
22 | |
23 | // Insertion of casting op to dtype, returns new resulting val |
24 | TORCH_CUDA_CU_API Val* castOp(DataType dtype, Val* v1); |
25 | TORCH_CUDA_CU_API TensorView* castOp(DataType dtype, TensorView* v1); |
26 | |
27 | TORCH_CUDA_CU_API Val* bitCastOp(DataType dtype, Val* v1); |
28 | TORCH_CUDA_CU_API TensorView* bitCastOp(DataType dtype, TensorView* v1); |
29 | |
30 | // Perform unary op type and return the output |
31 | TORCH_CUDA_CU_API Val* unaryOp(UnaryOpType type, Val* v1); |
32 | TORCH_CUDA_CU_API TensorView* unaryOp(UnaryOpType type, TensorView* v1); |
33 | TORCH_CUDA_CU_API Val* unaryIsOp(UnaryOpType type, Val* v1); |
34 | TORCH_CUDA_CU_API TensorView* unaryIsOp(UnaryOpType type, TensorView* v1); |
35 | TORCH_CUDA_CU_API Val* unaryOp( |
36 | UnaryOpType type, |
37 | Val* v1, |
38 | const TypePromotionConfig& config); |
39 | TORCH_CUDA_CU_API TensorView* unaryOp( |
40 | UnaryOpType type, |
41 | TensorView* v1, |
42 | const TypePromotionConfig& config); |
43 | |
44 | // Perform binary op type on v1 and v2 and return a type promoted output. |
45 | // Mod, CeilDiv, and LT are considered Int only output operations for now. |
46 | TORCH_CUDA_CU_API Val* binaryOp( |
47 | BinaryOpType type, |
48 | Val* v1, |
49 | Val* v2, |
50 | DataType out_dtype = DataType::Null); |
51 | TORCH_CUDA_CU_API TensorView* binaryOp( |
52 | BinaryOpType type, |
53 | TensorView* v1, |
54 | Val* v2, |
55 | DataType out_dtype = DataType::Null); |
56 | TORCH_CUDA_CU_API TensorView* binaryOp( |
57 | BinaryOpType type, |
58 | Val* v1, |
59 | TensorView* v2, |
60 | DataType out_dtype = DataType::Null); |
61 | TORCH_CUDA_CU_API TensorView* binaryOp( |
62 | BinaryOpType type, |
63 | TensorView* v1, |
64 | TensorView* v2, |
65 | DataType out_dtype = DataType::Null); |
66 | |
67 | TORCH_CUDA_CU_API Val* binaryOp( |
68 | BinaryOpType type, |
69 | Val* v1, |
70 | Val* v2, |
71 | const TypePromotionConfig& config); |
72 | TORCH_CUDA_CU_API TensorView* binaryOp( |
73 | BinaryOpType type, |
74 | TensorView* v1, |
75 | Val* v2, |
76 | const TypePromotionConfig& config); |
77 | TORCH_CUDA_CU_API TensorView* binaryOp( |
78 | BinaryOpType type, |
79 | Val* v1, |
80 | TensorView* v2, |
81 | const TypePromotionConfig& config); |
82 | TORCH_CUDA_CU_API TensorView* binaryOp( |
83 | BinaryOpType type, |
84 | TensorView* v1, |
85 | TensorView* v2, |
86 | const TypePromotionConfig& config); |
87 | |
88 | // Perform a reduction operation on v1, initial value for reduction is init, |
89 | // reduces across axes, and reduction operation defined by BinaryOp. |
90 | TORCH_CUDA_CU_API TensorView* reductionOp( |
91 | BinaryOpType reduction_op_type, |
92 | const std::vector<int>& axes, |
93 | Val* init, |
94 | TensorView* v1, |
95 | bool keep_dim = false, |
96 | DataType dtype = DataType::Null); |
97 | |
98 | //! Auxiliary Struct holding result of |
99 | //! a single welford op in ternsorview |
100 | class TORCH_CUDA_CU_API WelfordResult { |
101 | public: |
102 | TensorView* avg; |
103 | TensorView* var_sum; |
104 | TensorView* n; |
105 | |
106 | explicit WelfordResult( |
107 | TensorView* in_avg, |
108 | TensorView* in_var_sum, |
109 | TensorView* in_n); |
110 | }; |
111 | |
112 | //! Welford operator on specified axes. This is currently the only scan op with |
113 | //! multiple outputs that is supported. May consider generalization if more scan |
114 | //! ops are added. |
115 | TORCH_CUDA_CU_API WelfordResult Welford( |
116 | TensorView* tv, |
117 | const std::vector<int>& axes, |
118 | TensorView* init_avg = nullptr, |
119 | TensorView* init_var = nullptr, |
120 | // Initializes to 0 in function definition, doing this so we don't have to |
121 | // import IrBuilder just for this one interface. |
122 | Int* init_N = nullptr); |
123 | |
124 | // RNG OPERATIONS |
125 | TORCH_CUDA_CU_API TensorView* rand( |
126 | const std::vector<Val*>& shape, |
127 | DataType dtype); |
128 | TORCH_CUDA_CU_API Val* rand_like(Val*); |
129 | TORCH_CUDA_CU_API TensorView* rand_like(TensorView*); |
130 | |
131 | TORCH_CUDA_CU_API TensorView* uniform( |
132 | const std::vector<Val*>& shape, |
133 | Val* low, |
134 | Val* high, |
135 | DataType dtype); |
136 | |
137 | // TENSOR FACTORIES |
138 | TORCH_CUDA_CU_API TensorView* full( |
139 | const std::vector<Val*>& shape, |
140 | Val* fill_value, |
141 | DataType dtype); |
142 | TORCH_CUDA_CU_API TensorView* full_like(TensorView* tv, Val* fill_value); |
143 | TORCH_CUDA_CU_API Val* full_like(Val* tv, Val* fill_value); |
144 | TORCH_CUDA_CU_API TensorView* zeros( |
145 | const std::vector<Val*>& shape, |
146 | DataType dtype); |
147 | TORCH_CUDA_CU_API TensorView* zeros_like(TensorView*); |
148 | TORCH_CUDA_CU_API Val* zeros_like(Val*); |
149 | TORCH_CUDA_CU_API TensorView* ones( |
150 | const std::vector<Val*>& shape, |
151 | DataType dtype); |
152 | TORCH_CUDA_CU_API TensorView* ones_like(TensorView*); |
153 | TORCH_CUDA_CU_API Val* ones_like(Val*); |
154 | //! WARNING: giving invalid combinations of the start, end and step |
155 | //! arguments can result in undefined behavior. Specifically, the |
156 | //! signs of `end - start` and step must be the same. |
157 | TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int); |
158 | TORCH_CUDA_CU_API TensorView* arange( |
159 | Val* start, |
160 | Val* end, |
161 | DataType dtype = DataType::Int); |
162 | TORCH_CUDA_CU_API TensorView* arange( |
163 | Val* start, |
164 | Val* end, |
165 | Val* step, |
166 | DataType dtype = DataType::Int); |
167 | TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype); |
168 | TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype); |
169 | |
170 | // UNARY OPERATIONS |
171 | // abs |
172 | TORCH_CUDA_CU_API Val* abs(Val*); |
173 | TORCH_CUDA_CU_API TensorView* abs(TensorView*); |
174 | // acos |
175 | TORCH_CUDA_CU_API Val* acos(Val*); |
176 | TORCH_CUDA_CU_API TensorView* acos(TensorView*); |
177 | // asin |
178 | TORCH_CUDA_CU_API Val* asin(Val*); |
179 | TORCH_CUDA_CU_API TensorView* asin(TensorView*); |
180 | // atan |
181 | TORCH_CUDA_CU_API Val* atan(Val*); |
182 | TORCH_CUDA_CU_API TensorView* atan(TensorView*); |
183 | // atanh |
184 | TORCH_CUDA_CU_API Val* atanh(Val*); |
185 | TORCH_CUDA_CU_API TensorView* atanh(TensorView*); |
186 | // ceil |
187 | TORCH_CUDA_CU_API Val* ceil(Val*); |
188 | TORCH_CUDA_CU_API TensorView* ceil(TensorView*); |
189 | // cos |
190 | TORCH_CUDA_CU_API Val* cos(Val*); |
191 | TORCH_CUDA_CU_API TensorView* cos(TensorView*); |
192 | // cosh |
193 | TORCH_CUDA_CU_API Val* cosh(Val*); |
194 | TORCH_CUDA_CU_API TensorView* cosh(TensorView*); |
195 | // exp |
196 | TORCH_CUDA_CU_API Val* exp(Val*); |
197 | TORCH_CUDA_CU_API TensorView* exp(TensorView*); |
198 | // expm1 |
199 | TORCH_CUDA_CU_API Val* expm1(Val*); |
200 | TORCH_CUDA_CU_API TensorView* expm1(TensorView*); |
201 | // erf |
202 | TORCH_CUDA_CU_API Val* erf(Val*); |
203 | TORCH_CUDA_CU_API TensorView* erf(TensorView*); |
204 | // erfc |
205 | TORCH_CUDA_CU_API Val* erfc(Val*); |
206 | TORCH_CUDA_CU_API TensorView* erfc(TensorView*); |
207 | // floor |
208 | TORCH_CUDA_CU_API Val* floor(Val*); |
209 | TORCH_CUDA_CU_API TensorView* floor(TensorView*); |
210 | // frac |
211 | TORCH_CUDA_CU_API Val* frac(Val*); |
212 | TORCH_CUDA_CU_API TensorView* frac(TensorView*); |
213 | // silu |
214 | TORCH_CUDA_CU_API Val* silu(Val*); |
215 | TORCH_CUDA_CU_API TensorView* silu(TensorView*); |
216 | // lgamma |
217 | TORCH_CUDA_CU_API Val* lgamma(Val*); |
218 | TORCH_CUDA_CU_API TensorView* lgamma(TensorView*); |
219 | // log |
220 | TORCH_CUDA_CU_API Val* log(Val*); |
221 | TORCH_CUDA_CU_API TensorView* log(TensorView*); |
222 | // log10 |
223 | TORCH_CUDA_CU_API Val* log10(Val*); |
224 | TORCH_CUDA_CU_API TensorView* log10(TensorView*); |
225 | // log1p |
226 | TORCH_CUDA_CU_API Val* log1p(Val*); |
227 | TORCH_CUDA_CU_API TensorView* log1p(TensorView*); |
228 | // log2 |
229 | TORCH_CUDA_CU_API Val* log2(Val*); |
230 | TORCH_CUDA_CU_API TensorView* log2(TensorView*); |
231 | // neg |
232 | TORCH_CUDA_CU_API Val* neg(Val*); |
233 | TORCH_CUDA_CU_API TensorView* neg(TensorView*); |
234 | // real |
235 | TORCH_CUDA_CU_API Val* real(Val*); |
236 | TORCH_CUDA_CU_API TensorView* real(TensorView*); |
237 | // reciprocal |
238 | TORCH_CUDA_CU_API Val* reciprocal(Val*); |
239 | TORCH_CUDA_CU_API TensorView* reciprocal(TensorView*); |
240 | // relu |
241 | TORCH_CUDA_CU_API Val* relu(Val*); |
242 | TORCH_CUDA_CU_API TensorView* relu(TensorView*); |
243 | // rsqrt |
244 | TORCH_CUDA_CU_API Val* rsqrt(Val*); |
245 | TORCH_CUDA_CU_API TensorView* rsqrt(TensorView*); |
246 | // round |
247 | TORCH_CUDA_CU_API Val* round(Val*); |
248 | TORCH_CUDA_CU_API TensorView* round(TensorView*); |
249 | // set |
250 | TORCH_CUDA_CU_API Val* set(Val*); |
251 | TORCH_CUDA_CU_API TensorView* set(TensorView*); |
252 | // sigmoid |
253 | TORCH_CUDA_CU_API Val* sigmoid(Val*); |
254 | TORCH_CUDA_CU_API TensorView* sigmoid(TensorView*); |
255 | // sin |
256 | TORCH_CUDA_CU_API Val* sin(Val*); |
257 | TORCH_CUDA_CU_API TensorView* sin(TensorView*); |
258 | // sinh |
259 | TORCH_CUDA_CU_API Val* sinh(Val*); |
260 | TORCH_CUDA_CU_API TensorView* sinh(TensorView*); |
261 | // sqrt |
262 | TORCH_CUDA_CU_API Val* sqrt(Val*); |
263 | TORCH_CUDA_CU_API TensorView* sqrt(TensorView*); |
264 | // tan |
265 | TORCH_CUDA_CU_API Val* tan(Val*); |
266 | TORCH_CUDA_CU_API TensorView* tan(TensorView*); |
267 | // tanh |
268 | TORCH_CUDA_CU_API Val* tanh(Val*); |
269 | TORCH_CUDA_CU_API TensorView* tanh(TensorView*); |
270 | // trunc |
271 | TORCH_CUDA_CU_API Val* trunc(Val*); |
272 | TORCH_CUDA_CU_API TensorView* trunc(TensorView*); |
273 | // bitwise_not |
274 | TORCH_CUDA_CU_API Val* bitwise_not(Val*); |
275 | TORCH_CUDA_CU_API TensorView* bitwise_not(TensorView*); |
276 | // imag |
277 | TORCH_CUDA_CU_API Val* imag(Val*); |
278 | TORCH_CUDA_CU_API TensorView* imag(TensorView*); |
279 | // isfinite |
280 | TORCH_CUDA_CU_API Val* isfinite(Val*); |
281 | TORCH_CUDA_CU_API TensorView* isfinite(TensorView*); |
282 | // isinf |
283 | TORCH_CUDA_CU_API Val* isinf(Val*); |
284 | TORCH_CUDA_CU_API TensorView* isinf(TensorView*); |
285 | // isnan |
286 | TORCH_CUDA_CU_API Val* isnan(Val*); |
287 | TORCH_CUDA_CU_API TensorView* isnan(TensorView*); |
288 | // isneginf |
289 | TORCH_CUDA_CU_API Val* isneginf(Val*); |
290 | TORCH_CUDA_CU_API TensorView* isneginf(TensorView*); |
291 | // isposinf |
292 | TORCH_CUDA_CU_API Val* isposinf(Val*); |
293 | TORCH_CUDA_CU_API TensorView* isposinf(TensorView*); |
294 | // isreal |
295 | TORCH_CUDA_CU_API Val* isreal(Val*); |
296 | TORCH_CUDA_CU_API TensorView* isreal(TensorView*); |
297 | // print |
298 | TORCH_CUDA_CU_API Val* print(Val*); |
299 | TORCH_CUDA_CU_API TensorView* print(TensorView*); |
300 | |
301 | // Broadcasts inp based on bool vector. Size of broadcast bool vector should be |
302 | // the number of dims desired in the broadcasted tensor. This vector should be |
303 | // true if output dim should be a broadcasted dim, and false if it is not a |
304 | // broadcasted dim. Number of false entires must match the number of input dims. |
305 | TORCH_CUDA_CU_API TensorView* broadcast( |
306 | TensorView* inp, |
307 | const std::vector<bool>& is_broadcast_dim); |
308 | |
309 | // Expands input based on provided sizes. expand_sizes should be larger than |
310 | // the input's root domain (really rfactor) and will broadcast on inner |
311 | // dimensions. expand_sizes should be -1 for any dimension that should remain a |
312 | // symbolic size. For dimensions that remain broadcast after the expand should |
313 | // be set to 1, any dimension being expanded must be marked as a broadcast in |
314 | // the input and will be expanded to the provided constant size. Any dimension |
315 | // that's symbolic in the input but specified as a non -1 value will be set to |
316 | // that constant value. |
317 | TORCH_CUDA_CU_API TensorView* expand( |
318 | TensorView* inp, |
319 | const std::vector<Val*>& expanded_sizes); |
320 | |
321 | // Expands input based on other. For dimensions in inp that are broadcast with a |
322 | // matching entry in other that's either a broadcast with expanded extent or a |
323 | // non broadcasted iter domain, inp will be expanded to other's size. |
324 | TORCH_CUDA_CU_API TensorView* expand_as(TensorView* inp, TensorView* other); |
325 | |
326 | // BINARY OPERATIONS |
327 | // add |
328 | TORCH_CUDA_CU_API Val* add(Val* v1, Val* v2); |
329 | TORCH_CUDA_CU_API TensorView* add(TensorView* v1, Val* v2); |
330 | TORCH_CUDA_CU_API TensorView* add(Val* v1, TensorView* v2); |
331 | TORCH_CUDA_CU_API TensorView* add(TensorView* v1, TensorView* v2); |
332 | // atan2 |
333 | TORCH_CUDA_CU_API Val* atan2(Val* v1, Val* v2); |
334 | TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, Val* v2); |
335 | TORCH_CUDA_CU_API TensorView* atan2(Val* v1, TensorView* v2); |
336 | TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, TensorView* v2); |
337 | // div |
338 | TORCH_CUDA_CU_API Val* div(Val* v1, Val* v2); |
339 | TORCH_CUDA_CU_API TensorView* div(TensorView* v1, Val* v2); |
340 | TORCH_CUDA_CU_API TensorView* div(Val* v1, TensorView* v2); |
341 | TORCH_CUDA_CU_API TensorView* div(TensorView* v1, TensorView* v2); |
342 | // fmod |
343 | TORCH_CUDA_CU_API Val* fmod(Val* v1, Val* v2); |
344 | TORCH_CUDA_CU_API TensorView* fmod(TensorView* v1, Val* v2); |
345 | TORCH_CUDA_CU_API TensorView* fmod(Val* v1, TensorView* v2); |
346 | TORCH_CUDA_CU_API TensorView* fmod(TensorView* v1, TensorView* v2); |
347 | // mul |
348 | TORCH_CUDA_CU_API Val* mul(Val* v1, Val* v2); |
349 | TORCH_CUDA_CU_API TensorView* mul(TensorView* v1, Val* v2); |
350 | TORCH_CUDA_CU_API TensorView* mul(Val* v1, TensorView* v2); |
351 | TORCH_CUDA_CU_API TensorView* mul(TensorView* v1, TensorView* v2); |
352 | // pow |
353 | TORCH_CUDA_CU_API Val* pow(Val* v1, Val* v2); |
354 | TORCH_CUDA_CU_API TensorView* pow(TensorView* v1, Val* v2); |
355 | TORCH_CUDA_CU_API TensorView* pow(Val* v1, TensorView* v2); |
356 | TORCH_CUDA_CU_API TensorView* pow(TensorView* v1, TensorView* v2); |
357 | // remainder |
358 | TORCH_CUDA_CU_API Val* remainder(Val* v1, Val* v2); |
359 | TORCH_CUDA_CU_API TensorView* remainder(TensorView* v1, Val* v2); |
360 | TORCH_CUDA_CU_API TensorView* remainder(Val* v1, TensorView* v2); |
361 | TORCH_CUDA_CU_API TensorView* remainder(TensorView* v1, TensorView* v2); |
362 | // sub |
363 | TORCH_CUDA_CU_API Val* sub(Val* v1, Val* v2); |
364 | TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, Val* v2); |
365 | TORCH_CUDA_CU_API TensorView* sub(Val* v1, TensorView* v2); |
366 | TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, TensorView* v2); |
367 | // Integer binary ops |
368 | // mod |
369 | TORCH_CUDA_CU_API Val* mod(Val* v1, Val* v2); |
370 | TORCH_CUDA_CU_API TensorView* mod(TensorView* v1, Val* v2); |
371 | TORCH_CUDA_CU_API TensorView* mod(Val* v1, TensorView* v2); |
372 | TORCH_CUDA_CU_API TensorView* mod(TensorView* v1, TensorView* v2); |
373 | // ceilDiv |
374 | TORCH_CUDA_CU_API Val* ceilDiv(Val* v1, Val* v2); |
375 | TORCH_CUDA_CU_API TensorView* ceilDiv(TensorView* v1, Val* v2); |
376 | TORCH_CUDA_CU_API TensorView* ceilDiv(Val* v1, TensorView* v2); |
377 | TORCH_CUDA_CU_API TensorView* ceilDiv(TensorView* v1, TensorView* v2); |
378 | // Bitwise binary ops |
379 | // bitwise_and |
380 | TORCH_CUDA_CU_API Val* bitwise_and(Val* v1, Val* v2); |
381 | TORCH_CUDA_CU_API TensorView* bitwise_and(TensorView* v1, Val* v2); |
382 | TORCH_CUDA_CU_API TensorView* bitwise_and(Val* v1, TensorView* v2); |
383 | TORCH_CUDA_CU_API TensorView* bitwise_and(TensorView* v1, TensorView* v2); |
384 | // bitwise_left_shift |
385 | TORCH_CUDA_CU_API Val* bitwise_left_shift(Val* v1, Val* v2); |
386 | TORCH_CUDA_CU_API TensorView* bitwise_left_shift(TensorView* v1, Val* v2); |
387 | TORCH_CUDA_CU_API TensorView* bitwise_left_shift(Val* v1, TensorView* v2); |
388 | TORCH_CUDA_CU_API TensorView* bitwise_left_shift( |
389 | TensorView* v1, |
390 | TensorView* v2); |
391 | // bitwise_right_shift |
392 | TORCH_CUDA_CU_API Val* bitwise_right_shift(Val* v1, Val* v2); |
393 | TORCH_CUDA_CU_API TensorView* bitwise_right_shift(TensorView* v1, Val* v2); |
394 | TORCH_CUDA_CU_API TensorView* bitwise_right_shift(Val* v1, TensorView* v2); |
395 | TORCH_CUDA_CU_API TensorView* bitwise_right_shift( |
396 | TensorView* v1, |
397 | TensorView* v2); |
398 | // bitwise_or |
399 | TORCH_CUDA_CU_API Val* bitwise_or(Val* v1, Val* v2); |
400 | TORCH_CUDA_CU_API TensorView* bitwise_or(TensorView* v1, Val* v2); |
401 | TORCH_CUDA_CU_API TensorView* bitwise_or(Val* v1, TensorView* v2); |
402 | TORCH_CUDA_CU_API TensorView* bitwise_or(TensorView* v1, TensorView* v2); |
403 | // bitwise_xor |
404 | TORCH_CUDA_CU_API Val* bitwise_xor(Val* v1, Val* v2); |
405 | TORCH_CUDA_CU_API TensorView* bitwise_xor(TensorView* v1, Val* v2); |
406 | TORCH_CUDA_CU_API TensorView* bitwise_xor(Val* v1, TensorView* v2); |
407 | TORCH_CUDA_CU_API TensorView* bitwise_xor(TensorView* v1, TensorView* v2); |
408 | // Logical binary ops |
409 | // eq |
410 | TORCH_CUDA_CU_API Val* eq(Val* v1, Val* v2); |
411 | TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, Val* v2); |
412 | TORCH_CUDA_CU_API TensorView* eq(Val* v1, TensorView* v2); |
413 | TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, TensorView* v2); |
414 | // ge |
415 | TORCH_CUDA_CU_API Val* ge(Val* v1, Val* v2); |
416 | TORCH_CUDA_CU_API TensorView* ge(TensorView* v1, Val* v2); |
417 | TORCH_CUDA_CU_API TensorView* ge(Val* v1, TensorView* v2); |
418 | TORCH_CUDA_CU_API TensorView* ge(TensorView* v1, TensorView* v2); |
419 | // gt |
420 | TORCH_CUDA_CU_API Val* gt(Val* v1, Val* v2); |
421 | TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, Val* v2); |
422 | TORCH_CUDA_CU_API TensorView* gt(Val* v1, TensorView* v2); |
423 | TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, TensorView* v2); |
424 | // le |
425 | TORCH_CUDA_CU_API Val* le(Val* v1, Val* v2); |
426 | TORCH_CUDA_CU_API TensorView* le(TensorView* v1, Val* v2); |
427 | TORCH_CUDA_CU_API TensorView* le(Val* v1, TensorView* v2); |
428 | TORCH_CUDA_CU_API TensorView* le(TensorView* v1, TensorView* v2); |
429 | // lt |
430 | TORCH_CUDA_CU_API Val* lt(Val* v1, Val* v2); |
431 | TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, Val* v2); |
432 | TORCH_CUDA_CU_API TensorView* lt(Val* v1, TensorView* v2); |
433 | TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, TensorView* v2); |
434 | // ne |
435 | TORCH_CUDA_CU_API Val* ne(Val* v1, Val* v2); |
436 | TORCH_CUDA_CU_API TensorView* ne(TensorView* v1, Val* v2); |
437 | TORCH_CUDA_CU_API TensorView* ne(Val* v1, TensorView* v2); |
438 | TORCH_CUDA_CU_API TensorView* ne(TensorView* v1, TensorView* v2); |
439 | |
440 | // REDUCTION OPERATIONS |
441 | TORCH_CUDA_CU_API TensorView* sum( |
442 | TensorView* v1, |
443 | const std::vector<int>& reduction_axes, |
444 | bool keep_dim = false, |
445 | DataType dtype = DataType::Null); |
446 | |
447 | TORCH_CUDA_CU_API TensorView* max( |
448 | TensorView* v1, |
449 | const std::vector<int>& reduction_axes, |
450 | bool keep_dim = false, |
451 | DataType dtype = DataType::Null); |
452 | |
453 | TORCH_CUDA_CU_API TensorView* min( |
454 | TensorView* v1, |
455 | const std::vector<int>& reduction_axes, |
456 | bool keep_dim = false, |
457 | DataType dtype = DataType::Null); |
458 | |
459 | // COMPOUND OPERATIONS |
460 | // add_alpha |
461 | TORCH_CUDA_CU_API Val* add_alpha(Val* v1, Val* v2, Val* s); |
462 | TORCH_CUDA_CU_API TensorView* add_alpha(TensorView* v1, Val* v2, Val* s); |
463 | TORCH_CUDA_CU_API TensorView* add_alpha(Val* v1, TensorView* v2, Val* s); |
464 | TORCH_CUDA_CU_API TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* s); |
465 | // sub_alpha |
466 | TORCH_CUDA_CU_API Val* sub_alpha(Val* v1, Val* v2, Val* s); |
467 | TORCH_CUDA_CU_API TensorView* sub_alpha(TensorView* v1, Val* v2, Val* s); |
468 | TORCH_CUDA_CU_API TensorView* sub_alpha(Val* v1, TensorView* v2, Val* s); |
469 | TORCH_CUDA_CU_API TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* s); |
470 | // lerp |
471 | TORCH_CUDA_CU_API Val* lerp(Val* start, Val* end, Val* weight); |
472 | TORCH_CUDA_CU_API TensorView* lerp(TensorView* start, Val* end, Val* weight); |
473 | TORCH_CUDA_CU_API TensorView* lerp(Val* start, TensorView* end, Val* weight); |
474 | TORCH_CUDA_CU_API TensorView* lerp(Val* start, Val* end, TensorView* weight); |
475 | TORCH_CUDA_CU_API TensorView* lerp( |
476 | TensorView* start, |
477 | TensorView* end, |
478 | Val* weight); |
479 | TORCH_CUDA_CU_API TensorView* lerp( |
480 | TensorView* start, |
481 | Val* end, |
482 | TensorView* weight); |
483 | TORCH_CUDA_CU_API TensorView* lerp( |
484 | Val* start, |
485 | TensorView* end, |
486 | TensorView* weight); |
487 | TORCH_CUDA_CU_API TensorView* lerp( |
488 | TensorView* start, |
489 | TensorView* end, |
490 | TensorView* weight); |
491 | // addcmul |
492 | TORCH_CUDA_CU_API Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s); |
493 | TORCH_CUDA_CU_API TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* s); |
494 | TORCH_CUDA_CU_API TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* s); |
495 | TORCH_CUDA_CU_API TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* s); |
496 | TORCH_CUDA_CU_API TensorView* addcmul( |
497 | TensorView* v1, |
498 | TensorView* v2, |
499 | Val* v3, |
500 | Val* s); |
501 | TORCH_CUDA_CU_API TensorView* addcmul( |
502 | TensorView* v1, |
503 | Val* v2, |
504 | TensorView* v3, |
505 | Val* s); |
506 | TORCH_CUDA_CU_API TensorView* addcmul( |
507 | Val* v1, |
508 | TensorView* v2, |
509 | TensorView* v3, |
510 | Val* s); |
511 | TORCH_CUDA_CU_API TensorView* addcmul( |
512 | TensorView* v1, |
513 | TensorView* v2, |
514 | TensorView* v3, |
515 | Val* s); |
516 | |
517 | // TERNARY OPERATIONS |
518 | // where |
519 | TORCH_CUDA_CU_API Val* where(Val* c, Val* v1, Val* v2); |
520 | TORCH_CUDA_CU_API TensorView* where(TensorView* c, Val* v1, Val* v2); |
521 | TORCH_CUDA_CU_API TensorView* where(Val* c, TensorView* v1, Val* v2); |
522 | TORCH_CUDA_CU_API TensorView* where(Val* c, Val* v1, TensorView* v2); |
523 | TORCH_CUDA_CU_API TensorView* where(TensorView* c, TensorView* v1, Val* v2); |
524 | TORCH_CUDA_CU_API TensorView* where(TensorView* c, Val* v1, TensorView* v2); |
525 | TORCH_CUDA_CU_API TensorView* where(Val* c, TensorView* v1, TensorView* v2); |
526 | TORCH_CUDA_CU_API TensorView* where( |
527 | TensorView* c, |
528 | TensorView* v1, |
529 | TensorView* v2); |
530 | // threshold |
531 | TORCH_CUDA_CU_API Val* threshold(Val* in, Val* thresh, Val* value); |
532 | TORCH_CUDA_CU_API TensorView* threshold( |
533 | TensorView* in, |
534 | Val* thresh, |
535 | Val* value); |
536 | // clamp |
537 | TORCH_CUDA_CU_API Val* clamp(Val* in, Val* min_val, Val* max_val); |
538 | TORCH_CUDA_CU_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val); |
539 | |
540 | //! Internal operator for supporting backward graphs |
541 | //! |
542 | //! example: |
543 | //! v1 = T1 [I0(10),I1(20),I2(30),I3(40)] |
544 | //! v2 = sum_to(v1,{30,1}) ------> v2 = T2[I2,R3 (keep_dim)] |
545 | //! |
546 | //! This operator will return v1* directly if sizes of v1 root domain |
547 | //! is already the same as shape. |
548 | //! |
549 | //! Name of sum_to is different from NV fuser naming, |
550 | //! this is to align with the operator name of at::sum_to. |
551 | |
552 | TORCH_CUDA_CU_API TensorView* sum_to( |
553 | TensorView* v1, |
554 | const std::vector<Int*>& sum_to_size); |
555 | |
556 | TORCH_CUDA_CU_API TensorView* sum_to( |
557 | TensorView* v1, |
558 | const std::vector<int64_t>& sum_to_size); |
559 | |
560 | //! Shift a tensor to a direction specified by offsets. |
561 | //! |
562 | //! Example: |
563 | //! t0: 2D tensor of size N by M |
564 | //! t1 = shift(t0, {1, -1}); |
565 | //! |
566 | //! then: |
567 | //! t1[i, j] = t0[i-1, j+1] for 1 <= i < N and 0 <= j < M-1. |
568 | //! t1[i, j] = 0, otherwise |
569 | //! |
570 | //! The pad option controls how out-of-boundary accesses are |
571 | //! handled. It specifies how many zeros are logically padded. If no |
572 | //! pad option is given, it automatically pads the input tensor so |
573 | //! that the output tensor has the same extent for each axis. |
574 | //! |
575 | //! When a padding value is smaller than the absolute value of a shift |
576 | //! offset, the output axis still has the same extent but its start or |
577 | //! stop offset is moved inward to signify those outside of the offset |
578 | //! are invalid. |
579 | //! |
580 | //! It is not allowed to use padding values that are larger than shift |
581 | //! offsets, which would mean output extentes would be larger than |
582 | //! input extents |
583 | TORCH_CUDA_CU_API TensorView* shift( |
584 | TensorView* inp, |
585 | const std::vector<int>& offsets, |
586 | const std::vector<int>& pad_width = {}); |
587 | |
588 | TORCH_CUDA_CU_API TensorView* shift( |
589 | TensorView* inp, |
590 | const std::vector<int>& offsets, |
591 | bool pad); |
592 | |
593 | //! Gather a window of nearby elements for each element. |
594 | //! |
595 | //! Each window of size window_shape is stored as a additional |
596 | //! innermost domain, meaning that the number of dimensions of the |
597 | //! output tensor doubles. The pad_width parameter specifies the |
598 | //! padding width of each side of each axis. The strides parameter |
599 | //! specifies striding of the operation. Non-unit striding is |
600 | //! implemented with strided split, whose outer output domain becomes |
601 | //! the root domain for subsequent consumers. The inner output domain |
602 | //! becomes a Stride domain, which is ignored by subsequent consumers. |
603 | //! Only valid input ranges are fed into strided splits. |
604 | //! |
605 | //! When trim_out_of_bounds is true, the values at the first and last |
606 | //! ends that are outside of the start and stop offsets are |
607 | //! effetively trimmed by partial split by 1. |
608 | //! |
609 | //! Example 1: |
610 | //! t0: 2D tensor of [N, M] |
611 | //! t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}}); |
612 | //! |
613 | //! then: |
614 | //! t1: [N, M, 1, 3] |
615 | //! t1[i, j, k, l] = The value at the window position of [k, l] |
616 | //! for t0[i, j] |
617 | //! |
618 | //! Example 2.1 (without trimming): |
619 | //! t0: 2D tensor of [N, M] |
620 | //! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}); |
621 | //! |
622 | //! then: |
623 | //! t1: [N (stop offset: 1), M (stop offset: 1, 2, 2)] |
624 | //! |
625 | //! Example 2.1 (with trimming) |
626 | //! t0: 2D tensor of [N, M] |
627 | //! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}, true); |
628 | //! |
629 | //! then: |
630 | //! t1: [ceilDiv(N - 1, 1), ceilDiv(M - 1, 1), 2, 2] |
631 | //! |
632 | //! Example 3: |
633 | //! t0: 2D tensor of [N, M] |
634 | //! t1 = gather(t0, {3, 3}, {{0, 0}, {0, 0}}, {3, 3}); |
635 | //! |
636 | //! then: |
637 | //! t1: [ceilDiv(N - 2, 3), ceilDiv(M - 2, 3), 2, 2] |
638 | //! |
639 | TORCH_CUDA_CU_API TensorView* gather( |
640 | TensorView* inp, |
641 | const std::vector<int>& window_shape, |
642 | const std::vector<std::vector<int>>& pad_width, |
643 | const std::vector<int>& strides = {}, |
644 | bool trim_out_of_bounds = false); |
645 | |
646 | // Append a new IterDomain to the end of a TenorView to allow |
647 | // iterating on a vector type. The input tensor must have |
648 | // vector dtype. |
649 | TORCH_CUDA_CU_API TensorView* viewAsScalar(TensorView* inp); |
650 | |
651 | //! A fused pointwise multiply and sum |
652 | //! operator that instantiates the following |
653 | //! fused pattern: |
654 | //! c = mul(tv_a, tv_b); |
655 | //! return sum(c, axes) |
656 | //! |
657 | //! \param tv_a first multiply operand |
658 | //! \param tv_b second multiply operand |
659 | //! \param axes axes to sum over |
660 | //! \param init sum initial value |
661 | //! |
662 | //! Note & TODO: |
663 | //! currently only support lowering to a mma op |
664 | //! through this interface and only support fp16 inputs. |
665 | //! will support converting back to multiply and reduce in |
666 | //! a follow up. |
667 | TORCH_CUDA_CU_API TensorView* fusedMultiplySum( |
668 | TensorView* tv_a, |
669 | TensorView* tv_b, |
670 | const std::vector<int>& axes, |
671 | Val* init = nullptr); |
672 | |
673 | } // namespace cuda |
674 | } // namespace fuser |
675 | } // namespace jit |
676 | } // namespace torch |
677 | |