1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_interface_nodes.h>
6#include <type.h>
7#include <type_promotion.h>
8
9class Val;
10
11/*
12 * The operations defined in this header is intended as user facing functions.
13 * Generally users should not directly instantiate temporary TensorViews they
14 * should instead use the functions below which will automatically create IR
15 * nodes, and return a resulting TensorView of correctly tracked shapes.
16 */
17
18namespace torch {
19namespace jit {
20namespace fuser {
21namespace cuda {
22
23// Insertion of casting op to dtype, returns new resulting val
24TORCH_CUDA_CU_API Val* castOp(DataType dtype, Val* v1);
25TORCH_CUDA_CU_API TensorView* castOp(DataType dtype, TensorView* v1);
26
27TORCH_CUDA_CU_API Val* bitCastOp(DataType dtype, Val* v1);
28TORCH_CUDA_CU_API TensorView* bitCastOp(DataType dtype, TensorView* v1);
29
30// Perform unary op type and return the output
31TORCH_CUDA_CU_API Val* unaryOp(UnaryOpType type, Val* v1);
32TORCH_CUDA_CU_API TensorView* unaryOp(UnaryOpType type, TensorView* v1);
33TORCH_CUDA_CU_API Val* unaryIsOp(UnaryOpType type, Val* v1);
34TORCH_CUDA_CU_API TensorView* unaryIsOp(UnaryOpType type, TensorView* v1);
35TORCH_CUDA_CU_API Val* unaryOp(
36 UnaryOpType type,
37 Val* v1,
38 const TypePromotionConfig& config);
39TORCH_CUDA_CU_API TensorView* unaryOp(
40 UnaryOpType type,
41 TensorView* v1,
42 const TypePromotionConfig& config);
43
44// Perform binary op type on v1 and v2 and return a type promoted output.
45// Mod, CeilDiv, and LT are considered Int only output operations for now.
46TORCH_CUDA_CU_API Val* binaryOp(
47 BinaryOpType type,
48 Val* v1,
49 Val* v2,
50 DataType out_dtype = DataType::Null);
51TORCH_CUDA_CU_API TensorView* binaryOp(
52 BinaryOpType type,
53 TensorView* v1,
54 Val* v2,
55 DataType out_dtype = DataType::Null);
56TORCH_CUDA_CU_API TensorView* binaryOp(
57 BinaryOpType type,
58 Val* v1,
59 TensorView* v2,
60 DataType out_dtype = DataType::Null);
61TORCH_CUDA_CU_API TensorView* binaryOp(
62 BinaryOpType type,
63 TensorView* v1,
64 TensorView* v2,
65 DataType out_dtype = DataType::Null);
66
67TORCH_CUDA_CU_API Val* binaryOp(
68 BinaryOpType type,
69 Val* v1,
70 Val* v2,
71 const TypePromotionConfig& config);
72TORCH_CUDA_CU_API TensorView* binaryOp(
73 BinaryOpType type,
74 TensorView* v1,
75 Val* v2,
76 const TypePromotionConfig& config);
77TORCH_CUDA_CU_API TensorView* binaryOp(
78 BinaryOpType type,
79 Val* v1,
80 TensorView* v2,
81 const TypePromotionConfig& config);
82TORCH_CUDA_CU_API TensorView* binaryOp(
83 BinaryOpType type,
84 TensorView* v1,
85 TensorView* v2,
86 const TypePromotionConfig& config);
87
88// Perform a reduction operation on v1, initial value for reduction is init,
89// reduces across axes, and reduction operation defined by BinaryOp.
90TORCH_CUDA_CU_API TensorView* reductionOp(
91 BinaryOpType reduction_op_type,
92 const std::vector<int>& axes,
93 Val* init,
94 TensorView* v1,
95 bool keep_dim = false,
96 DataType dtype = DataType::Null);
97
98//! Auxiliary Struct holding result of
99//! a single welford op in ternsorview
100class TORCH_CUDA_CU_API WelfordResult {
101 public:
102 TensorView* avg;
103 TensorView* var_sum;
104 TensorView* n;
105
106 explicit WelfordResult(
107 TensorView* in_avg,
108 TensorView* in_var_sum,
109 TensorView* in_n);
110};
111
112//! Welford operator on specified axes. This is currently the only scan op with
113//! multiple outputs that is supported. May consider generalization if more scan
114//! ops are added.
115TORCH_CUDA_CU_API WelfordResult Welford(
116 TensorView* tv,
117 const std::vector<int>& axes,
118 TensorView* init_avg = nullptr,
119 TensorView* init_var = nullptr,
120 // Initializes to 0 in function definition, doing this so we don't have to
121 // import IrBuilder just for this one interface.
122 Int* init_N = nullptr);
123
124// RNG OPERATIONS
125TORCH_CUDA_CU_API TensorView* rand(
126 const std::vector<Val*>& shape,
127 DataType dtype);
128TORCH_CUDA_CU_API Val* rand_like(Val*);
129TORCH_CUDA_CU_API TensorView* rand_like(TensorView*);
130
131TORCH_CUDA_CU_API TensorView* uniform(
132 const std::vector<Val*>& shape,
133 Val* low,
134 Val* high,
135 DataType dtype);
136
137// TENSOR FACTORIES
138TORCH_CUDA_CU_API TensorView* full(
139 const std::vector<Val*>& shape,
140 Val* fill_value,
141 DataType dtype);
142TORCH_CUDA_CU_API TensorView* full_like(TensorView* tv, Val* fill_value);
143TORCH_CUDA_CU_API Val* full_like(Val* tv, Val* fill_value);
144TORCH_CUDA_CU_API TensorView* zeros(
145 const std::vector<Val*>& shape,
146 DataType dtype);
147TORCH_CUDA_CU_API TensorView* zeros_like(TensorView*);
148TORCH_CUDA_CU_API Val* zeros_like(Val*);
149TORCH_CUDA_CU_API TensorView* ones(
150 const std::vector<Val*>& shape,
151 DataType dtype);
152TORCH_CUDA_CU_API TensorView* ones_like(TensorView*);
153TORCH_CUDA_CU_API Val* ones_like(Val*);
154//! WARNING: giving invalid combinations of the start, end and step
155//! arguments can result in undefined behavior. Specifically, the
156//! signs of `end - start` and step must be the same.
157TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int);
158TORCH_CUDA_CU_API TensorView* arange(
159 Val* start,
160 Val* end,
161 DataType dtype = DataType::Int);
162TORCH_CUDA_CU_API TensorView* arange(
163 Val* start,
164 Val* end,
165 Val* step,
166 DataType dtype = DataType::Int);
167TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype);
168TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype);
169
170// UNARY OPERATIONS
171// abs
172TORCH_CUDA_CU_API Val* abs(Val*);
173TORCH_CUDA_CU_API TensorView* abs(TensorView*);
174// acos
175TORCH_CUDA_CU_API Val* acos(Val*);
176TORCH_CUDA_CU_API TensorView* acos(TensorView*);
177// asin
178TORCH_CUDA_CU_API Val* asin(Val*);
179TORCH_CUDA_CU_API TensorView* asin(TensorView*);
180// atan
181TORCH_CUDA_CU_API Val* atan(Val*);
182TORCH_CUDA_CU_API TensorView* atan(TensorView*);
183// atanh
184TORCH_CUDA_CU_API Val* atanh(Val*);
185TORCH_CUDA_CU_API TensorView* atanh(TensorView*);
186// ceil
187TORCH_CUDA_CU_API Val* ceil(Val*);
188TORCH_CUDA_CU_API TensorView* ceil(TensorView*);
189// cos
190TORCH_CUDA_CU_API Val* cos(Val*);
191TORCH_CUDA_CU_API TensorView* cos(TensorView*);
192// cosh
193TORCH_CUDA_CU_API Val* cosh(Val*);
194TORCH_CUDA_CU_API TensorView* cosh(TensorView*);
195// exp
196TORCH_CUDA_CU_API Val* exp(Val*);
197TORCH_CUDA_CU_API TensorView* exp(TensorView*);
198// expm1
199TORCH_CUDA_CU_API Val* expm1(Val*);
200TORCH_CUDA_CU_API TensorView* expm1(TensorView*);
201// erf
202TORCH_CUDA_CU_API Val* erf(Val*);
203TORCH_CUDA_CU_API TensorView* erf(TensorView*);
204// erfc
205TORCH_CUDA_CU_API Val* erfc(Val*);
206TORCH_CUDA_CU_API TensorView* erfc(TensorView*);
207// floor
208TORCH_CUDA_CU_API Val* floor(Val*);
209TORCH_CUDA_CU_API TensorView* floor(TensorView*);
210// frac
211TORCH_CUDA_CU_API Val* frac(Val*);
212TORCH_CUDA_CU_API TensorView* frac(TensorView*);
213// silu
214TORCH_CUDA_CU_API Val* silu(Val*);
215TORCH_CUDA_CU_API TensorView* silu(TensorView*);
216// lgamma
217TORCH_CUDA_CU_API Val* lgamma(Val*);
218TORCH_CUDA_CU_API TensorView* lgamma(TensorView*);
219// log
220TORCH_CUDA_CU_API Val* log(Val*);
221TORCH_CUDA_CU_API TensorView* log(TensorView*);
222// log10
223TORCH_CUDA_CU_API Val* log10(Val*);
224TORCH_CUDA_CU_API TensorView* log10(TensorView*);
225// log1p
226TORCH_CUDA_CU_API Val* log1p(Val*);
227TORCH_CUDA_CU_API TensorView* log1p(TensorView*);
228// log2
229TORCH_CUDA_CU_API Val* log2(Val*);
230TORCH_CUDA_CU_API TensorView* log2(TensorView*);
231// neg
232TORCH_CUDA_CU_API Val* neg(Val*);
233TORCH_CUDA_CU_API TensorView* neg(TensorView*);
234// real
235TORCH_CUDA_CU_API Val* real(Val*);
236TORCH_CUDA_CU_API TensorView* real(TensorView*);
237// reciprocal
238TORCH_CUDA_CU_API Val* reciprocal(Val*);
239TORCH_CUDA_CU_API TensorView* reciprocal(TensorView*);
240// relu
241TORCH_CUDA_CU_API Val* relu(Val*);
242TORCH_CUDA_CU_API TensorView* relu(TensorView*);
243// rsqrt
244TORCH_CUDA_CU_API Val* rsqrt(Val*);
245TORCH_CUDA_CU_API TensorView* rsqrt(TensorView*);
246// round
247TORCH_CUDA_CU_API Val* round(Val*);
248TORCH_CUDA_CU_API TensorView* round(TensorView*);
249// set
250TORCH_CUDA_CU_API Val* set(Val*);
251TORCH_CUDA_CU_API TensorView* set(TensorView*);
252// sigmoid
253TORCH_CUDA_CU_API Val* sigmoid(Val*);
254TORCH_CUDA_CU_API TensorView* sigmoid(TensorView*);
255// sin
256TORCH_CUDA_CU_API Val* sin(Val*);
257TORCH_CUDA_CU_API TensorView* sin(TensorView*);
258// sinh
259TORCH_CUDA_CU_API Val* sinh(Val*);
260TORCH_CUDA_CU_API TensorView* sinh(TensorView*);
261// sqrt
262TORCH_CUDA_CU_API Val* sqrt(Val*);
263TORCH_CUDA_CU_API TensorView* sqrt(TensorView*);
264// tan
265TORCH_CUDA_CU_API Val* tan(Val*);
266TORCH_CUDA_CU_API TensorView* tan(TensorView*);
267// tanh
268TORCH_CUDA_CU_API Val* tanh(Val*);
269TORCH_CUDA_CU_API TensorView* tanh(TensorView*);
270// trunc
271TORCH_CUDA_CU_API Val* trunc(Val*);
272TORCH_CUDA_CU_API TensorView* trunc(TensorView*);
273// bitwise_not
274TORCH_CUDA_CU_API Val* bitwise_not(Val*);
275TORCH_CUDA_CU_API TensorView* bitwise_not(TensorView*);
276// imag
277TORCH_CUDA_CU_API Val* imag(Val*);
278TORCH_CUDA_CU_API TensorView* imag(TensorView*);
279// isfinite
280TORCH_CUDA_CU_API Val* isfinite(Val*);
281TORCH_CUDA_CU_API TensorView* isfinite(TensorView*);
282// isinf
283TORCH_CUDA_CU_API Val* isinf(Val*);
284TORCH_CUDA_CU_API TensorView* isinf(TensorView*);
285// isnan
286TORCH_CUDA_CU_API Val* isnan(Val*);
287TORCH_CUDA_CU_API TensorView* isnan(TensorView*);
288// isneginf
289TORCH_CUDA_CU_API Val* isneginf(Val*);
290TORCH_CUDA_CU_API TensorView* isneginf(TensorView*);
291// isposinf
292TORCH_CUDA_CU_API Val* isposinf(Val*);
293TORCH_CUDA_CU_API TensorView* isposinf(TensorView*);
294// isreal
295TORCH_CUDA_CU_API Val* isreal(Val*);
296TORCH_CUDA_CU_API TensorView* isreal(TensorView*);
297// print
298TORCH_CUDA_CU_API Val* print(Val*);
299TORCH_CUDA_CU_API TensorView* print(TensorView*);
300
301// Broadcasts inp based on bool vector. Size of broadcast bool vector should be
302// the number of dims desired in the broadcasted tensor. This vector should be
303// true if output dim should be a broadcasted dim, and false if it is not a
304// broadcasted dim. Number of false entires must match the number of input dims.
305TORCH_CUDA_CU_API TensorView* broadcast(
306 TensorView* inp,
307 const std::vector<bool>& is_broadcast_dim);
308
309// Expands input based on provided sizes. expand_sizes should be larger than
310// the input's root domain (really rfactor) and will broadcast on inner
311// dimensions. expand_sizes should be -1 for any dimension that should remain a
312// symbolic size. For dimensions that remain broadcast after the expand should
313// be set to 1, any dimension being expanded must be marked as a broadcast in
314// the input and will be expanded to the provided constant size. Any dimension
315// that's symbolic in the input but specified as a non -1 value will be set to
316// that constant value.
317TORCH_CUDA_CU_API TensorView* expand(
318 TensorView* inp,
319 const std::vector<Val*>& expanded_sizes);
320
321// Expands input based on other. For dimensions in inp that are broadcast with a
322// matching entry in other that's either a broadcast with expanded extent or a
323// non broadcasted iter domain, inp will be expanded to other's size.
324TORCH_CUDA_CU_API TensorView* expand_as(TensorView* inp, TensorView* other);
325
326// BINARY OPERATIONS
327// add
328TORCH_CUDA_CU_API Val* add(Val* v1, Val* v2);
329TORCH_CUDA_CU_API TensorView* add(TensorView* v1, Val* v2);
330TORCH_CUDA_CU_API TensorView* add(Val* v1, TensorView* v2);
331TORCH_CUDA_CU_API TensorView* add(TensorView* v1, TensorView* v2);
332// atan2
333TORCH_CUDA_CU_API Val* atan2(Val* v1, Val* v2);
334TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, Val* v2);
335TORCH_CUDA_CU_API TensorView* atan2(Val* v1, TensorView* v2);
336TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, TensorView* v2);
337// div
338TORCH_CUDA_CU_API Val* div(Val* v1, Val* v2);
339TORCH_CUDA_CU_API TensorView* div(TensorView* v1, Val* v2);
340TORCH_CUDA_CU_API TensorView* div(Val* v1, TensorView* v2);
341TORCH_CUDA_CU_API TensorView* div(TensorView* v1, TensorView* v2);
342// fmod
343TORCH_CUDA_CU_API Val* fmod(Val* v1, Val* v2);
344TORCH_CUDA_CU_API TensorView* fmod(TensorView* v1, Val* v2);
345TORCH_CUDA_CU_API TensorView* fmod(Val* v1, TensorView* v2);
346TORCH_CUDA_CU_API TensorView* fmod(TensorView* v1, TensorView* v2);
347// mul
348TORCH_CUDA_CU_API Val* mul(Val* v1, Val* v2);
349TORCH_CUDA_CU_API TensorView* mul(TensorView* v1, Val* v2);
350TORCH_CUDA_CU_API TensorView* mul(Val* v1, TensorView* v2);
351TORCH_CUDA_CU_API TensorView* mul(TensorView* v1, TensorView* v2);
352// pow
353TORCH_CUDA_CU_API Val* pow(Val* v1, Val* v2);
354TORCH_CUDA_CU_API TensorView* pow(TensorView* v1, Val* v2);
355TORCH_CUDA_CU_API TensorView* pow(Val* v1, TensorView* v2);
356TORCH_CUDA_CU_API TensorView* pow(TensorView* v1, TensorView* v2);
357// remainder
358TORCH_CUDA_CU_API Val* remainder(Val* v1, Val* v2);
359TORCH_CUDA_CU_API TensorView* remainder(TensorView* v1, Val* v2);
360TORCH_CUDA_CU_API TensorView* remainder(Val* v1, TensorView* v2);
361TORCH_CUDA_CU_API TensorView* remainder(TensorView* v1, TensorView* v2);
362// sub
363TORCH_CUDA_CU_API Val* sub(Val* v1, Val* v2);
364TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, Val* v2);
365TORCH_CUDA_CU_API TensorView* sub(Val* v1, TensorView* v2);
366TORCH_CUDA_CU_API TensorView* sub(TensorView* v1, TensorView* v2);
367// Integer binary ops
368// mod
369TORCH_CUDA_CU_API Val* mod(Val* v1, Val* v2);
370TORCH_CUDA_CU_API TensorView* mod(TensorView* v1, Val* v2);
371TORCH_CUDA_CU_API TensorView* mod(Val* v1, TensorView* v2);
372TORCH_CUDA_CU_API TensorView* mod(TensorView* v1, TensorView* v2);
373// ceilDiv
374TORCH_CUDA_CU_API Val* ceilDiv(Val* v1, Val* v2);
375TORCH_CUDA_CU_API TensorView* ceilDiv(TensorView* v1, Val* v2);
376TORCH_CUDA_CU_API TensorView* ceilDiv(Val* v1, TensorView* v2);
377TORCH_CUDA_CU_API TensorView* ceilDiv(TensorView* v1, TensorView* v2);
378// Bitwise binary ops
379// bitwise_and
380TORCH_CUDA_CU_API Val* bitwise_and(Val* v1, Val* v2);
381TORCH_CUDA_CU_API TensorView* bitwise_and(TensorView* v1, Val* v2);
382TORCH_CUDA_CU_API TensorView* bitwise_and(Val* v1, TensorView* v2);
383TORCH_CUDA_CU_API TensorView* bitwise_and(TensorView* v1, TensorView* v2);
384// bitwise_left_shift
385TORCH_CUDA_CU_API Val* bitwise_left_shift(Val* v1, Val* v2);
386TORCH_CUDA_CU_API TensorView* bitwise_left_shift(TensorView* v1, Val* v2);
387TORCH_CUDA_CU_API TensorView* bitwise_left_shift(Val* v1, TensorView* v2);
388TORCH_CUDA_CU_API TensorView* bitwise_left_shift(
389 TensorView* v1,
390 TensorView* v2);
391// bitwise_right_shift
392TORCH_CUDA_CU_API Val* bitwise_right_shift(Val* v1, Val* v2);
393TORCH_CUDA_CU_API TensorView* bitwise_right_shift(TensorView* v1, Val* v2);
394TORCH_CUDA_CU_API TensorView* bitwise_right_shift(Val* v1, TensorView* v2);
395TORCH_CUDA_CU_API TensorView* bitwise_right_shift(
396 TensorView* v1,
397 TensorView* v2);
398// bitwise_or
399TORCH_CUDA_CU_API Val* bitwise_or(Val* v1, Val* v2);
400TORCH_CUDA_CU_API TensorView* bitwise_or(TensorView* v1, Val* v2);
401TORCH_CUDA_CU_API TensorView* bitwise_or(Val* v1, TensorView* v2);
402TORCH_CUDA_CU_API TensorView* bitwise_or(TensorView* v1, TensorView* v2);
403// bitwise_xor
404TORCH_CUDA_CU_API Val* bitwise_xor(Val* v1, Val* v2);
405TORCH_CUDA_CU_API TensorView* bitwise_xor(TensorView* v1, Val* v2);
406TORCH_CUDA_CU_API TensorView* bitwise_xor(Val* v1, TensorView* v2);
407TORCH_CUDA_CU_API TensorView* bitwise_xor(TensorView* v1, TensorView* v2);
408// Logical binary ops
409// eq
410TORCH_CUDA_CU_API Val* eq(Val* v1, Val* v2);
411TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, Val* v2);
412TORCH_CUDA_CU_API TensorView* eq(Val* v1, TensorView* v2);
413TORCH_CUDA_CU_API TensorView* eq(TensorView* v1, TensorView* v2);
414// ge
415TORCH_CUDA_CU_API Val* ge(Val* v1, Val* v2);
416TORCH_CUDA_CU_API TensorView* ge(TensorView* v1, Val* v2);
417TORCH_CUDA_CU_API TensorView* ge(Val* v1, TensorView* v2);
418TORCH_CUDA_CU_API TensorView* ge(TensorView* v1, TensorView* v2);
419// gt
420TORCH_CUDA_CU_API Val* gt(Val* v1, Val* v2);
421TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, Val* v2);
422TORCH_CUDA_CU_API TensorView* gt(Val* v1, TensorView* v2);
423TORCH_CUDA_CU_API TensorView* gt(TensorView* v1, TensorView* v2);
424// le
425TORCH_CUDA_CU_API Val* le(Val* v1, Val* v2);
426TORCH_CUDA_CU_API TensorView* le(TensorView* v1, Val* v2);
427TORCH_CUDA_CU_API TensorView* le(Val* v1, TensorView* v2);
428TORCH_CUDA_CU_API TensorView* le(TensorView* v1, TensorView* v2);
429// lt
430TORCH_CUDA_CU_API Val* lt(Val* v1, Val* v2);
431TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, Val* v2);
432TORCH_CUDA_CU_API TensorView* lt(Val* v1, TensorView* v2);
433TORCH_CUDA_CU_API TensorView* lt(TensorView* v1, TensorView* v2);
434// ne
435TORCH_CUDA_CU_API Val* ne(Val* v1, Val* v2);
436TORCH_CUDA_CU_API TensorView* ne(TensorView* v1, Val* v2);
437TORCH_CUDA_CU_API TensorView* ne(Val* v1, TensorView* v2);
438TORCH_CUDA_CU_API TensorView* ne(TensorView* v1, TensorView* v2);
439
440// REDUCTION OPERATIONS
441TORCH_CUDA_CU_API TensorView* sum(
442 TensorView* v1,
443 const std::vector<int>& reduction_axes,
444 bool keep_dim = false,
445 DataType dtype = DataType::Null);
446
447TORCH_CUDA_CU_API TensorView* max(
448 TensorView* v1,
449 const std::vector<int>& reduction_axes,
450 bool keep_dim = false,
451 DataType dtype = DataType::Null);
452
453TORCH_CUDA_CU_API TensorView* min(
454 TensorView* v1,
455 const std::vector<int>& reduction_axes,
456 bool keep_dim = false,
457 DataType dtype = DataType::Null);
458
459// COMPOUND OPERATIONS
460// add_alpha
461TORCH_CUDA_CU_API Val* add_alpha(Val* v1, Val* v2, Val* s);
462TORCH_CUDA_CU_API TensorView* add_alpha(TensorView* v1, Val* v2, Val* s);
463TORCH_CUDA_CU_API TensorView* add_alpha(Val* v1, TensorView* v2, Val* s);
464TORCH_CUDA_CU_API TensorView* add_alpha(TensorView* v1, TensorView* v2, Val* s);
465// sub_alpha
466TORCH_CUDA_CU_API Val* sub_alpha(Val* v1, Val* v2, Val* s);
467TORCH_CUDA_CU_API TensorView* sub_alpha(TensorView* v1, Val* v2, Val* s);
468TORCH_CUDA_CU_API TensorView* sub_alpha(Val* v1, TensorView* v2, Val* s);
469TORCH_CUDA_CU_API TensorView* sub_alpha(TensorView* v1, TensorView* v2, Val* s);
470// lerp
471TORCH_CUDA_CU_API Val* lerp(Val* start, Val* end, Val* weight);
472TORCH_CUDA_CU_API TensorView* lerp(TensorView* start, Val* end, Val* weight);
473TORCH_CUDA_CU_API TensorView* lerp(Val* start, TensorView* end, Val* weight);
474TORCH_CUDA_CU_API TensorView* lerp(Val* start, Val* end, TensorView* weight);
475TORCH_CUDA_CU_API TensorView* lerp(
476 TensorView* start,
477 TensorView* end,
478 Val* weight);
479TORCH_CUDA_CU_API TensorView* lerp(
480 TensorView* start,
481 Val* end,
482 TensorView* weight);
483TORCH_CUDA_CU_API TensorView* lerp(
484 Val* start,
485 TensorView* end,
486 TensorView* weight);
487TORCH_CUDA_CU_API TensorView* lerp(
488 TensorView* start,
489 TensorView* end,
490 TensorView* weight);
491// addcmul
492TORCH_CUDA_CU_API Val* addcmul(Val* v1, Val* v2, Val* v3, Val* s);
493TORCH_CUDA_CU_API TensorView* addcmul(TensorView* v1, Val* v2, Val* v3, Val* s);
494TORCH_CUDA_CU_API TensorView* addcmul(Val* v1, TensorView* v2, Val* v3, Val* s);
495TORCH_CUDA_CU_API TensorView* addcmul(Val* v1, Val* v2, TensorView* v3, Val* s);
496TORCH_CUDA_CU_API TensorView* addcmul(
497 TensorView* v1,
498 TensorView* v2,
499 Val* v3,
500 Val* s);
501TORCH_CUDA_CU_API TensorView* addcmul(
502 TensorView* v1,
503 Val* v2,
504 TensorView* v3,
505 Val* s);
506TORCH_CUDA_CU_API TensorView* addcmul(
507 Val* v1,
508 TensorView* v2,
509 TensorView* v3,
510 Val* s);
511TORCH_CUDA_CU_API TensorView* addcmul(
512 TensorView* v1,
513 TensorView* v2,
514 TensorView* v3,
515 Val* s);
516
517// TERNARY OPERATIONS
518// where
519TORCH_CUDA_CU_API Val* where(Val* c, Val* v1, Val* v2);
520TORCH_CUDA_CU_API TensorView* where(TensorView* c, Val* v1, Val* v2);
521TORCH_CUDA_CU_API TensorView* where(Val* c, TensorView* v1, Val* v2);
522TORCH_CUDA_CU_API TensorView* where(Val* c, Val* v1, TensorView* v2);
523TORCH_CUDA_CU_API TensorView* where(TensorView* c, TensorView* v1, Val* v2);
524TORCH_CUDA_CU_API TensorView* where(TensorView* c, Val* v1, TensorView* v2);
525TORCH_CUDA_CU_API TensorView* where(Val* c, TensorView* v1, TensorView* v2);
526TORCH_CUDA_CU_API TensorView* where(
527 TensorView* c,
528 TensorView* v1,
529 TensorView* v2);
530// threshold
531TORCH_CUDA_CU_API Val* threshold(Val* in, Val* thresh, Val* value);
532TORCH_CUDA_CU_API TensorView* threshold(
533 TensorView* in,
534 Val* thresh,
535 Val* value);
536// clamp
537TORCH_CUDA_CU_API Val* clamp(Val* in, Val* min_val, Val* max_val);
538TORCH_CUDA_CU_API TensorView* clamp(TensorView* in, Val* min_val, Val* max_val);
539
540//! Internal operator for supporting backward graphs
541//!
542//! example:
543//! v1 = T1 [I0(10),I1(20),I2(30),I3(40)]
544//! v2 = sum_to(v1,{30,1}) ------> v2 = T2[I2,R3 (keep_dim)]
545//!
546//! This operator will return v1* directly if sizes of v1 root domain
547//! is already the same as shape.
548//!
549//! Name of sum_to is different from NV fuser naming,
550//! this is to align with the operator name of at::sum_to.
551
552TORCH_CUDA_CU_API TensorView* sum_to(
553 TensorView* v1,
554 const std::vector<Int*>& sum_to_size);
555
556TORCH_CUDA_CU_API TensorView* sum_to(
557 TensorView* v1,
558 const std::vector<int64_t>& sum_to_size);
559
560//! Shift a tensor to a direction specified by offsets.
561//!
562//! Example:
563//! t0: 2D tensor of size N by M
564//! t1 = shift(t0, {1, -1});
565//!
566//! then:
567//! t1[i, j] = t0[i-1, j+1] for 1 <= i < N and 0 <= j < M-1.
568//! t1[i, j] = 0, otherwise
569//!
570//! The pad option controls how out-of-boundary accesses are
571//! handled. It specifies how many zeros are logically padded. If no
572//! pad option is given, it automatically pads the input tensor so
573//! that the output tensor has the same extent for each axis.
574//!
575//! When a padding value is smaller than the absolute value of a shift
576//! offset, the output axis still has the same extent but its start or
577//! stop offset is moved inward to signify those outside of the offset
578//! are invalid.
579//!
580//! It is not allowed to use padding values that are larger than shift
581//! offsets, which would mean output extentes would be larger than
582//! input extents
583TORCH_CUDA_CU_API TensorView* shift(
584 TensorView* inp,
585 const std::vector<int>& offsets,
586 const std::vector<int>& pad_width = {});
587
588TORCH_CUDA_CU_API TensorView* shift(
589 TensorView* inp,
590 const std::vector<int>& offsets,
591 bool pad);
592
593//! Gather a window of nearby elements for each element.
594//!
595//! Each window of size window_shape is stored as a additional
596//! innermost domain, meaning that the number of dimensions of the
597//! output tensor doubles. The pad_width parameter specifies the
598//! padding width of each side of each axis. The strides parameter
599//! specifies striding of the operation. Non-unit striding is
600//! implemented with strided split, whose outer output domain becomes
601//! the root domain for subsequent consumers. The inner output domain
602//! becomes a Stride domain, which is ignored by subsequent consumers.
603//! Only valid input ranges are fed into strided splits.
604//!
605//! When trim_out_of_bounds is true, the values at the first and last
606//! ends that are outside of the start and stop offsets are
607//! effetively trimmed by partial split by 1.
608//!
609//! Example 1:
610//! t0: 2D tensor of [N, M]
611//! t1 = gather(t0, {1, 3}, {{0, 0}, {1, 1}});
612//!
613//! then:
614//! t1: [N, M, 1, 3]
615//! t1[i, j, k, l] = The value at the window position of [k, l]
616//! for t0[i, j]
617//!
618//! Example 2.1 (without trimming):
619//! t0: 2D tensor of [N, M]
620//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}});
621//!
622//! then:
623//! t1: [N (stop offset: 1), M (stop offset: 1, 2, 2)]
624//!
625//! Example 2.1 (with trimming)
626//! t0: 2D tensor of [N, M]
627//! t1 = gather(t0, {2, 2}, {{0, 0}, {0, 0}}, true);
628//!
629//! then:
630//! t1: [ceilDiv(N - 1, 1), ceilDiv(M - 1, 1), 2, 2]
631//!
632//! Example 3:
633//! t0: 2D tensor of [N, M]
634//! t1 = gather(t0, {3, 3}, {{0, 0}, {0, 0}}, {3, 3});
635//!
636//! then:
637//! t1: [ceilDiv(N - 2, 3), ceilDiv(M - 2, 3), 2, 2]
638//!
639TORCH_CUDA_CU_API TensorView* gather(
640 TensorView* inp,
641 const std::vector<int>& window_shape,
642 const std::vector<std::vector<int>>& pad_width,
643 const std::vector<int>& strides = {},
644 bool trim_out_of_bounds = false);
645
646// Append a new IterDomain to the end of a TenorView to allow
647// iterating on a vector type. The input tensor must have
648// vector dtype.
649TORCH_CUDA_CU_API TensorView* viewAsScalar(TensorView* inp);
650
651//! A fused pointwise multiply and sum
652//! operator that instantiates the following
653//! fused pattern:
654//! c = mul(tv_a, tv_b);
655//! return sum(c, axes)
656//!
657//! \param tv_a first multiply operand
658//! \param tv_b second multiply operand
659//! \param axes axes to sum over
660//! \param init sum initial value
661//!
662//! Note & TODO:
663//! currently only support lowering to a mma op
664//! through this interface and only support fp16 inputs.
665//! will support converting back to multiply and reduce in
666//! a follow up.
667TORCH_CUDA_CU_API TensorView* fusedMultiplySum(
668 TensorView* tv_a,
669 TensorView* tv_b,
670 const std::vector<int>& axes,
671 Val* init = nullptr);
672
673} // namespace cuda
674} // namespace fuser
675} // namespace jit
676} // namespace torch
677