1#include <python_frontend/python_bindings.h>
2
3#include <c10/util/ArrayRef.h>
4#include <c10/util/Optional.h>
5#include <c10/util/irange.h>
6#include <arith.h>
7#include <instrumentation.h>
8#include <ir_all_nodes.h>
9#include <ir_builder.h>
10#include <ops/composite.h>
11#include <python_frontend/fusion_cache.h>
12#include <python_frontend/fusion_definition.h>
13#include <python_frontend/fusion_interface.h>
14#include <python_frontend/fusion_record.h>
15#include <python_frontend/python_bindings.h>
16#include <torch/csrc/jit/python/pybind_utils.h>
17#include <iostream>
18#include <tuple>
19
20namespace torch {
21namespace jit {
22
23void initNvFuserPythonBindings(PyObject* module) {
24 auto nvfuser = py::handle(module).cast<py::module>();
25
26 //! DataTypes supported by nvFuser in the FusionDefinition
27 py::enum_<Nvf::DataType>(nvfuser, "DataType")
28 .value("Double", Nvf::DataType::Double)
29 .value("Float", Nvf::DataType::Float)
30 .value("Half", Nvf::DataType::Half)
31 .value("Int", Nvf::DataType::Int)
32 .value("Int32", Nvf::DataType::Int32)
33 .value("Bool", Nvf::DataType::Bool)
34 .value("BFloat16", Nvf::DataType::BFloat16)
35 .value("ComplexFloat", Nvf::DataType::ComplexFloat)
36 .value("ComplexDouble", Nvf::DataType::ComplexDouble)
37 .value("Null", Nvf::DataType::Null);
38
39 nvfuser.def(
40 "compute_contiguity",
41 [](const std::vector<int64_t>& sizes,
42 const std::vector<int64_t>& strides) {
43 py::tuple contiguity(sizes.size());
44 TORCH_CHECK(
45 sizes.size() == strides.size(),
46 "compute_contiguity: Sizes and strides must have the same number of dimensions");
47 if (sizes.size() == 0) {
48 return contiguity;
49 }
50 contiguity[sizes.size() - 1] = strides.back() == 1;
51 for (int64_t i = static_cast<int64_t>(sizes.size()) - 2; i >= 0; --i) {
52 contiguity[i] = strides[i] == strides[i + 1] * sizes[i + 1];
53 }
54 return contiguity;
55 });
56
57 //! Binding the FusionCache that holds a cache of Fusions
58 //! This is only bound to provide an interface to get the number of fusions
59 //! that are cached.
60 py::class_<nvfuser::FusionCache> fusion_cache(nvfuser, "FusionCache");
61 fusion_cache
62 .def_static(
63 "get",
64 &nvfuser::FusionCache::get,
65 py::arg("max_fusions") = int(8192),
66 py::return_value_policy::reference)
67 .def("num_fusions", &nvfuser::FusionCache::numFusions)
68 .def("print_stats", [](nvfuser::FusionCache& self) {
69 self.print(std::cout);
70 });
71
72 py::class_<nvfuser::FusionInterface> fusion(nvfuser, "Fusion");
73 fusion.def(py::init<>())
74 .def(py::init<size_t>(), py::arg("fusion_id"))
75 .def("define", &nvfuser::FusionInterface::define)
76 .def("defined", &nvfuser::FusionInterface::defined)
77 .def(
78 "execute",
79 [](nvfuser::FusionInterface& self, const py::iterable& iter) {
80 std::vector<IValue> inputs;
81 for (py::handle obj : iter) {
82 inputs.push_back(toIValue(obj, c10::AnyType::get()));
83 }
84 return self.execute(inputs);
85 },
86 py::return_value_policy::reference)
87 .def("id", &nvfuser::FusionInterface::id)
88 .def("print", &nvfuser::FusionInterface::print);
89
90 //! These are the FusionDefinition supported object types that are either
91 //! defined as inputs or the output of an operation.
92 py::class_<nvfuser::Tensor>(nvfuser, "Tensor");
93 py::class_<nvfuser::Scalar>(nvfuser, "Scalar");
94
95 //! The FusionDefinition is a context manager in Python where the user will
96 //! define the set the operations and connections between operations for
97 //! nvFuser to create.
98 py::class_<nvfuser::FusionDefinition> fusion_def(nvfuser, "FusionDefinition");
99 fusion_def
100 .def(
101 py::init<nvfuser::FusionInterface*, int>(),
102 py::arg("fusion"),
103 py::arg("max_length") = int(1024))
104 .def_readwrite("ops", &nvfuser::FusionDefinition::ops)
105 .def(
106 "__enter__",
107 [](nvfuser::FusionDefinition& self) -> nvfuser::FusionDefinition* {
108 // Instrumentation to mark the beginning of a FusionDefinition
109 Nvf::inst::Trace::instance()->beginEvent(
110 "FusionDefinition Context Manager");
111 return self.enter();
112 })
113 .def(
114 "__exit__",
115 [](nvfuser::FusionDefinition& self,
116 void* exc_type,
117 void* exc_value,
118 void* traceback) {
119 self.exit();
120 // Mark the end of a FusionDefinition Context Manager
121 Nvf::inst::Trace::instance()->endEvent(nullptr);
122 })
123 .def(
124 "__str__",
125 [](nvfuser::FusionDefinition& self) {
126 std::stringstream ss;
127 self.print(ss);
128 return ss.str();
129 })
130 .def(
131 "add_output",
132 [](nvfuser::FusionDefinition& self, nvfuser::Scalar output) {
133 FUSER_PERF_SCOPE("FusionDefinition.add_output (scalar)");
134 self.defineRecord(new nvfuser::OutputRecord<Nvf::Val>(
135 {self.recordingState(output())}));
136 })
137 .def(
138 "add_output",
139 [](nvfuser::FusionDefinition& self, nvfuser::Tensor output) {
140 FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)");
141 self.defineRecord(new nvfuser::OutputRecord<Nvf::TensorView>(
142 {self.recordingState(output())}));
143 })
144 .def(
145 "define_tensor",
146 [](nvfuser::FusionDefinition& self,
147 size_t ndims,
148 Nvf::DataType dtype = Nvf::DataType::Float,
149 bool is_cpu = false) -> nvfuser::Tensor {
150 FUSER_PERF_SCOPE("FusionDefinition.define_tensor (simple)");
151 std::vector<int64_t> maybe_symbolic_sizes(ndims, -1);
152 ;
153 std::vector<bool> contig_info(ndims, false);
154
155 nvfuser::Tensor out = self.defineTensor();
156 self.defineRecord(new nvfuser::TensorRecord(
157 {self.recordingState(out())},
158 std::move(maybe_symbolic_sizes),
159 std::move(contig_info),
160 dtype,
161 is_cpu));
162
163 return out;
164 },
165 py::arg("ndims"),
166 py::arg("dtype") = Nvf::DataType::Float,
167 py::arg("is_cpu") = false,
168 py::return_value_policy::reference)
169 .def(
170 "define_tensor",
171 [](nvfuser::FusionDefinition& self,
172 std::vector<int64_t>& symbolic_sizes,
173 std::vector<bool>& contiguous,
174 Nvf::DataType dtype = Nvf::DataType::Float,
175 bool is_cpu = false) -> nvfuser::Tensor {
176 FUSER_PERF_SCOPE("FusionDefinition.define_tensor (default)");
177
178 for (size_t i = 0; i < symbolic_sizes.size(); ++i) {
179 TORCH_CHECK(
180 symbolic_sizes[i] == -1 || symbolic_sizes[i] == 1,
181 "The value ",
182 symbolic_sizes[i],
183 " at index ",
184 i,
185 " was neither broadcast(1) or symbolic(-1).");
186 }
187
188 nvfuser::Tensor out = self.defineTensor();
189 self.defineRecord(new nvfuser::TensorRecord(
190 {self.recordingState(out())},
191 symbolic_sizes,
192 contiguous,
193 dtype,
194 is_cpu));
195
196 return out;
197 },
198 py::arg("symbolic_sizes"),
199 py::arg("contiguous"),
200 py::arg("dtype") = Nvf::DataType::Float,
201 py::arg("is_cpu") = false,
202 py::return_value_policy::reference)
203 .def(
204 "define_tensor",
205 [](nvfuser::FusionDefinition& self,
206 std::vector<int64_t>& sizes,
207 std::vector<int64_t>& strides,
208 Nvf::DataType dtype = Nvf::DataType::Float,
209 bool is_cpu = false) -> nvfuser::Tensor {
210 FUSER_PERF_SCOPE("FusionDefinition.define_tensor (integration)");
211 TORCH_CHECK(
212 sizes.size() == strides.size(),
213 "The number of sizes does not match the number of strides.",
214 sizes.size(),
215 strides.size());
216
217 // TensorViewBuilder assumes any dim with a compile time constant
218 // size == 1 is a "maybe broadcast" axis, symbolic sizes are
219 // identified by -1, and size == 0 is not supported.
220
221 // Translate to TensorViewBuilder's view of the world.
222 std::vector<int64_t> maybe_symbolic_sizes;
223 maybe_symbolic_sizes.reserve(sizes.size());
224 for (const auto i : c10::irange(sizes.size())) {
225 TORCH_INTERNAL_ASSERT(
226 sizes[i] > 0,
227 "Size of ",
228 sizes[i],
229 " is not supported in nvFuser. Expected size > 0.");
230 if (sizes[i] == 1) {
231 maybe_symbolic_sizes.push_back(1);
232 } else {
233 maybe_symbolic_sizes.push_back(-1);
234 }
235 }
236
237 std::vector<bool> contig_info(strides.size(), false);
238 for (int i = contig_info.size() - 1; i >= 0; --i) {
239 if (i == static_cast<int>(contig_info.size() - 1)) {
240 contig_info[i] = (strides[i] == 1);
241 } else {
242 contig_info[i] =
243 (strides[i] == (strides[i + 1] * sizes[i + 1]));
244 }
245 }
246
247 nvfuser::Tensor out = self.defineTensor();
248 self.defineRecord(new nvfuser::TensorRecord(
249 {self.recordingState(out())},
250 std::move(maybe_symbolic_sizes),
251 std::move(contig_info),
252 dtype,
253 is_cpu));
254
255 return out;
256 },
257 py::arg("sizes"),
258 py::arg("strides"),
259 py::arg("dtype") = Nvf::DataType::Float,
260 py::arg("is_cpu") = false,
261 py::return_value_policy::reference)
262 .def(
263 "define_constant",
264 [](nvfuser::FusionDefinition& self, double val) -> nvfuser::Scalar {
265 FUSER_PERF_SCOPE("FusionDefinition.define_constant (double)");
266 nvfuser::Scalar out = self.defineScalar();
267 self.defineRecord(new nvfuser::ConstantRecord<Nvf::Double, double>(
268 {self.recordingState(out())}, val));
269 return out;
270 },
271 py::return_value_policy::reference)
272 .def(
273 "define_constant",
274 [](nvfuser::FusionDefinition& self,
275 std::complex<double> val) -> nvfuser::Scalar {
276 FUSER_PERF_SCOPE("FusionDefinition.define_constant (complex)");
277 nvfuser::Scalar out = self.defineScalar();
278 self.defineRecord(
279 new nvfuser::
280 ConstantRecord<Nvf::ComplexDouble, c10::complex<double>>(
281 {self.recordingState(out())},
282 static_cast<c10::complex<double>>(val)));
283 return out;
284 },
285 py::return_value_policy::reference)
286 .def(
287 "define_constant",
288 [](nvfuser::FusionDefinition& self, bool val) -> nvfuser::Scalar {
289 FUSER_PERF_SCOPE("FusionDefinition.define_constant (bool)");
290 nvfuser::Scalar out = self.defineScalar();
291 self.defineRecord(new nvfuser::ConstantRecord<Nvf::Bool, bool>(
292 {self.recordingState(out())}, val));
293 return out;
294 },
295 py::return_value_policy::reference)
296 .def(
297 "define_constant",
298 [](nvfuser::FusionDefinition& self, int64_t val) -> nvfuser::Scalar {
299 FUSER_PERF_SCOPE("FusionDefinition.define_constant (int)");
300 nvfuser::Scalar out = self.defineScalar();
301 self.defineRecord(new nvfuser::ConstantRecord<Nvf::Int, int64_t>(
302 {self.recordingState(out())}, val));
303 return out;
304 },
305 py::return_value_policy::reference)
306 .def(
307 "define_scalar",
308 [](nvfuser::FusionDefinition& self,
309 Nvf::DataType dtype = Nvf::DataType::Double) -> nvfuser::Scalar {
310 FUSER_PERF_SCOPE("FusionDefinition.define_scalar");
311 nvfuser::Scalar out = self.defineScalar();
312 self.defineRecord(
313 new nvfuser::ScalarRecord({self.recordingState(out())}, dtype));
314 return out;
315 },
316 py::arg("dtype") = Nvf::DataType::Double,
317 py::return_value_policy::reference);
318
319 //! The Operators class is a nested class of FusionDefinition to allow the
320 //! user to query the class for the list of operators.
321 //!
322 //! Example:
323 //! help(FusionDefinition.Operators)
324 //!
325 //! Additional operators are expected to be defined below as needed. They
326 //! may require defining a new RecordFunctor child class if they are unique.
327 py::class_<nvfuser::FusionDefinition::Operators> nvf_ops(
328 fusion_def, "Operators");
329 nvf_ops.def(py::init<nvfuser::FusionDefinition*>());
330
331 // ******************** INSERT OP BINDINGS BELOW HERE ********************
332#define OP_PREFIX "Operators."
333#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \
334 nvf_ops.def( \
335 op_str, \
336 [](nvfuser::FusionDefinition::Operators& self, \
337 nvfuser::Tensor input) -> nvfuser::Tensor { \
338 FUSER_PERF_SCOPE("Operators." op_str); \
339 nvfuser::FusionDefinition* fd = self.fusion_definition; \
340 nvfuser::Tensor output = fd->defineTensor(); \
341 fd->defineRecord( \
342 new nvfuser::OpRecord<Nvf::TensorView*, Nvf::TensorView*>( \
343 {fd->recordingState(input())}, \
344 {fd->recordingState(output())}, \
345 ("ops." op_str), \
346 static_cast<Nvf::TensorView* (*)(Nvf::TensorView*)>( \
347 Nvf::op_name))); \
348 return output; \
349 }, \
350 py::return_value_policy::reference); \
351 nvf_ops.def( \
352 op_str, \
353 [](nvfuser::FusionDefinition::Operators& self, \
354 nvfuser::Scalar input) -> nvfuser::Scalar { \
355 FUSER_PERF_SCOPE("Operators." op_str); \
356 nvfuser::FusionDefinition* fd = self.fusion_definition; \
357 nvfuser::Scalar output = fd->defineScalar(); \
358 fd->defineRecord(new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*>( \
359 {fd->recordingState(input())}, \
360 {fd->recordingState(output())}, \
361 ("ops." op_str), \
362 static_cast<Nvf::Val* (*)(Nvf::Val*)>(Nvf::op_name))); \
363 return output; \
364 }, \
365 py::return_value_policy::reference);
366
367 NVFUSER_PYTHON_BINDING_UNARY_OP("abs", abs)
368 NVFUSER_PYTHON_BINDING_UNARY_OP("acos", acos)
369 NVFUSER_PYTHON_BINDING_UNARY_OP("asin", asin)
370 NVFUSER_PYTHON_BINDING_UNARY_OP("atan", atan)
371 NVFUSER_PYTHON_BINDING_UNARY_OP("atanh", atanh)
372 NVFUSER_PYTHON_BINDING_UNARY_OP("ceil", ceil)
373 NVFUSER_PYTHON_BINDING_UNARY_OP("cos", cos)
374 NVFUSER_PYTHON_BINDING_UNARY_OP("cosh", cosh)
375 NVFUSER_PYTHON_BINDING_UNARY_OP("exp", exp)
376 NVFUSER_PYTHON_BINDING_UNARY_OP("expm1", expm1)
377 NVFUSER_PYTHON_BINDING_UNARY_OP("erf", erf)
378 NVFUSER_PYTHON_BINDING_UNARY_OP("erfc", erfc)
379 NVFUSER_PYTHON_BINDING_UNARY_OP("floor", floor)
380 NVFUSER_PYTHON_BINDING_UNARY_OP("frac", frac)
381 NVFUSER_PYTHON_BINDING_UNARY_OP("lgamma", lgamma)
382 NVFUSER_PYTHON_BINDING_UNARY_OP("log", log)
383 NVFUSER_PYTHON_BINDING_UNARY_OP("log10", log10)
384 NVFUSER_PYTHON_BINDING_UNARY_OP("log1p", log1p)
385 NVFUSER_PYTHON_BINDING_UNARY_OP("log2", log2)
386 NVFUSER_PYTHON_BINDING_UNARY_OP("neg", neg)
387 NVFUSER_PYTHON_BINDING_UNARY_OP("bitwise_not", bitwise_not)
388 NVFUSER_PYTHON_BINDING_UNARY_OP("relu", relu)
389 NVFUSER_PYTHON_BINDING_UNARY_OP("rand_like", rand_like)
390 NVFUSER_PYTHON_BINDING_UNARY_OP("reciprocal", reciprocal)
391 NVFUSER_PYTHON_BINDING_UNARY_OP("round", round)
392 NVFUSER_PYTHON_BINDING_UNARY_OP("rsqrt", rsqrt)
393 NVFUSER_PYTHON_BINDING_UNARY_OP("set", set)
394 NVFUSER_PYTHON_BINDING_UNARY_OP("sign", sign)
395 NVFUSER_PYTHON_BINDING_UNARY_OP("sigmoid", sigmoid)
396 NVFUSER_PYTHON_BINDING_UNARY_OP("silu", silu)
397 NVFUSER_PYTHON_BINDING_UNARY_OP("sin", sin)
398 NVFUSER_PYTHON_BINDING_UNARY_OP("sinh", sinh)
399 NVFUSER_PYTHON_BINDING_UNARY_OP("sqrt", sqrt)
400 NVFUSER_PYTHON_BINDING_UNARY_OP("tan", tan)
401 NVFUSER_PYTHON_BINDING_UNARY_OP("tanh", tanh)
402 NVFUSER_PYTHON_BINDING_UNARY_OP("trunc", trunc)
403 NVFUSER_PYTHON_BINDING_UNARY_OP("isfinite", isfinite)
404 NVFUSER_PYTHON_BINDING_UNARY_OP("isinf", isinf)
405 NVFUSER_PYTHON_BINDING_UNARY_OP("isnan", isnan)
406 NVFUSER_PYTHON_BINDING_UNARY_OP("isneginf", isneginf)
407 NVFUSER_PYTHON_BINDING_UNARY_OP("isposinf", isposinf)
408 NVFUSER_PYTHON_BINDING_UNARY_OP("isreal", isreal)
409 NVFUSER_PYTHON_BINDING_UNARY_OP("real", real)
410 NVFUSER_PYTHON_BINDING_UNARY_OP("imag", imag)
411#undef NVFUSER_PYTHON_BINDING_UNARY_OP
412
413#define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name) \
414 nvf_ops.def( \
415 op_str, \
416 [](nvfuser::FusionDefinition::Operators& self, \
417 nvfuser::Tensor arg1, \
418 nvfuser::Tensor arg2) -> nvfuser::Tensor { \
419 FUSER_PERF_SCOPE("Operators." op_str); \
420 nvfuser::FusionDefinition* fd = self.fusion_definition; \
421 nvfuser::Tensor output = fd->defineTensor(); \
422 fd->defineRecord(new nvfuser::OpRecord< \
423 Nvf::TensorView*, \
424 Nvf::TensorView*, \
425 Nvf::TensorView*>( \
426 {fd->recordingState(arg1()), fd->recordingState(arg2())}, \
427 {fd->recordingState(output())}, \
428 ("ops." op_str), \
429 static_cast< \
430 Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*)>( \
431 Nvf::op_name))); \
432 return output; \
433 }, \
434 py::return_value_policy::reference); \
435 nvf_ops.def( \
436 op_str, \
437 [](nvfuser::FusionDefinition::Operators& self, \
438 nvfuser::Tensor arg1, \
439 nvfuser::Scalar arg2) -> nvfuser::Tensor { \
440 FUSER_PERF_SCOPE("Operators." op_str); \
441 nvfuser::FusionDefinition* fd = self.fusion_definition; \
442 nvfuser::Tensor output = fd->defineTensor(); \
443 fd->defineRecord(new nvfuser::OpRecord< \
444 Nvf::TensorView*, \
445 Nvf::TensorView*, \
446 Nvf::Val*>( \
447 {fd->recordingState(arg1()), fd->recordingState(arg2())}, \
448 {fd->recordingState(output())}, \
449 ("ops." op_str), \
450 static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>( \
451 Nvf::op_name))); \
452 return output; \
453 }, \
454 py::return_value_policy::reference); \
455 nvf_ops.def( \
456 op_str, \
457 [](nvfuser::FusionDefinition::Operators& self, \
458 nvfuser::Scalar arg1, \
459 nvfuser::Tensor arg2) -> nvfuser::Tensor { \
460 FUSER_PERF_SCOPE("Operators." op_str); \
461 nvfuser::FusionDefinition* fd = self.fusion_definition; \
462 nvfuser::Tensor output = fd->defineTensor(); \
463 fd->defineRecord(new nvfuser::OpRecord< \
464 Nvf::TensorView*, \
465 Nvf::Val*, \
466 Nvf::TensorView*>( \
467 {fd->recordingState(arg1()), fd->recordingState(arg2())}, \
468 {fd->recordingState(output())}, \
469 ("ops." op_str), \
470 static_cast<Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*)>( \
471 Nvf::op_name))); \
472 return output; \
473 }, \
474 py::return_value_policy::reference); \
475 nvf_ops.def( \
476 op_str, \
477 [](nvfuser::FusionDefinition::Operators& self, \
478 nvfuser::Scalar arg1, \
479 nvfuser::Scalar arg2) -> nvfuser::Scalar { \
480 FUSER_PERF_SCOPE("Operators." op_str); \
481 nvfuser::FusionDefinition* fd = self.fusion_definition; \
482 nvfuser::Scalar output = fd->defineScalar(); \
483 fd->defineRecord( \
484 new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*>( \
485 {fd->recordingState(arg1()), fd->recordingState(arg2())}, \
486 {fd->recordingState(output())}, \
487 ("ops." op_str), \
488 static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*)>( \
489 Nvf::op_name))); \
490 return output; \
491 }, \
492 py::return_value_policy::reference);
493
494 NVFUSER_PYTHON_BINDING_BINARY_OP("add", add)
495 NVFUSER_PYTHON_BINDING_BINARY_OP("atan2", atan2)
496 NVFUSER_PYTHON_BINDING_BINARY_OP("div", div)
497 NVFUSER_PYTHON_BINDING_BINARY_OP("fmod", fmod)
498 NVFUSER_PYTHON_BINDING_BINARY_OP("mul", mul)
499 NVFUSER_PYTHON_BINDING_BINARY_OP("pow", pow)
500 NVFUSER_PYTHON_BINDING_BINARY_OP("remainder", remainder)
501 NVFUSER_PYTHON_BINDING_BINARY_OP("sub", sub)
502 NVFUSER_PYTHON_BINDING_BINARY_OP("mod", mod)
503 NVFUSER_PYTHON_BINDING_BINARY_OP("eq", eq)
504 NVFUSER_PYTHON_BINDING_BINARY_OP("ge", ge)
505 NVFUSER_PYTHON_BINDING_BINARY_OP("gt", gt)
506 NVFUSER_PYTHON_BINDING_BINARY_OP("le", le)
507 NVFUSER_PYTHON_BINDING_BINARY_OP("lt", lt)
508 NVFUSER_PYTHON_BINDING_BINARY_OP("ne", ne)
509 NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_and", bitwise_and)
510 NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_or", bitwise_or)
511 NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_xor", bitwise_xor)
512 NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_left_shift", bitwise_left_shift)
513 NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift", bitwise_left_shift)
514#undef NVFUSER_PYTHON_BINDING_BINARY_OP
515
516#define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name) \
517 nvf_ops.def( \
518 op_str, \
519 [](nvfuser::FusionDefinition::Operators& self, \
520 nvfuser::Tensor arg1, \
521 nvfuser::Tensor arg2, \
522 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
523 FUSER_PERF_SCOPE("Operators." op_str); \
524 nvfuser::FusionDefinition* fd = self.fusion_definition; \
525 nvfuser::Tensor output = fd->defineTensor(); \
526 fd->defineRecord(new nvfuser::OpRecord< \
527 Nvf::TensorView*, \
528 Nvf::TensorView*, \
529 Nvf::TensorView*, \
530 Nvf::Val*>( \
531 {fd->recordingState(arg1()), \
532 fd->recordingState(arg2()), \
533 fd->recordingState(arg3())}, \
534 {fd->recordingState(output())}, \
535 ("ops." op_str), \
536 static_cast< \
537 Nvf:: \
538 TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
539 Nvf::op_name))); \
540 return output; \
541 }, \
542 py::return_value_policy::reference); \
543 nvf_ops.def( \
544 op_str, \
545 [](nvfuser::FusionDefinition::Operators& self, \
546 nvfuser::Tensor arg1, \
547 nvfuser::Scalar arg2, \
548 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
549 FUSER_PERF_SCOPE("Operators." op_str); \
550 nvfuser::FusionDefinition* fd = self.fusion_definition; \
551 nvfuser::Tensor output = fd->defineTensor(); \
552 fd->defineRecord(new nvfuser::OpRecord< \
553 Nvf::TensorView*, \
554 Nvf::TensorView*, \
555 Nvf::Val*, \
556 Nvf::Val*>( \
557 {fd->recordingState(arg1()), \
558 fd->recordingState(arg2()), \
559 fd->recordingState(arg3())}, \
560 {fd->recordingState(output())}, \
561 ("ops." op_str), \
562 static_cast< \
563 Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
564 Nvf::op_name))); \
565 return output; \
566 }, \
567 py::return_value_policy::reference); \
568 nvf_ops.def( \
569 op_str, \
570 [](nvfuser::FusionDefinition::Operators& self, \
571 nvfuser::Scalar arg1, \
572 nvfuser::Tensor arg2, \
573 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
574 FUSER_PERF_SCOPE("Operators." op_str); \
575 nvfuser::FusionDefinition* fd = self.fusion_definition; \
576 nvfuser::Tensor output = fd->defineTensor(); \
577 fd->defineRecord(new nvfuser::OpRecord< \
578 Nvf::TensorView*, \
579 Nvf::Val*, \
580 Nvf::TensorView*, \
581 Nvf::Val*>( \
582 {fd->recordingState(arg1()), \
583 fd->recordingState(arg2()), \
584 fd->recordingState(arg3())}, \
585 {fd->recordingState(output())}, \
586 ("ops." op_str), \
587 static_cast< \
588 Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \
589 Nvf::op_name))); \
590 return output; \
591 }, \
592 py::return_value_policy::reference); \
593 nvf_ops.def( \
594 op_str, \
595 [](nvfuser::FusionDefinition::Operators& self, \
596 nvfuser::Scalar arg1, \
597 nvfuser::Scalar arg2, \
598 nvfuser::Scalar arg3) -> nvfuser::Scalar { \
599 FUSER_PERF_SCOPE("Operators." op_str); \
600 nvfuser::FusionDefinition* fd = self.fusion_definition; \
601 nvfuser::Scalar output = fd->defineScalar(); \
602 fd->defineRecord( \
603 new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \
604 {fd->recordingState(arg1()), \
605 fd->recordingState(arg2()), \
606 fd->recordingState(arg3())}, \
607 {fd->recordingState(output())}, \
608 ("ops." op_str), \
609 static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \
610 Nvf::op_name))); \
611 return output; \
612 }, \
613 py::return_value_policy::reference);
614
615 NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("add_alpha", add_alpha)
616 NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("sub_alpha", sub_alpha)
617#undef NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP
618
619#define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name) \
620 nvf_ops.def( \
621 op_str, \
622 [](nvfuser::FusionDefinition::Operators& self, \
623 nvfuser::Scalar arg1, \
624 nvfuser::Scalar arg2, \
625 nvfuser::Scalar arg3) -> nvfuser::Scalar { \
626 FUSER_PERF_SCOPE("Operators." op_str); \
627 nvfuser::FusionDefinition* fd = self.fusion_definition; \
628 nvfuser::Scalar output = fd->defineScalar(); \
629 fd->defineRecord( \
630 new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \
631 {fd->recordingState(arg1()), \
632 fd->recordingState(arg2()), \
633 fd->recordingState(arg3())}, \
634 {fd->recordingState(output())}, \
635 ("ops." op_str), \
636 static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \
637 Nvf::op_name))); \
638 return output; \
639 }, \
640 py::return_value_policy::reference); \
641 nvf_ops.def( \
642 op_str, \
643 [](nvfuser::FusionDefinition::Operators& self, \
644 nvfuser::Tensor arg1, \
645 nvfuser::Tensor arg2, \
646 nvfuser::Tensor arg3) -> nvfuser::Tensor { \
647 FUSER_PERF_SCOPE("Operators." op_str); \
648 nvfuser::FusionDefinition* fd = self.fusion_definition; \
649 nvfuser::Tensor output = fd->defineTensor(); \
650 fd->defineRecord(new nvfuser::OpRecord< \
651 Nvf::TensorView*, \
652 Nvf::TensorView*, \
653 Nvf::TensorView*, \
654 Nvf::TensorView*>( \
655 {fd->recordingState(arg1()), \
656 fd->recordingState(arg2()), \
657 fd->recordingState(arg3())}, \
658 {fd->recordingState(output())}, \
659 ("ops." op_str), \
660 static_cast< \
661 Nvf:: \
662 TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*)>( \
663 Nvf::op_name))); \
664 return output; \
665 }, \
666 py::return_value_policy::reference); \
667 nvf_ops.def( \
668 op_str, \
669 [](nvfuser::FusionDefinition::Operators& self, \
670 nvfuser::Tensor arg1, \
671 nvfuser::Tensor arg2, \
672 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
673 FUSER_PERF_SCOPE("Operators." op_str); \
674 nvfuser::FusionDefinition* fd = self.fusion_definition; \
675 nvfuser::Tensor output = fd->defineTensor(); \
676 fd->defineRecord(new nvfuser::OpRecord< \
677 Nvf::TensorView*, \
678 Nvf::TensorView*, \
679 Nvf::TensorView*, \
680 Nvf::Val*>( \
681 {fd->recordingState(arg1()), \
682 fd->recordingState(arg2()), \
683 fd->recordingState(arg3())}, \
684 {fd->recordingState(output())}, \
685 ("ops." op_str), \
686 static_cast< \
687 Nvf:: \
688 TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
689 Nvf::op_name))); \
690 return output; \
691 }, \
692 py::return_value_policy::reference); \
693 nvf_ops.def( \
694 op_str, \
695 [](nvfuser::FusionDefinition::Operators& self, \
696 nvfuser::Tensor arg1, \
697 nvfuser::Scalar arg2, \
698 nvfuser::Tensor arg3) -> nvfuser::Tensor { \
699 FUSER_PERF_SCOPE("Operators." op_str); \
700 nvfuser::FusionDefinition* fd = self.fusion_definition; \
701 nvfuser::Tensor output = fd->defineTensor(); \
702 fd->defineRecord(new nvfuser::OpRecord< \
703 Nvf::TensorView*, \
704 Nvf::TensorView*, \
705 Nvf::Val*, \
706 Nvf::TensorView*>( \
707 {fd->recordingState(arg1()), \
708 fd->recordingState(arg2()), \
709 fd->recordingState(arg3())}, \
710 {fd->recordingState(output())}, \
711 ("ops." op_str), \
712 static_cast< \
713 Nvf:: \
714 TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*)>( \
715 Nvf::op_name))); \
716 return output; \
717 }, \
718 py::return_value_policy::reference); \
719 nvf_ops.def( \
720 op_str, \
721 [](nvfuser::FusionDefinition::Operators& self, \
722 nvfuser::Scalar arg1, \
723 nvfuser::Tensor arg2, \
724 nvfuser::Tensor arg3) -> nvfuser::Tensor { \
725 FUSER_PERF_SCOPE("Operators." op_str); \
726 nvfuser::FusionDefinition* fd = self.fusion_definition; \
727 nvfuser::Tensor output = fd->defineTensor(); \
728 fd->defineRecord(new nvfuser::OpRecord< \
729 Nvf::TensorView*, \
730 Nvf::Val*, \
731 Nvf::TensorView*, \
732 Nvf::TensorView*>( \
733 {fd->recordingState(arg1()), \
734 fd->recordingState(arg2()), \
735 fd->recordingState(arg3())}, \
736 {fd->recordingState(output())}, \
737 ("ops." op_str), \
738 static_cast< \
739 Nvf:: \
740 TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*)>( \
741 Nvf::op_name))); \
742 return output; \
743 }, \
744 py::return_value_policy::reference); \
745 nvf_ops.def( \
746 op_str, \
747 [](nvfuser::FusionDefinition::Operators& self, \
748 nvfuser::Scalar arg1, \
749 nvfuser::Scalar arg2, \
750 nvfuser::Tensor arg3) -> nvfuser::Tensor { \
751 FUSER_PERF_SCOPE("Operators." op_str); \
752 nvfuser::FusionDefinition* fd = self.fusion_definition; \
753 nvfuser::Tensor output = fd->defineTensor(); \
754 fd->defineRecord(new nvfuser::OpRecord< \
755 Nvf::TensorView*, \
756 Nvf::Val*, \
757 Nvf::Val*, \
758 Nvf::TensorView*>( \
759 {fd->recordingState(arg1()), \
760 fd->recordingState(arg2()), \
761 fd->recordingState(arg3())}, \
762 {fd->recordingState(output())}, \
763 ("ops." op_str), \
764 static_cast< \
765 Nvf::TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*)>( \
766 Nvf::op_name))); \
767 return output; \
768 }, \
769 py::return_value_policy::reference); \
770 nvf_ops.def( \
771 op_str, \
772 [](nvfuser::FusionDefinition::Operators& self, \
773 nvfuser::Tensor arg1, \
774 nvfuser::Scalar arg2, \
775 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
776 FUSER_PERF_SCOPE("Operators." op_str); \
777 nvfuser::FusionDefinition* fd = self.fusion_definition; \
778 nvfuser::Tensor output = fd->defineTensor(); \
779 fd->defineRecord(new nvfuser::OpRecord< \
780 Nvf::TensorView*, \
781 Nvf::TensorView*, \
782 Nvf::Val*, \
783 Nvf::Val*>( \
784 {fd->recordingState(arg1()), \
785 fd->recordingState(arg2()), \
786 fd->recordingState(arg3())}, \
787 {fd->recordingState(output())}, \
788 ("ops." op_str), \
789 static_cast< \
790 Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
791 Nvf::op_name))); \
792 return output; \
793 }, \
794 py::return_value_policy::reference); \
795 nvf_ops.def( \
796 op_str, \
797 [](nvfuser::FusionDefinition::Operators& self, \
798 nvfuser::Scalar arg1, \
799 nvfuser::Tensor arg2, \
800 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
801 FUSER_PERF_SCOPE("Operators." op_str); \
802 nvfuser::FusionDefinition* fd = self.fusion_definition; \
803 nvfuser::Tensor output = fd->defineTensor(); \
804 fd->defineRecord(new nvfuser::OpRecord< \
805 Nvf::TensorView*, \
806 Nvf::Val*, \
807 Nvf::TensorView*, \
808 Nvf::Val*>( \
809 {fd->recordingState(arg1()), \
810 fd->recordingState(arg2()), \
811 fd->recordingState(arg3())}, \
812 {fd->recordingState(output())}, \
813 ("ops." op_str), \
814 static_cast< \
815 Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \
816 Nvf::op_name))); \
817 return output; \
818 }, \
819 py::return_value_policy::reference);
820
821 NVFUSER_PYTHON_BINDING_TERNARY_OP("lerp", lerp)
822 NVFUSER_PYTHON_BINDING_TERNARY_OP("where", where)
823#undef NVFUSER_PYTHON_BINDING_TERNARY_OP
824
825#define NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP(op_str, op_name) \
826 nvf_ops.def( \
827 op_str, \
828 [](nvfuser::FusionDefinition::Operators& self, \
829 nvfuser::Scalar arg1, \
830 nvfuser::Scalar arg2, \
831 nvfuser::Scalar arg3) -> nvfuser::Scalar { \
832 FUSER_PERF_SCOPE("Operators." op_str); \
833 nvfuser::FusionDefinition* fd = self.fusion_definition; \
834 nvfuser::Scalar output = fd->defineScalar(); \
835 fd->defineRecord( \
836 new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \
837 {fd->recordingState(arg1()), \
838 fd->recordingState(arg2()), \
839 fd->recordingState(arg3())}, \
840 {fd->recordingState(output())}, \
841 ("ops." op_str), \
842 static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \
843 Nvf::op_name))); \
844 return output; \
845 }, \
846 py::return_value_policy::reference); \
847 nvf_ops.def( \
848 op_str, \
849 [](nvfuser::FusionDefinition::Operators& self, \
850 nvfuser::Tensor arg1, \
851 nvfuser::Scalar arg2, \
852 nvfuser::Scalar arg3) -> nvfuser::Tensor { \
853 FUSER_PERF_SCOPE("Operators." op_str); \
854 nvfuser::FusionDefinition* fd = self.fusion_definition; \
855 nvfuser::Tensor output = fd->defineTensor(); \
856 fd->defineRecord(new nvfuser::OpRecord< \
857 Nvf::TensorView*, \
858 Nvf::TensorView*, \
859 Nvf::Val*, \
860 Nvf::Val*>( \
861 {fd->recordingState(arg1()), \
862 fd->recordingState(arg2()), \
863 fd->recordingState(arg3())}, \
864 {fd->recordingState(output())}, \
865 ("ops." op_str), \
866 static_cast< \
867 Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
868 Nvf::op_name))); \
869 return output; \
870 }, \
871 py::return_value_policy::reference);
872
873 NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("clamp", clamp)
874 NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("threshold", threshold)
875#undef NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP
876
877#define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name) \
878 nvf_ops.def( \
879 op_str, \
880 [](nvfuser::FusionDefinition::Operators& self, \
881 nvfuser::Scalar arg1, \
882 nvfuser::Scalar arg2, \
883 nvfuser::Scalar arg3, \
884 nvfuser::Scalar arg4) -> nvfuser::Scalar { \
885 FUSER_PERF_SCOPE("Operators." op_str); \
886 nvfuser::FusionDefinition* fd = self.fusion_definition; \
887 nvfuser::Scalar output = fd->defineScalar(); \
888 fd->defineRecord(new nvfuser::OpRecord< \
889 Nvf::Val*, \
890 Nvf::Val*, \
891 Nvf::Val*, \
892 Nvf::Val*, \
893 Nvf::Val*>( \
894 {fd->recordingState(arg1()), \
895 fd->recordingState(arg2()), \
896 fd->recordingState(arg3()), \
897 fd->recordingState(arg4())}, \
898 {fd->recordingState(output())}, \
899 ("ops." op_str), \
900 static_cast< \
901 Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \
902 Nvf::op_name))); \
903 return output; \
904 }, \
905 py::return_value_policy::reference); \
906 nvf_ops.def( \
907 op_str, \
908 [](nvfuser::FusionDefinition::Operators& self, \
909 nvfuser::Tensor arg1, \
910 nvfuser::Tensor arg2, \
911 nvfuser::Tensor arg3, \
912 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
913 FUSER_PERF_SCOPE("Operators." op_str); \
914 nvfuser::FusionDefinition* fd = self.fusion_definition; \
915 nvfuser::Tensor output = fd->defineTensor(); \
916 fd->defineRecord(new nvfuser::OpRecord< \
917 Nvf::TensorView*, \
918 Nvf::TensorView*, \
919 Nvf::TensorView*, \
920 Nvf::TensorView*, \
921 Nvf::TensorView*>( \
922 {fd->recordingState(arg1()), \
923 fd->recordingState(arg2()), \
924 fd->recordingState(arg3()), \
925 fd->recordingState(arg4())}, \
926 {fd->recordingState(output())}, \
927 ("ops." op_str), \
928 static_cast< \
929 Nvf:: \
930 TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
931 Nvf::op_name))); \
932 return output; \
933 }, \
934 py::return_value_policy::reference); \
935 nvf_ops.def( \
936 op_str, \
937 [](nvfuser::FusionDefinition::Operators& self, \
938 nvfuser::Tensor arg1, \
939 nvfuser::Tensor arg2, \
940 nvfuser::Scalar arg3, \
941 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
942 FUSER_PERF_SCOPE("Operators." op_str); \
943 nvfuser::FusionDefinition* fd = self.fusion_definition; \
944 nvfuser::Tensor output = fd->defineTensor(); \
945 fd->defineRecord(new nvfuser::OpRecord< \
946 Nvf::TensorView*, \
947 Nvf::TensorView*, \
948 Nvf::TensorView*, \
949 Nvf::Val*, \
950 Nvf::Val*>( \
951 {fd->recordingState(arg1()), \
952 fd->recordingState(arg2()), \
953 fd->recordingState(arg3()), \
954 fd->recordingState(arg4())}, \
955 {fd->recordingState(output())}, \
956 ("ops." op_str), \
957 static_cast< \
958 Nvf:: \
959 TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
960 Nvf::op_name))); \
961 return output; \
962 }, \
963 py::return_value_policy::reference); \
964 nvf_ops.def( \
965 op_str, \
966 [](nvfuser::FusionDefinition::Operators& self, \
967 nvfuser::Tensor arg1, \
968 nvfuser::Scalar arg2, \
969 nvfuser::Tensor arg3, \
970 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
971 FUSER_PERF_SCOPE("Operators." op_str); \
972 nvfuser::FusionDefinition* fd = self.fusion_definition; \
973 nvfuser::Tensor output = fd->defineTensor(); \
974 fd->defineRecord(new nvfuser::OpRecord< \
975 Nvf::TensorView*, \
976 Nvf::TensorView*, \
977 Nvf::Val*, \
978 Nvf::TensorView*, \
979 Nvf::Val*>( \
980 {fd->recordingState(arg1()), \
981 fd->recordingState(arg2()), \
982 fd->recordingState(arg3()), \
983 fd->recordingState(arg4())}, \
984 {fd->recordingState(output())}, \
985 ("ops." op_str), \
986 static_cast< \
987 Nvf:: \
988 TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \
989 Nvf::op_name))); \
990 return output; \
991 }, \
992 py::return_value_policy::reference); \
993 nvf_ops.def( \
994 op_str, \
995 [](nvfuser::FusionDefinition::Operators& self, \
996 nvfuser::Scalar arg1, \
997 nvfuser::Tensor arg2, \
998 nvfuser::Tensor arg3, \
999 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
1000 FUSER_PERF_SCOPE("Operators." op_str); \
1001 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1002 nvfuser::Tensor output = fd->defineTensor(); \
1003 fd->defineRecord(new nvfuser::OpRecord< \
1004 Nvf::TensorView*, \
1005 Nvf::Val*, \
1006 Nvf::TensorView*, \
1007 Nvf::TensorView*, \
1008 Nvf::Val*>( \
1009 {fd->recordingState(arg1()), \
1010 fd->recordingState(arg2()), \
1011 fd->recordingState(arg3()), \
1012 fd->recordingState(arg4())}, \
1013 {fd->recordingState(output())}, \
1014 ("ops." op_str), \
1015 static_cast< \
1016 Nvf:: \
1017 TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \
1018 Nvf::op_name))); \
1019 return output; \
1020 }, \
1021 py::return_value_policy::reference); \
1022 nvf_ops.def( \
1023 op_str, \
1024 [](nvfuser::FusionDefinition::Operators& self, \
1025 nvfuser::Scalar arg1, \
1026 nvfuser::Scalar arg2, \
1027 nvfuser::Tensor arg3, \
1028 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
1029 FUSER_PERF_SCOPE("Operators." op_str); \
1030 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1031 nvfuser::Tensor output = fd->defineTensor(); \
1032 fd->defineRecord(new nvfuser::OpRecord< \
1033 Nvf::TensorView*, \
1034 Nvf::Val*, \
1035 Nvf::Val*, \
1036 Nvf::TensorView*, \
1037 Nvf::Val*>( \
1038 {fd->recordingState(arg1()), \
1039 fd->recordingState(arg2()), \
1040 fd->recordingState(arg3()), \
1041 fd->recordingState(arg4())}, \
1042 {fd->recordingState(output())}, \
1043 ("ops." op_str), \
1044 static_cast< \
1045 Nvf:: \
1046 TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \
1047 Nvf::op_name))); \
1048 return output; \
1049 }, \
1050 py::return_value_policy::reference); \
1051 nvf_ops.def( \
1052 op_str, \
1053 [](nvfuser::FusionDefinition::Operators& self, \
1054 nvfuser::Tensor arg1, \
1055 nvfuser::Scalar arg2, \
1056 nvfuser::Scalar arg3, \
1057 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
1058 FUSER_PERF_SCOPE("Operators." op_str); \
1059 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1060 nvfuser::Tensor output = fd->defineTensor(); \
1061 fd->defineRecord(new nvfuser::OpRecord< \
1062 Nvf::TensorView*, \
1063 Nvf::TensorView*, \
1064 Nvf::Val*, \
1065 Nvf::Val*, \
1066 Nvf::Val*>( \
1067 {fd->recordingState(arg1()), \
1068 fd->recordingState(arg2()), \
1069 fd->recordingState(arg3()), \
1070 fd->recordingState(arg4())}, \
1071 {fd->recordingState(output())}, \
1072 ("ops." op_str), \
1073 static_cast< \
1074 Nvf:: \
1075 TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \
1076 Nvf::op_name))); \
1077 return output; \
1078 }, \
1079 py::return_value_policy::reference); \
1080 nvf_ops.def( \
1081 op_str, \
1082 [](nvfuser::FusionDefinition::Operators& self, \
1083 nvfuser::Scalar arg1, \
1084 nvfuser::Tensor arg2, \
1085 nvfuser::Scalar arg3, \
1086 nvfuser::Scalar arg4) -> nvfuser::Tensor { \
1087 FUSER_PERF_SCOPE("Operators." op_str); \
1088 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1089 nvfuser::Tensor output = fd->defineTensor(); \
1090 fd->defineRecord(new nvfuser::OpRecord< \
1091 Nvf::TensorView*, \
1092 Nvf::Val*, \
1093 Nvf::TensorView*, \
1094 Nvf::Val*, \
1095 Nvf::Val*>( \
1096 {fd->recordingState(arg1()), \
1097 fd->recordingState(arg2()), \
1098 fd->recordingState(arg3()), \
1099 fd->recordingState(arg4())}, \
1100 {fd->recordingState(output())}, \
1101 ("ops." op_str), \
1102 static_cast< \
1103 Nvf:: \
1104 TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \
1105 Nvf::op_name))); \
1106 return output; \
1107 }, \
1108 py::return_value_policy::reference);
1109
1110 NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul", addcmul)
1111#undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP
1112
1113#define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name) \
1114 nvf_ops.def( \
1115 op_str, \
1116 [](nvfuser::FusionDefinition::Operators& self, \
1117 nvfuser::Tensor arg, \
1118 const std::vector<int>& axes, \
1119 bool keepdim, \
1120 Nvf::DataType dtype) -> nvfuser::Tensor { \
1121 FUSER_PERF_SCOPE("Operators." op_str); \
1122 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1123 nvfuser::Tensor output = fd->defineTensor(); \
1124 fd->defineRecord(new nvfuser::ReductionOpRecord( \
1125 {fd->recordingState(arg())}, \
1126 {fd->recordingState(output())}, \
1127 ("ops." op_str), \
1128 static_cast< \
1129 Nvf:: \
1130 TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>( \
1131 Nvf::op_name), \
1132 axes, \
1133 keepdim, \
1134 dtype)); \
1135 return output; \
1136 }, \
1137 py::arg("arg"), \
1138 py::arg("axes"), \
1139 py::arg("keepdim") = false, \
1140 py::arg("dtype") = Nvf::DataType::Null, \
1141 py::return_value_policy::reference);
1142
1143 NVFUSER_PYTHON_BINDING_REDUCTION_OP("sum", sum)
1144 NVFUSER_PYTHON_BINDING_REDUCTION_OP("max", max)
1145 NVFUSER_PYTHON_BINDING_REDUCTION_OP("min", min)
1146#undef NVFUSER_PYTHON_BINDING_REDUCTION_OP
1147
1148#define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name) \
1149 nvf_ops.def( \
1150 op_str, \
1151 [](nvfuser::FusionDefinition::Operators& self, \
1152 nvfuser::Tensor arg, \
1153 Nvf::DataType dtype) -> nvfuser::Tensor { \
1154 FUSER_PERF_SCOPE("Operators." op_str); \
1155 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1156 nvfuser::Tensor output = fd->defineTensor(); \
1157 fd->defineRecord( \
1158 new nvfuser::CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>( \
1159 {fd->recordingState(arg())}, \
1160 {fd->recordingState(output())}, \
1161 ("ops." op_str), \
1162 static_cast< \
1163 Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>( \
1164 Nvf::op_name), \
1165 dtype)); \
1166 return output; \
1167 }, \
1168 py::arg("arg"), \
1169 py::arg("dtype"), \
1170 py::return_value_policy::reference); \
1171 nvf_ops.def( \
1172 op_str, \
1173 [](nvfuser::FusionDefinition::Operators& self, \
1174 nvfuser::Scalar arg, \
1175 Nvf::DataType dtype) -> nvfuser::Scalar { \
1176 FUSER_PERF_SCOPE("Operators." op_str); \
1177 nvfuser::FusionDefinition* fd = self.fusion_definition; \
1178 nvfuser::Scalar output = fd->defineScalar(); \
1179 fd->defineRecord(new nvfuser::CastOpRecord<Nvf::Val*, Nvf::Val*>( \
1180 {fd->recordingState(arg())}, \
1181 {fd->recordingState(output())}, \
1182 ("ops." op_str), \
1183 static_cast<Nvf::Val* (*)(Nvf::DataType, Nvf::Val*)>( \
1184 Nvf::op_name), \
1185 dtype)); \
1186 return output; \
1187 }, \
1188 py::arg("arg"), \
1189 py::arg("dtype"), \
1190 py::return_value_policy::reference);
1191
1192 NVFUSER_PYTHON_BINDING_CAST_OP("cast", castOp)
1193#undef NVFUSER_PYTHON_BINDING_CAST_OP
1194
1195 nvf_ops.def(
1196 "permute",
1197 [](nvfuser::FusionDefinition::Operators& self,
1198 nvfuser::Tensor arg,
1199 std::vector<int64_t>& dims) -> nvfuser::Tensor {
1200 nvfuser::FusionDefinition* fd = self.fusion_definition;
1201 nvfuser::Tensor output = fd->defineTensor();
1202 self.fusion_definition->defineRecord(new nvfuser::PermuteOpRecord(
1203 {fd->recordingState(arg())}, {fd->recordingState(output())}, dims));
1204 return output;
1205 },
1206 py::arg("arg"),
1207 py::arg("dims"),
1208 py::return_value_policy::reference);
1209 nvf_ops.def(
1210 "squeeze",
1211 [](nvfuser::FusionDefinition::Operators& self,
1212 nvfuser::Tensor arg,
1213 std::vector<int64_t>& original_shape,
1214 int64_t dim) -> nvfuser::Tensor {
1215 FUSER_PERF_SCOPE("Operators.squeeze");
1216 nvfuser::FusionDefinition* fd = self.fusion_definition;
1217 nvfuser::Tensor output = fd->defineTensor();
1218 fd->defineRecord(new nvfuser::SqueezeOpRecord(
1219 {fd->recordingState(arg())},
1220 {fd->recordingState(output())},
1221 original_shape,
1222 dim));
1223 return output;
1224 },
1225 py::arg("arg"),
1226 py::arg("original_shape"),
1227 py::arg("dim"),
1228 py::return_value_policy::reference);
1229 nvf_ops.def(
1230 "view",
1231 [](nvfuser::FusionDefinition::Operators& self,
1232 nvfuser::Tensor arg,
1233 std::vector<int64_t>& original_shape,
1234 std::vector<int64_t>& new_shape) -> nvfuser::Tensor {
1235 nvfuser::FusionDefinition* fd = self.fusion_definition;
1236 nvfuser::Tensor output = fd->defineTensor();
1237 self.fusion_definition->defineRecord(new nvfuser::ViewOpRecord(
1238 {fd->recordingState(arg())},
1239 {fd->recordingState(output())},
1240 original_shape,
1241 new_shape));
1242 return output;
1243 },
1244 py::arg("arg"),
1245 py::arg("original_shape"),
1246 py::arg("new_shape"),
1247 py::return_value_policy::reference);
1248 nvf_ops.def(
1249 "full",
1250 [](nvfuser::FusionDefinition::Operators& self,
1251 std::vector<int64_t>& size,
1252 nvfuser::Scalar arg,
1253 Nvf::DataType dtype) -> nvfuser::Tensor {
1254 nvfuser::FusionDefinition* fd = self.fusion_definition;
1255 nvfuser::Tensor output = fd->defineTensor();
1256 fd->defineRecord(new nvfuser::FullOpRecord(
1257 {fd->recordingState(arg())},
1258 {fd->recordingState(output())},
1259 size,
1260 dtype));
1261 return output;
1262 },
1263 py::arg("size"),
1264 py::arg("arg"),
1265 py::arg("dtype"),
1266 py::return_value_policy::reference);
1267 nvf_ops.def(
1268 "var",
1269 [](nvfuser::FusionDefinition::Operators& self,
1270 nvfuser::Tensor arg,
1271 std::vector<int>& axes,
1272 int64_t correction,
1273 bool keepdim) -> nvfuser::Tensor {
1274 FUSER_PERF_SCOPE("Operators.var");
1275 nvfuser::FusionDefinition* fd = self.fusion_definition;
1276 nvfuser::Tensor output = fd->defineTensor();
1277 fd->defineRecord(new nvfuser::VarianceOpRecord(
1278 {fd->recordingState(arg())},
1279 {fd->recordingState(output())},
1280 axes,
1281 correction,
1282 keepdim));
1283 return output;
1284 },
1285 py::arg("arg"),
1286 py::arg("axes"),
1287 py::arg("correction"),
1288 py::arg("keepdim") = false,
1289 py::return_value_policy::reference);
1290 nvf_ops.def(
1291 "var_mean",
1292 [](nvfuser::FusionDefinition::Operators& self,
1293 nvfuser::Tensor arg,
1294 std::vector<int>& axes,
1295 int64_t correction,
1296 bool keepdim) -> decltype(auto) {
1297 FUSER_PERF_SCOPE("Operators.var_mean");
1298 nvfuser::FusionDefinition* fd = self.fusion_definition;
1299 nvfuser::Tensor var = fd->defineTensor();
1300 nvfuser::Tensor mean = fd->defineTensor();
1301 fd->defineRecord(new nvfuser::VarianceMeanOpRecord(
1302 {fd->recordingState(arg())},
1303 {fd->recordingState(var()), fd->recordingState(mean())},
1304 axes,
1305 correction,
1306 keepdim));
1307 return std::make_tuple(var, mean);
1308 },
1309 py::arg("arg"),
1310 py::arg("axes"),
1311 py::arg("correction"),
1312 py::arg("keepdim") = false,
1313 py::return_value_policy::reference);
1314 nvf_ops.def(
1315 "batch_norm",
1316 [](nvfuser::FusionDefinition::Operators& self,
1317 nvfuser::Tensor arg,
1318 c10::optional<nvfuser::Tensor> weight,
1319 c10::optional<nvfuser::Tensor> bias,
1320 c10::optional<nvfuser::Tensor> running_mean,
1321 c10::optional<nvfuser::Tensor> running_var,
1322 nvfuser::Scalar momentum,
1323 nvfuser::Scalar eps,
1324 bool training,
1325 bool channels_last) -> decltype(auto) {
1326 FUSER_PERF_SCOPE("Operators.batch_norm");
1327 nvfuser::FusionDefinition* fd = self.fusion_definition;
1328 nvfuser::Tensor output = fd->defineTensor();
1329 nvfuser::Tensor mean = fd->defineTensor();
1330 nvfuser::Tensor invstd = fd->defineTensor();
1331 auto weight_state = weight.has_value()
1332 ? fd->recordingState(weight.value()())
1333 : nvfuser::State(0, nvfuser::StateType::None);
1334 auto bias_state = bias.has_value()
1335 ? fd->recordingState(bias.value()())
1336 : nvfuser::State(0, nvfuser::StateType::None);
1337 auto running_mean_state = running_mean.has_value()
1338 ? fd->recordingState(running_mean.value()())
1339 : nvfuser::State(0, nvfuser::StateType::None);
1340 auto running_var_state = running_var.has_value()
1341 ? fd->recordingState(running_var.value()())
1342 : nvfuser::State(0, nvfuser::StateType::None);
1343 fd->defineRecord(new nvfuser::BatchNormOpRecord(
1344 {fd->recordingState(arg()),
1345 weight_state,
1346 bias_state,
1347 running_mean_state,
1348 running_var_state,
1349 fd->recordingState(momentum()),
1350 fd->recordingState(eps())},
1351 {fd->recordingState(output()),
1352 fd->recordingState(mean()),
1353 fd->recordingState(invstd())},
1354 training,
1355 channels_last));
1356 return std::make_tuple(output, mean, invstd);
1357 },
1358 py::arg("arg"),
1359 py::arg("weight").none(true),
1360 py::arg("bias").none(true),
1361 py::arg("running_mean").none(true),
1362 py::arg("running_var").none(true),
1363 py::arg("momentum"),
1364 py::arg("eps"),
1365 py::arg("training"),
1366 py::arg("channels_last") = false,
1367 py::return_value_policy::reference);
1368 nvf_ops.def(
1369 "broadcast_in_dim",
1370 [](nvfuser::FusionDefinition::Operators& self,
1371 nvfuser::Tensor arg,
1372 std::vector<int64_t>& output_shape,
1373 std::vector<int64_t>& broadcast_dims) -> nvfuser::Tensor {
1374 FUSER_PERF_SCOPE("Operators.broadcast_in_dim");
1375 nvfuser::FusionDefinition* fd = self.fusion_definition;
1376 TORCH_CHECK(
1377 output_shape.size() >= broadcast_dims.size(),
1378 "broadcast_dims vector size is too big for output shape!");
1379 nvfuser::Tensor output = fd->defineTensor();
1380 fd->defineRecord(new nvfuser::BroadcastInDimOpRecord(
1381 {fd->recordingState(arg())},
1382 {fd->recordingState(output())},
1383 "ops.broadcast_in_dim",
1384 output_shape,
1385 broadcast_dims));
1386 return output;
1387 },
1388 py::arg("arg"),
1389 py::arg("output_shape"),
1390 py::arg("broadcast_dims"),
1391 py::return_value_policy::reference);
1392 nvf_ops.def(
1393 "broadcast",
1394 [](nvfuser::FusionDefinition::Operators& self,
1395 nvfuser::Tensor arg,
1396 std::vector<bool>& is_broadcast_dim) -> nvfuser::Tensor {
1397 FUSER_PERF_SCOPE("Operators.broadcast");
1398 nvfuser::FusionDefinition* fd = self.fusion_definition;
1399 nvfuser::Tensor output = fd->defineTensor();
1400 fd->defineRecord(new nvfuser::BroadcastOpRecord(
1401 {fd->recordingState(arg())},
1402 {fd->recordingState(output())},
1403 "ops.broadcast",
1404 is_broadcast_dim));
1405 return output;
1406 },
1407 py::arg("arg"),
1408 py::arg("is_broadcast_dim"),
1409 py::return_value_policy::reference);
1410}
1411
1412} // namespace jit
1413} // namespace torch
1414