1 | #include <python_frontend/python_bindings.h> |
2 | |
3 | #include <c10/util/ArrayRef.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <c10/util/irange.h> |
6 | #include <arith.h> |
7 | #include <instrumentation.h> |
8 | #include <ir_all_nodes.h> |
9 | #include <ir_builder.h> |
10 | #include <ops/composite.h> |
11 | #include <python_frontend/fusion_cache.h> |
12 | #include <python_frontend/fusion_definition.h> |
13 | #include <python_frontend/fusion_interface.h> |
14 | #include <python_frontend/fusion_record.h> |
15 | #include <python_frontend/python_bindings.h> |
16 | #include <torch/csrc/jit/python/pybind_utils.h> |
17 | #include <iostream> |
18 | #include <tuple> |
19 | |
20 | namespace torch { |
21 | namespace jit { |
22 | |
23 | void initNvFuserPythonBindings(PyObject* module) { |
24 | auto nvfuser = py::handle(module).cast<py::module>(); |
25 | |
26 | //! DataTypes supported by nvFuser in the FusionDefinition |
27 | py::enum_<Nvf::DataType>(nvfuser, "DataType" ) |
28 | .value("Double" , Nvf::DataType::Double) |
29 | .value("Float" , Nvf::DataType::Float) |
30 | .value("Half" , Nvf::DataType::Half) |
31 | .value("Int" , Nvf::DataType::Int) |
32 | .value("Int32" , Nvf::DataType::Int32) |
33 | .value("Bool" , Nvf::DataType::Bool) |
34 | .value("BFloat16" , Nvf::DataType::BFloat16) |
35 | .value("ComplexFloat" , Nvf::DataType::ComplexFloat) |
36 | .value("ComplexDouble" , Nvf::DataType::ComplexDouble) |
37 | .value("Null" , Nvf::DataType::Null); |
38 | |
39 | nvfuser.def( |
40 | "compute_contiguity" , |
41 | [](const std::vector<int64_t>& sizes, |
42 | const std::vector<int64_t>& strides) { |
43 | py::tuple contiguity(sizes.size()); |
44 | TORCH_CHECK( |
45 | sizes.size() == strides.size(), |
46 | "compute_contiguity: Sizes and strides must have the same number of dimensions" ); |
47 | if (sizes.size() == 0) { |
48 | return contiguity; |
49 | } |
50 | contiguity[sizes.size() - 1] = strides.back() == 1; |
51 | for (int64_t i = static_cast<int64_t>(sizes.size()) - 2; i >= 0; --i) { |
52 | contiguity[i] = strides[i] == strides[i + 1] * sizes[i + 1]; |
53 | } |
54 | return contiguity; |
55 | }); |
56 | |
57 | //! Binding the FusionCache that holds a cache of Fusions |
58 | //! This is only bound to provide an interface to get the number of fusions |
59 | //! that are cached. |
60 | py::class_<nvfuser::FusionCache> fusion_cache(nvfuser, "FusionCache" ); |
61 | fusion_cache |
62 | .def_static( |
63 | "get" , |
64 | &nvfuser::FusionCache::get, |
65 | py::arg("max_fusions" ) = int(8192), |
66 | py::return_value_policy::reference) |
67 | .def("num_fusions" , &nvfuser::FusionCache::numFusions) |
68 | .def("print_stats" , [](nvfuser::FusionCache& self) { |
69 | self.print(std::cout); |
70 | }); |
71 | |
72 | py::class_<nvfuser::FusionInterface> fusion(nvfuser, "Fusion" ); |
73 | fusion.def(py::init<>()) |
74 | .def(py::init<size_t>(), py::arg("fusion_id" )) |
75 | .def("define" , &nvfuser::FusionInterface::define) |
76 | .def("defined" , &nvfuser::FusionInterface::defined) |
77 | .def( |
78 | "execute" , |
79 | [](nvfuser::FusionInterface& self, const py::iterable& iter) { |
80 | std::vector<IValue> inputs; |
81 | for (py::handle obj : iter) { |
82 | inputs.push_back(toIValue(obj, c10::AnyType::get())); |
83 | } |
84 | return self.execute(inputs); |
85 | }, |
86 | py::return_value_policy::reference) |
87 | .def("id" , &nvfuser::FusionInterface::id) |
88 | .def("print" , &nvfuser::FusionInterface::print); |
89 | |
90 | //! These are the FusionDefinition supported object types that are either |
91 | //! defined as inputs or the output of an operation. |
92 | py::class_<nvfuser::Tensor>(nvfuser, "Tensor" ); |
93 | py::class_<nvfuser::Scalar>(nvfuser, "Scalar" ); |
94 | |
95 | //! The FusionDefinition is a context manager in Python where the user will |
96 | //! define the set the operations and connections between operations for |
97 | //! nvFuser to create. |
98 | py::class_<nvfuser::FusionDefinition> fusion_def(nvfuser, "FusionDefinition" ); |
99 | fusion_def |
100 | .def( |
101 | py::init<nvfuser::FusionInterface*, int>(), |
102 | py::arg("fusion" ), |
103 | py::arg("max_length" ) = int(1024)) |
104 | .def_readwrite("ops" , &nvfuser::FusionDefinition::ops) |
105 | .def( |
106 | "__enter__" , |
107 | [](nvfuser::FusionDefinition& self) -> nvfuser::FusionDefinition* { |
108 | // Instrumentation to mark the beginning of a FusionDefinition |
109 | Nvf::inst::Trace::instance()->beginEvent( |
110 | "FusionDefinition Context Manager" ); |
111 | return self.enter(); |
112 | }) |
113 | .def( |
114 | "__exit__" , |
115 | [](nvfuser::FusionDefinition& self, |
116 | void* exc_type, |
117 | void* exc_value, |
118 | void* traceback) { |
119 | self.exit(); |
120 | // Mark the end of a FusionDefinition Context Manager |
121 | Nvf::inst::Trace::instance()->endEvent(nullptr); |
122 | }) |
123 | .def( |
124 | "__str__" , |
125 | [](nvfuser::FusionDefinition& self) { |
126 | std::stringstream ss; |
127 | self.print(ss); |
128 | return ss.str(); |
129 | }) |
130 | .def( |
131 | "add_output" , |
132 | [](nvfuser::FusionDefinition& self, nvfuser::Scalar output) { |
133 | FUSER_PERF_SCOPE("FusionDefinition.add_output (scalar)" ); |
134 | self.defineRecord(new nvfuser::OutputRecord<Nvf::Val>( |
135 | {self.recordingState(output())})); |
136 | }) |
137 | .def( |
138 | "add_output" , |
139 | [](nvfuser::FusionDefinition& self, nvfuser::Tensor output) { |
140 | FUSER_PERF_SCOPE("FusionDefinition.add_output (tensor)" ); |
141 | self.defineRecord(new nvfuser::OutputRecord<Nvf::TensorView>( |
142 | {self.recordingState(output())})); |
143 | }) |
144 | .def( |
145 | "define_tensor" , |
146 | [](nvfuser::FusionDefinition& self, |
147 | size_t ndims, |
148 | Nvf::DataType dtype = Nvf::DataType::Float, |
149 | bool is_cpu = false) -> nvfuser::Tensor { |
150 | FUSER_PERF_SCOPE("FusionDefinition.define_tensor (simple)" ); |
151 | std::vector<int64_t> maybe_symbolic_sizes(ndims, -1); |
152 | ; |
153 | std::vector<bool> contig_info(ndims, false); |
154 | |
155 | nvfuser::Tensor out = self.defineTensor(); |
156 | self.defineRecord(new nvfuser::TensorRecord( |
157 | {self.recordingState(out())}, |
158 | std::move(maybe_symbolic_sizes), |
159 | std::move(contig_info), |
160 | dtype, |
161 | is_cpu)); |
162 | |
163 | return out; |
164 | }, |
165 | py::arg("ndims" ), |
166 | py::arg("dtype" ) = Nvf::DataType::Float, |
167 | py::arg("is_cpu" ) = false, |
168 | py::return_value_policy::reference) |
169 | .def( |
170 | "define_tensor" , |
171 | [](nvfuser::FusionDefinition& self, |
172 | std::vector<int64_t>& symbolic_sizes, |
173 | std::vector<bool>& contiguous, |
174 | Nvf::DataType dtype = Nvf::DataType::Float, |
175 | bool is_cpu = false) -> nvfuser::Tensor { |
176 | FUSER_PERF_SCOPE("FusionDefinition.define_tensor (default)" ); |
177 | |
178 | for (size_t i = 0; i < symbolic_sizes.size(); ++i) { |
179 | TORCH_CHECK( |
180 | symbolic_sizes[i] == -1 || symbolic_sizes[i] == 1, |
181 | "The value " , |
182 | symbolic_sizes[i], |
183 | " at index " , |
184 | i, |
185 | " was neither broadcast(1) or symbolic(-1)." ); |
186 | } |
187 | |
188 | nvfuser::Tensor out = self.defineTensor(); |
189 | self.defineRecord(new nvfuser::TensorRecord( |
190 | {self.recordingState(out())}, |
191 | symbolic_sizes, |
192 | contiguous, |
193 | dtype, |
194 | is_cpu)); |
195 | |
196 | return out; |
197 | }, |
198 | py::arg("symbolic_sizes" ), |
199 | py::arg("contiguous" ), |
200 | py::arg("dtype" ) = Nvf::DataType::Float, |
201 | py::arg("is_cpu" ) = false, |
202 | py::return_value_policy::reference) |
203 | .def( |
204 | "define_tensor" , |
205 | [](nvfuser::FusionDefinition& self, |
206 | std::vector<int64_t>& sizes, |
207 | std::vector<int64_t>& strides, |
208 | Nvf::DataType dtype = Nvf::DataType::Float, |
209 | bool is_cpu = false) -> nvfuser::Tensor { |
210 | FUSER_PERF_SCOPE("FusionDefinition.define_tensor (integration)" ); |
211 | TORCH_CHECK( |
212 | sizes.size() == strides.size(), |
213 | "The number of sizes does not match the number of strides." , |
214 | sizes.size(), |
215 | strides.size()); |
216 | |
217 | // TensorViewBuilder assumes any dim with a compile time constant |
218 | // size == 1 is a "maybe broadcast" axis, symbolic sizes are |
219 | // identified by -1, and size == 0 is not supported. |
220 | |
221 | // Translate to TensorViewBuilder's view of the world. |
222 | std::vector<int64_t> maybe_symbolic_sizes; |
223 | maybe_symbolic_sizes.reserve(sizes.size()); |
224 | for (const auto i : c10::irange(sizes.size())) { |
225 | TORCH_INTERNAL_ASSERT( |
226 | sizes[i] > 0, |
227 | "Size of " , |
228 | sizes[i], |
229 | " is not supported in nvFuser. Expected size > 0." ); |
230 | if (sizes[i] == 1) { |
231 | maybe_symbolic_sizes.push_back(1); |
232 | } else { |
233 | maybe_symbolic_sizes.push_back(-1); |
234 | } |
235 | } |
236 | |
237 | std::vector<bool> contig_info(strides.size(), false); |
238 | for (int i = contig_info.size() - 1; i >= 0; --i) { |
239 | if (i == static_cast<int>(contig_info.size() - 1)) { |
240 | contig_info[i] = (strides[i] == 1); |
241 | } else { |
242 | contig_info[i] = |
243 | (strides[i] == (strides[i + 1] * sizes[i + 1])); |
244 | } |
245 | } |
246 | |
247 | nvfuser::Tensor out = self.defineTensor(); |
248 | self.defineRecord(new nvfuser::TensorRecord( |
249 | {self.recordingState(out())}, |
250 | std::move(maybe_symbolic_sizes), |
251 | std::move(contig_info), |
252 | dtype, |
253 | is_cpu)); |
254 | |
255 | return out; |
256 | }, |
257 | py::arg("sizes" ), |
258 | py::arg("strides" ), |
259 | py::arg("dtype" ) = Nvf::DataType::Float, |
260 | py::arg("is_cpu" ) = false, |
261 | py::return_value_policy::reference) |
262 | .def( |
263 | "define_constant" , |
264 | [](nvfuser::FusionDefinition& self, double val) -> nvfuser::Scalar { |
265 | FUSER_PERF_SCOPE("FusionDefinition.define_constant (double)" ); |
266 | nvfuser::Scalar out = self.defineScalar(); |
267 | self.defineRecord(new nvfuser::ConstantRecord<Nvf::Double, double>( |
268 | {self.recordingState(out())}, val)); |
269 | return out; |
270 | }, |
271 | py::return_value_policy::reference) |
272 | .def( |
273 | "define_constant" , |
274 | [](nvfuser::FusionDefinition& self, |
275 | std::complex<double> val) -> nvfuser::Scalar { |
276 | FUSER_PERF_SCOPE("FusionDefinition.define_constant (complex)" ); |
277 | nvfuser::Scalar out = self.defineScalar(); |
278 | self.defineRecord( |
279 | new nvfuser:: |
280 | ConstantRecord<Nvf::ComplexDouble, c10::complex<double>>( |
281 | {self.recordingState(out())}, |
282 | static_cast<c10::complex<double>>(val))); |
283 | return out; |
284 | }, |
285 | py::return_value_policy::reference) |
286 | .def( |
287 | "define_constant" , |
288 | [](nvfuser::FusionDefinition& self, bool val) -> nvfuser::Scalar { |
289 | FUSER_PERF_SCOPE("FusionDefinition.define_constant (bool)" ); |
290 | nvfuser::Scalar out = self.defineScalar(); |
291 | self.defineRecord(new nvfuser::ConstantRecord<Nvf::Bool, bool>( |
292 | {self.recordingState(out())}, val)); |
293 | return out; |
294 | }, |
295 | py::return_value_policy::reference) |
296 | .def( |
297 | "define_constant" , |
298 | [](nvfuser::FusionDefinition& self, int64_t val) -> nvfuser::Scalar { |
299 | FUSER_PERF_SCOPE("FusionDefinition.define_constant (int)" ); |
300 | nvfuser::Scalar out = self.defineScalar(); |
301 | self.defineRecord(new nvfuser::ConstantRecord<Nvf::Int, int64_t>( |
302 | {self.recordingState(out())}, val)); |
303 | return out; |
304 | }, |
305 | py::return_value_policy::reference) |
306 | .def( |
307 | "define_scalar" , |
308 | [](nvfuser::FusionDefinition& self, |
309 | Nvf::DataType dtype = Nvf::DataType::Double) -> nvfuser::Scalar { |
310 | FUSER_PERF_SCOPE("FusionDefinition.define_scalar" ); |
311 | nvfuser::Scalar out = self.defineScalar(); |
312 | self.defineRecord( |
313 | new nvfuser::ScalarRecord({self.recordingState(out())}, dtype)); |
314 | return out; |
315 | }, |
316 | py::arg("dtype" ) = Nvf::DataType::Double, |
317 | py::return_value_policy::reference); |
318 | |
319 | //! The Operators class is a nested class of FusionDefinition to allow the |
320 | //! user to query the class for the list of operators. |
321 | //! |
322 | //! Example: |
323 | //! help(FusionDefinition.Operators) |
324 | //! |
325 | //! Additional operators are expected to be defined below as needed. They |
326 | //! may require defining a new RecordFunctor child class if they are unique. |
327 | py::class_<nvfuser::FusionDefinition::Operators> nvf_ops( |
328 | fusion_def, "Operators" ); |
329 | nvf_ops.def(py::init<nvfuser::FusionDefinition*>()); |
330 | |
331 | // ******************** INSERT OP BINDINGS BELOW HERE ******************** |
332 | #define OP_PREFIX "Operators." |
333 | #define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ |
334 | nvf_ops.def( \ |
335 | op_str, \ |
336 | [](nvfuser::FusionDefinition::Operators& self, \ |
337 | nvfuser::Tensor input) -> nvfuser::Tensor { \ |
338 | FUSER_PERF_SCOPE("Operators." op_str); \ |
339 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
340 | nvfuser::Tensor output = fd->defineTensor(); \ |
341 | fd->defineRecord( \ |
342 | new nvfuser::OpRecord<Nvf::TensorView*, Nvf::TensorView*>( \ |
343 | {fd->recordingState(input())}, \ |
344 | {fd->recordingState(output())}, \ |
345 | ("ops." op_str), \ |
346 | static_cast<Nvf::TensorView* (*)(Nvf::TensorView*)>( \ |
347 | Nvf::op_name))); \ |
348 | return output; \ |
349 | }, \ |
350 | py::return_value_policy::reference); \ |
351 | nvf_ops.def( \ |
352 | op_str, \ |
353 | [](nvfuser::FusionDefinition::Operators& self, \ |
354 | nvfuser::Scalar input) -> nvfuser::Scalar { \ |
355 | FUSER_PERF_SCOPE("Operators." op_str); \ |
356 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
357 | nvfuser::Scalar output = fd->defineScalar(); \ |
358 | fd->defineRecord(new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*>( \ |
359 | {fd->recordingState(input())}, \ |
360 | {fd->recordingState(output())}, \ |
361 | ("ops." op_str), \ |
362 | static_cast<Nvf::Val* (*)(Nvf::Val*)>(Nvf::op_name))); \ |
363 | return output; \ |
364 | }, \ |
365 | py::return_value_policy::reference); |
366 | |
367 | NVFUSER_PYTHON_BINDING_UNARY_OP("abs" , abs) |
368 | NVFUSER_PYTHON_BINDING_UNARY_OP("acos" , acos) |
369 | NVFUSER_PYTHON_BINDING_UNARY_OP("asin" , asin) |
370 | NVFUSER_PYTHON_BINDING_UNARY_OP("atan" , atan) |
371 | NVFUSER_PYTHON_BINDING_UNARY_OP("atanh" , atanh) |
372 | NVFUSER_PYTHON_BINDING_UNARY_OP("ceil" , ceil) |
373 | NVFUSER_PYTHON_BINDING_UNARY_OP("cos" , cos) |
374 | NVFUSER_PYTHON_BINDING_UNARY_OP("cosh" , cosh) |
375 | NVFUSER_PYTHON_BINDING_UNARY_OP("exp" , exp) |
376 | NVFUSER_PYTHON_BINDING_UNARY_OP("expm1" , expm1) |
377 | NVFUSER_PYTHON_BINDING_UNARY_OP("erf" , erf) |
378 | NVFUSER_PYTHON_BINDING_UNARY_OP("erfc" , erfc) |
379 | NVFUSER_PYTHON_BINDING_UNARY_OP("floor" , floor) |
380 | NVFUSER_PYTHON_BINDING_UNARY_OP("frac" , frac) |
381 | NVFUSER_PYTHON_BINDING_UNARY_OP("lgamma" , lgamma) |
382 | NVFUSER_PYTHON_BINDING_UNARY_OP("log" , log) |
383 | NVFUSER_PYTHON_BINDING_UNARY_OP("log10" , log10) |
384 | NVFUSER_PYTHON_BINDING_UNARY_OP("log1p" , log1p) |
385 | NVFUSER_PYTHON_BINDING_UNARY_OP("log2" , log2) |
386 | NVFUSER_PYTHON_BINDING_UNARY_OP("neg" , neg) |
387 | NVFUSER_PYTHON_BINDING_UNARY_OP("bitwise_not" , bitwise_not) |
388 | NVFUSER_PYTHON_BINDING_UNARY_OP("relu" , relu) |
389 | NVFUSER_PYTHON_BINDING_UNARY_OP("rand_like" , rand_like) |
390 | NVFUSER_PYTHON_BINDING_UNARY_OP("reciprocal" , reciprocal) |
391 | NVFUSER_PYTHON_BINDING_UNARY_OP("round" , round) |
392 | NVFUSER_PYTHON_BINDING_UNARY_OP("rsqrt" , rsqrt) |
393 | NVFUSER_PYTHON_BINDING_UNARY_OP("set" , set) |
394 | NVFUSER_PYTHON_BINDING_UNARY_OP("sign" , sign) |
395 | NVFUSER_PYTHON_BINDING_UNARY_OP("sigmoid" , sigmoid) |
396 | NVFUSER_PYTHON_BINDING_UNARY_OP("silu" , silu) |
397 | NVFUSER_PYTHON_BINDING_UNARY_OP("sin" , sin) |
398 | NVFUSER_PYTHON_BINDING_UNARY_OP("sinh" , sinh) |
399 | NVFUSER_PYTHON_BINDING_UNARY_OP("sqrt" , sqrt) |
400 | NVFUSER_PYTHON_BINDING_UNARY_OP("tan" , tan) |
401 | NVFUSER_PYTHON_BINDING_UNARY_OP("tanh" , tanh) |
402 | NVFUSER_PYTHON_BINDING_UNARY_OP("trunc" , trunc) |
403 | NVFUSER_PYTHON_BINDING_UNARY_OP("isfinite" , isfinite) |
404 | NVFUSER_PYTHON_BINDING_UNARY_OP("isinf" , isinf) |
405 | NVFUSER_PYTHON_BINDING_UNARY_OP("isnan" , isnan) |
406 | NVFUSER_PYTHON_BINDING_UNARY_OP("isneginf" , isneginf) |
407 | NVFUSER_PYTHON_BINDING_UNARY_OP("isposinf" , isposinf) |
408 | NVFUSER_PYTHON_BINDING_UNARY_OP("isreal" , isreal) |
409 | NVFUSER_PYTHON_BINDING_UNARY_OP("real" , real) |
410 | NVFUSER_PYTHON_BINDING_UNARY_OP("imag" , imag) |
411 | #undef NVFUSER_PYTHON_BINDING_UNARY_OP |
412 | |
413 | #define NVFUSER_PYTHON_BINDING_BINARY_OP(op_str, op_name) \ |
414 | nvf_ops.def( \ |
415 | op_str, \ |
416 | [](nvfuser::FusionDefinition::Operators& self, \ |
417 | nvfuser::Tensor arg1, \ |
418 | nvfuser::Tensor arg2) -> nvfuser::Tensor { \ |
419 | FUSER_PERF_SCOPE("Operators." op_str); \ |
420 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
421 | nvfuser::Tensor output = fd->defineTensor(); \ |
422 | fd->defineRecord(new nvfuser::OpRecord< \ |
423 | Nvf::TensorView*, \ |
424 | Nvf::TensorView*, \ |
425 | Nvf::TensorView*>( \ |
426 | {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ |
427 | {fd->recordingState(output())}, \ |
428 | ("ops." op_str), \ |
429 | static_cast< \ |
430 | Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*)>( \ |
431 | Nvf::op_name))); \ |
432 | return output; \ |
433 | }, \ |
434 | py::return_value_policy::reference); \ |
435 | nvf_ops.def( \ |
436 | op_str, \ |
437 | [](nvfuser::FusionDefinition::Operators& self, \ |
438 | nvfuser::Tensor arg1, \ |
439 | nvfuser::Scalar arg2) -> nvfuser::Tensor { \ |
440 | FUSER_PERF_SCOPE("Operators." op_str); \ |
441 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
442 | nvfuser::Tensor output = fd->defineTensor(); \ |
443 | fd->defineRecord(new nvfuser::OpRecord< \ |
444 | Nvf::TensorView*, \ |
445 | Nvf::TensorView*, \ |
446 | Nvf::Val*>( \ |
447 | {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ |
448 | {fd->recordingState(output())}, \ |
449 | ("ops." op_str), \ |
450 | static_cast<Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*)>( \ |
451 | Nvf::op_name))); \ |
452 | return output; \ |
453 | }, \ |
454 | py::return_value_policy::reference); \ |
455 | nvf_ops.def( \ |
456 | op_str, \ |
457 | [](nvfuser::FusionDefinition::Operators& self, \ |
458 | nvfuser::Scalar arg1, \ |
459 | nvfuser::Tensor arg2) -> nvfuser::Tensor { \ |
460 | FUSER_PERF_SCOPE("Operators." op_str); \ |
461 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
462 | nvfuser::Tensor output = fd->defineTensor(); \ |
463 | fd->defineRecord(new nvfuser::OpRecord< \ |
464 | Nvf::TensorView*, \ |
465 | Nvf::Val*, \ |
466 | Nvf::TensorView*>( \ |
467 | {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ |
468 | {fd->recordingState(output())}, \ |
469 | ("ops." op_str), \ |
470 | static_cast<Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*)>( \ |
471 | Nvf::op_name))); \ |
472 | return output; \ |
473 | }, \ |
474 | py::return_value_policy::reference); \ |
475 | nvf_ops.def( \ |
476 | op_str, \ |
477 | [](nvfuser::FusionDefinition::Operators& self, \ |
478 | nvfuser::Scalar arg1, \ |
479 | nvfuser::Scalar arg2) -> nvfuser::Scalar { \ |
480 | FUSER_PERF_SCOPE("Operators." op_str); \ |
481 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
482 | nvfuser::Scalar output = fd->defineScalar(); \ |
483 | fd->defineRecord( \ |
484 | new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*>( \ |
485 | {fd->recordingState(arg1()), fd->recordingState(arg2())}, \ |
486 | {fd->recordingState(output())}, \ |
487 | ("ops." op_str), \ |
488 | static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*)>( \ |
489 | Nvf::op_name))); \ |
490 | return output; \ |
491 | }, \ |
492 | py::return_value_policy::reference); |
493 | |
494 | NVFUSER_PYTHON_BINDING_BINARY_OP("add" , add) |
495 | NVFUSER_PYTHON_BINDING_BINARY_OP("atan2" , atan2) |
496 | NVFUSER_PYTHON_BINDING_BINARY_OP("div" , div) |
497 | NVFUSER_PYTHON_BINDING_BINARY_OP("fmod" , fmod) |
498 | NVFUSER_PYTHON_BINDING_BINARY_OP("mul" , mul) |
499 | NVFUSER_PYTHON_BINDING_BINARY_OP("pow" , pow) |
500 | NVFUSER_PYTHON_BINDING_BINARY_OP("remainder" , remainder) |
501 | NVFUSER_PYTHON_BINDING_BINARY_OP("sub" , sub) |
502 | NVFUSER_PYTHON_BINDING_BINARY_OP("mod" , mod) |
503 | NVFUSER_PYTHON_BINDING_BINARY_OP("eq" , eq) |
504 | NVFUSER_PYTHON_BINDING_BINARY_OP("ge" , ge) |
505 | NVFUSER_PYTHON_BINDING_BINARY_OP("gt" , gt) |
506 | NVFUSER_PYTHON_BINDING_BINARY_OP("le" , le) |
507 | NVFUSER_PYTHON_BINDING_BINARY_OP("lt" , lt) |
508 | NVFUSER_PYTHON_BINDING_BINARY_OP("ne" , ne) |
509 | NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_and" , bitwise_and) |
510 | NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_or" , bitwise_or) |
511 | NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_xor" , bitwise_xor) |
512 | NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_left_shift" , bitwise_left_shift) |
513 | NVFUSER_PYTHON_BINDING_BINARY_OP("bitwise_right_shift" , bitwise_left_shift) |
514 | #undef NVFUSER_PYTHON_BINDING_BINARY_OP |
515 | |
516 | #define NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP(op_str, op_name) \ |
517 | nvf_ops.def( \ |
518 | op_str, \ |
519 | [](nvfuser::FusionDefinition::Operators& self, \ |
520 | nvfuser::Tensor arg1, \ |
521 | nvfuser::Tensor arg2, \ |
522 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
523 | FUSER_PERF_SCOPE("Operators." op_str); \ |
524 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
525 | nvfuser::Tensor output = fd->defineTensor(); \ |
526 | fd->defineRecord(new nvfuser::OpRecord< \ |
527 | Nvf::TensorView*, \ |
528 | Nvf::TensorView*, \ |
529 | Nvf::TensorView*, \ |
530 | Nvf::Val*>( \ |
531 | {fd->recordingState(arg1()), \ |
532 | fd->recordingState(arg2()), \ |
533 | fd->recordingState(arg3())}, \ |
534 | {fd->recordingState(output())}, \ |
535 | ("ops." op_str), \ |
536 | static_cast< \ |
537 | Nvf:: \ |
538 | TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \ |
539 | Nvf::op_name))); \ |
540 | return output; \ |
541 | }, \ |
542 | py::return_value_policy::reference); \ |
543 | nvf_ops.def( \ |
544 | op_str, \ |
545 | [](nvfuser::FusionDefinition::Operators& self, \ |
546 | nvfuser::Tensor arg1, \ |
547 | nvfuser::Scalar arg2, \ |
548 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
549 | FUSER_PERF_SCOPE("Operators." op_str); \ |
550 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
551 | nvfuser::Tensor output = fd->defineTensor(); \ |
552 | fd->defineRecord(new nvfuser::OpRecord< \ |
553 | Nvf::TensorView*, \ |
554 | Nvf::TensorView*, \ |
555 | Nvf::Val*, \ |
556 | Nvf::Val*>( \ |
557 | {fd->recordingState(arg1()), \ |
558 | fd->recordingState(arg2()), \ |
559 | fd->recordingState(arg3())}, \ |
560 | {fd->recordingState(output())}, \ |
561 | ("ops." op_str), \ |
562 | static_cast< \ |
563 | Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \ |
564 | Nvf::op_name))); \ |
565 | return output; \ |
566 | }, \ |
567 | py::return_value_policy::reference); \ |
568 | nvf_ops.def( \ |
569 | op_str, \ |
570 | [](nvfuser::FusionDefinition::Operators& self, \ |
571 | nvfuser::Scalar arg1, \ |
572 | nvfuser::Tensor arg2, \ |
573 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
574 | FUSER_PERF_SCOPE("Operators." op_str); \ |
575 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
576 | nvfuser::Tensor output = fd->defineTensor(); \ |
577 | fd->defineRecord(new nvfuser::OpRecord< \ |
578 | Nvf::TensorView*, \ |
579 | Nvf::Val*, \ |
580 | Nvf::TensorView*, \ |
581 | Nvf::Val*>( \ |
582 | {fd->recordingState(arg1()), \ |
583 | fd->recordingState(arg2()), \ |
584 | fd->recordingState(arg3())}, \ |
585 | {fd->recordingState(output())}, \ |
586 | ("ops." op_str), \ |
587 | static_cast< \ |
588 | Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \ |
589 | Nvf::op_name))); \ |
590 | return output; \ |
591 | }, \ |
592 | py::return_value_policy::reference); \ |
593 | nvf_ops.def( \ |
594 | op_str, \ |
595 | [](nvfuser::FusionDefinition::Operators& self, \ |
596 | nvfuser::Scalar arg1, \ |
597 | nvfuser::Scalar arg2, \ |
598 | nvfuser::Scalar arg3) -> nvfuser::Scalar { \ |
599 | FUSER_PERF_SCOPE("Operators." op_str); \ |
600 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
601 | nvfuser::Scalar output = fd->defineScalar(); \ |
602 | fd->defineRecord( \ |
603 | new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \ |
604 | {fd->recordingState(arg1()), \ |
605 | fd->recordingState(arg2()), \ |
606 | fd->recordingState(arg3())}, \ |
607 | {fd->recordingState(output())}, \ |
608 | ("ops." op_str), \ |
609 | static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \ |
610 | Nvf::op_name))); \ |
611 | return output; \ |
612 | }, \ |
613 | py::return_value_policy::reference); |
614 | |
615 | NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("add_alpha" , add_alpha) |
616 | NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP("sub_alpha" , sub_alpha) |
617 | #undef NVFUSER_PYTHON_BINDING_BINARY_WITH_ALPHA_OP |
618 | |
619 | #define NVFUSER_PYTHON_BINDING_TERNARY_OP(op_str, op_name) \ |
620 | nvf_ops.def( \ |
621 | op_str, \ |
622 | [](nvfuser::FusionDefinition::Operators& self, \ |
623 | nvfuser::Scalar arg1, \ |
624 | nvfuser::Scalar arg2, \ |
625 | nvfuser::Scalar arg3) -> nvfuser::Scalar { \ |
626 | FUSER_PERF_SCOPE("Operators." op_str); \ |
627 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
628 | nvfuser::Scalar output = fd->defineScalar(); \ |
629 | fd->defineRecord( \ |
630 | new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \ |
631 | {fd->recordingState(arg1()), \ |
632 | fd->recordingState(arg2()), \ |
633 | fd->recordingState(arg3())}, \ |
634 | {fd->recordingState(output())}, \ |
635 | ("ops." op_str), \ |
636 | static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \ |
637 | Nvf::op_name))); \ |
638 | return output; \ |
639 | }, \ |
640 | py::return_value_policy::reference); \ |
641 | nvf_ops.def( \ |
642 | op_str, \ |
643 | [](nvfuser::FusionDefinition::Operators& self, \ |
644 | nvfuser::Tensor arg1, \ |
645 | nvfuser::Tensor arg2, \ |
646 | nvfuser::Tensor arg3) -> nvfuser::Tensor { \ |
647 | FUSER_PERF_SCOPE("Operators." op_str); \ |
648 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
649 | nvfuser::Tensor output = fd->defineTensor(); \ |
650 | fd->defineRecord(new nvfuser::OpRecord< \ |
651 | Nvf::TensorView*, \ |
652 | Nvf::TensorView*, \ |
653 | Nvf::TensorView*, \ |
654 | Nvf::TensorView*>( \ |
655 | {fd->recordingState(arg1()), \ |
656 | fd->recordingState(arg2()), \ |
657 | fd->recordingState(arg3())}, \ |
658 | {fd->recordingState(output())}, \ |
659 | ("ops." op_str), \ |
660 | static_cast< \ |
661 | Nvf:: \ |
662 | TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*)>( \ |
663 | Nvf::op_name))); \ |
664 | return output; \ |
665 | }, \ |
666 | py::return_value_policy::reference); \ |
667 | nvf_ops.def( \ |
668 | op_str, \ |
669 | [](nvfuser::FusionDefinition::Operators& self, \ |
670 | nvfuser::Tensor arg1, \ |
671 | nvfuser::Tensor arg2, \ |
672 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
673 | FUSER_PERF_SCOPE("Operators." op_str); \ |
674 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
675 | nvfuser::Tensor output = fd->defineTensor(); \ |
676 | fd->defineRecord(new nvfuser::OpRecord< \ |
677 | Nvf::TensorView*, \ |
678 | Nvf::TensorView*, \ |
679 | Nvf::TensorView*, \ |
680 | Nvf::Val*>( \ |
681 | {fd->recordingState(arg1()), \ |
682 | fd->recordingState(arg2()), \ |
683 | fd->recordingState(arg3())}, \ |
684 | {fd->recordingState(output())}, \ |
685 | ("ops." op_str), \ |
686 | static_cast< \ |
687 | Nvf:: \ |
688 | TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \ |
689 | Nvf::op_name))); \ |
690 | return output; \ |
691 | }, \ |
692 | py::return_value_policy::reference); \ |
693 | nvf_ops.def( \ |
694 | op_str, \ |
695 | [](nvfuser::FusionDefinition::Operators& self, \ |
696 | nvfuser::Tensor arg1, \ |
697 | nvfuser::Scalar arg2, \ |
698 | nvfuser::Tensor arg3) -> nvfuser::Tensor { \ |
699 | FUSER_PERF_SCOPE("Operators." op_str); \ |
700 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
701 | nvfuser::Tensor output = fd->defineTensor(); \ |
702 | fd->defineRecord(new nvfuser::OpRecord< \ |
703 | Nvf::TensorView*, \ |
704 | Nvf::TensorView*, \ |
705 | Nvf::Val*, \ |
706 | Nvf::TensorView*>( \ |
707 | {fd->recordingState(arg1()), \ |
708 | fd->recordingState(arg2()), \ |
709 | fd->recordingState(arg3())}, \ |
710 | {fd->recordingState(output())}, \ |
711 | ("ops." op_str), \ |
712 | static_cast< \ |
713 | Nvf:: \ |
714 | TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*)>( \ |
715 | Nvf::op_name))); \ |
716 | return output; \ |
717 | }, \ |
718 | py::return_value_policy::reference); \ |
719 | nvf_ops.def( \ |
720 | op_str, \ |
721 | [](nvfuser::FusionDefinition::Operators& self, \ |
722 | nvfuser::Scalar arg1, \ |
723 | nvfuser::Tensor arg2, \ |
724 | nvfuser::Tensor arg3) -> nvfuser::Tensor { \ |
725 | FUSER_PERF_SCOPE("Operators." op_str); \ |
726 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
727 | nvfuser::Tensor output = fd->defineTensor(); \ |
728 | fd->defineRecord(new nvfuser::OpRecord< \ |
729 | Nvf::TensorView*, \ |
730 | Nvf::Val*, \ |
731 | Nvf::TensorView*, \ |
732 | Nvf::TensorView*>( \ |
733 | {fd->recordingState(arg1()), \ |
734 | fd->recordingState(arg2()), \ |
735 | fd->recordingState(arg3())}, \ |
736 | {fd->recordingState(output())}, \ |
737 | ("ops." op_str), \ |
738 | static_cast< \ |
739 | Nvf:: \ |
740 | TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*)>( \ |
741 | Nvf::op_name))); \ |
742 | return output; \ |
743 | }, \ |
744 | py::return_value_policy::reference); \ |
745 | nvf_ops.def( \ |
746 | op_str, \ |
747 | [](nvfuser::FusionDefinition::Operators& self, \ |
748 | nvfuser::Scalar arg1, \ |
749 | nvfuser::Scalar arg2, \ |
750 | nvfuser::Tensor arg3) -> nvfuser::Tensor { \ |
751 | FUSER_PERF_SCOPE("Operators." op_str); \ |
752 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
753 | nvfuser::Tensor output = fd->defineTensor(); \ |
754 | fd->defineRecord(new nvfuser::OpRecord< \ |
755 | Nvf::TensorView*, \ |
756 | Nvf::Val*, \ |
757 | Nvf::Val*, \ |
758 | Nvf::TensorView*>( \ |
759 | {fd->recordingState(arg1()), \ |
760 | fd->recordingState(arg2()), \ |
761 | fd->recordingState(arg3())}, \ |
762 | {fd->recordingState(output())}, \ |
763 | ("ops." op_str), \ |
764 | static_cast< \ |
765 | Nvf::TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*)>( \ |
766 | Nvf::op_name))); \ |
767 | return output; \ |
768 | }, \ |
769 | py::return_value_policy::reference); \ |
770 | nvf_ops.def( \ |
771 | op_str, \ |
772 | [](nvfuser::FusionDefinition::Operators& self, \ |
773 | nvfuser::Tensor arg1, \ |
774 | nvfuser::Scalar arg2, \ |
775 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
776 | FUSER_PERF_SCOPE("Operators." op_str); \ |
777 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
778 | nvfuser::Tensor output = fd->defineTensor(); \ |
779 | fd->defineRecord(new nvfuser::OpRecord< \ |
780 | Nvf::TensorView*, \ |
781 | Nvf::TensorView*, \ |
782 | Nvf::Val*, \ |
783 | Nvf::Val*>( \ |
784 | {fd->recordingState(arg1()), \ |
785 | fd->recordingState(arg2()), \ |
786 | fd->recordingState(arg3())}, \ |
787 | {fd->recordingState(output())}, \ |
788 | ("ops." op_str), \ |
789 | static_cast< \ |
790 | Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \ |
791 | Nvf::op_name))); \ |
792 | return output; \ |
793 | }, \ |
794 | py::return_value_policy::reference); \ |
795 | nvf_ops.def( \ |
796 | op_str, \ |
797 | [](nvfuser::FusionDefinition::Operators& self, \ |
798 | nvfuser::Scalar arg1, \ |
799 | nvfuser::Tensor arg2, \ |
800 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
801 | FUSER_PERF_SCOPE("Operators." op_str); \ |
802 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
803 | nvfuser::Tensor output = fd->defineTensor(); \ |
804 | fd->defineRecord(new nvfuser::OpRecord< \ |
805 | Nvf::TensorView*, \ |
806 | Nvf::Val*, \ |
807 | Nvf::TensorView*, \ |
808 | Nvf::Val*>( \ |
809 | {fd->recordingState(arg1()), \ |
810 | fd->recordingState(arg2()), \ |
811 | fd->recordingState(arg3())}, \ |
812 | {fd->recordingState(output())}, \ |
813 | ("ops." op_str), \ |
814 | static_cast< \ |
815 | Nvf::TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \ |
816 | Nvf::op_name))); \ |
817 | return output; \ |
818 | }, \ |
819 | py::return_value_policy::reference); |
820 | |
821 | NVFUSER_PYTHON_BINDING_TERNARY_OP("lerp" , lerp) |
822 | NVFUSER_PYTHON_BINDING_TERNARY_OP("where" , where) |
823 | #undef NVFUSER_PYTHON_BINDING_TERNARY_OP |
824 | |
825 | #define NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP(op_str, op_name) \ |
826 | nvf_ops.def( \ |
827 | op_str, \ |
828 | [](nvfuser::FusionDefinition::Operators& self, \ |
829 | nvfuser::Scalar arg1, \ |
830 | nvfuser::Scalar arg2, \ |
831 | nvfuser::Scalar arg3) -> nvfuser::Scalar { \ |
832 | FUSER_PERF_SCOPE("Operators." op_str); \ |
833 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
834 | nvfuser::Scalar output = fd->defineScalar(); \ |
835 | fd->defineRecord( \ |
836 | new nvfuser::OpRecord<Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*>( \ |
837 | {fd->recordingState(arg1()), \ |
838 | fd->recordingState(arg2()), \ |
839 | fd->recordingState(arg3())}, \ |
840 | {fd->recordingState(output())}, \ |
841 | ("ops." op_str), \ |
842 | static_cast<Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \ |
843 | Nvf::op_name))); \ |
844 | return output; \ |
845 | }, \ |
846 | py::return_value_policy::reference); \ |
847 | nvf_ops.def( \ |
848 | op_str, \ |
849 | [](nvfuser::FusionDefinition::Operators& self, \ |
850 | nvfuser::Tensor arg1, \ |
851 | nvfuser::Scalar arg2, \ |
852 | nvfuser::Scalar arg3) -> nvfuser::Tensor { \ |
853 | FUSER_PERF_SCOPE("Operators." op_str); \ |
854 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
855 | nvfuser::Tensor output = fd->defineTensor(); \ |
856 | fd->defineRecord(new nvfuser::OpRecord< \ |
857 | Nvf::TensorView*, \ |
858 | Nvf::TensorView*, \ |
859 | Nvf::Val*, \ |
860 | Nvf::Val*>( \ |
861 | {fd->recordingState(arg1()), \ |
862 | fd->recordingState(arg2()), \ |
863 | fd->recordingState(arg3())}, \ |
864 | {fd->recordingState(output())}, \ |
865 | ("ops." op_str), \ |
866 | static_cast< \ |
867 | Nvf::TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \ |
868 | Nvf::op_name))); \ |
869 | return output; \ |
870 | }, \ |
871 | py::return_value_policy::reference); |
872 | |
873 | NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("clamp" , clamp) |
874 | NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP("threshold" , threshold) |
875 | #undef NVFUSER_PYTHON_BINDING_THRESHOLD_LIKE_OP |
876 | |
877 | #define NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP(op_str, op_name) \ |
878 | nvf_ops.def( \ |
879 | op_str, \ |
880 | [](nvfuser::FusionDefinition::Operators& self, \ |
881 | nvfuser::Scalar arg1, \ |
882 | nvfuser::Scalar arg2, \ |
883 | nvfuser::Scalar arg3, \ |
884 | nvfuser::Scalar arg4) -> nvfuser::Scalar { \ |
885 | FUSER_PERF_SCOPE("Operators." op_str); \ |
886 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
887 | nvfuser::Scalar output = fd->defineScalar(); \ |
888 | fd->defineRecord(new nvfuser::OpRecord< \ |
889 | Nvf::Val*, \ |
890 | Nvf::Val*, \ |
891 | Nvf::Val*, \ |
892 | Nvf::Val*, \ |
893 | Nvf::Val*>( \ |
894 | {fd->recordingState(arg1()), \ |
895 | fd->recordingState(arg2()), \ |
896 | fd->recordingState(arg3()), \ |
897 | fd->recordingState(arg4())}, \ |
898 | {fd->recordingState(output())}, \ |
899 | ("ops." op_str), \ |
900 | static_cast< \ |
901 | Nvf::Val* (*)(Nvf::Val*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \ |
902 | Nvf::op_name))); \ |
903 | return output; \ |
904 | }, \ |
905 | py::return_value_policy::reference); \ |
906 | nvf_ops.def( \ |
907 | op_str, \ |
908 | [](nvfuser::FusionDefinition::Operators& self, \ |
909 | nvfuser::Tensor arg1, \ |
910 | nvfuser::Tensor arg2, \ |
911 | nvfuser::Tensor arg3, \ |
912 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
913 | FUSER_PERF_SCOPE("Operators." op_str); \ |
914 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
915 | nvfuser::Tensor output = fd->defineTensor(); \ |
916 | fd->defineRecord(new nvfuser::OpRecord< \ |
917 | Nvf::TensorView*, \ |
918 | Nvf::TensorView*, \ |
919 | Nvf::TensorView*, \ |
920 | Nvf::TensorView*, \ |
921 | Nvf::TensorView*>( \ |
922 | {fd->recordingState(arg1()), \ |
923 | fd->recordingState(arg2()), \ |
924 | fd->recordingState(arg3()), \ |
925 | fd->recordingState(arg4())}, \ |
926 | {fd->recordingState(output())}, \ |
927 | ("ops." op_str), \ |
928 | static_cast< \ |
929 | Nvf:: \ |
930 | TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \ |
931 | Nvf::op_name))); \ |
932 | return output; \ |
933 | }, \ |
934 | py::return_value_policy::reference); \ |
935 | nvf_ops.def( \ |
936 | op_str, \ |
937 | [](nvfuser::FusionDefinition::Operators& self, \ |
938 | nvfuser::Tensor arg1, \ |
939 | nvfuser::Tensor arg2, \ |
940 | nvfuser::Scalar arg3, \ |
941 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
942 | FUSER_PERF_SCOPE("Operators." op_str); \ |
943 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
944 | nvfuser::Tensor output = fd->defineTensor(); \ |
945 | fd->defineRecord(new nvfuser::OpRecord< \ |
946 | Nvf::TensorView*, \ |
947 | Nvf::TensorView*, \ |
948 | Nvf::TensorView*, \ |
949 | Nvf::Val*, \ |
950 | Nvf::Val*>( \ |
951 | {fd->recordingState(arg1()), \ |
952 | fd->recordingState(arg2()), \ |
953 | fd->recordingState(arg3()), \ |
954 | fd->recordingState(arg4())}, \ |
955 | {fd->recordingState(output())}, \ |
956 | ("ops." op_str), \ |
957 | static_cast< \ |
958 | Nvf:: \ |
959 | TensorView* (*)(Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \ |
960 | Nvf::op_name))); \ |
961 | return output; \ |
962 | }, \ |
963 | py::return_value_policy::reference); \ |
964 | nvf_ops.def( \ |
965 | op_str, \ |
966 | [](nvfuser::FusionDefinition::Operators& self, \ |
967 | nvfuser::Tensor arg1, \ |
968 | nvfuser::Scalar arg2, \ |
969 | nvfuser::Tensor arg3, \ |
970 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
971 | FUSER_PERF_SCOPE("Operators." op_str); \ |
972 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
973 | nvfuser::Tensor output = fd->defineTensor(); \ |
974 | fd->defineRecord(new nvfuser::OpRecord< \ |
975 | Nvf::TensorView*, \ |
976 | Nvf::TensorView*, \ |
977 | Nvf::Val*, \ |
978 | Nvf::TensorView*, \ |
979 | Nvf::Val*>( \ |
980 | {fd->recordingState(arg1()), \ |
981 | fd->recordingState(arg2()), \ |
982 | fd->recordingState(arg3()), \ |
983 | fd->recordingState(arg4())}, \ |
984 | {fd->recordingState(output())}, \ |
985 | ("ops." op_str), \ |
986 | static_cast< \ |
987 | Nvf:: \ |
988 | TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \ |
989 | Nvf::op_name))); \ |
990 | return output; \ |
991 | }, \ |
992 | py::return_value_policy::reference); \ |
993 | nvf_ops.def( \ |
994 | op_str, \ |
995 | [](nvfuser::FusionDefinition::Operators& self, \ |
996 | nvfuser::Scalar arg1, \ |
997 | nvfuser::Tensor arg2, \ |
998 | nvfuser::Tensor arg3, \ |
999 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
1000 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1001 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1002 | nvfuser::Tensor output = fd->defineTensor(); \ |
1003 | fd->defineRecord(new nvfuser::OpRecord< \ |
1004 | Nvf::TensorView*, \ |
1005 | Nvf::Val*, \ |
1006 | Nvf::TensorView*, \ |
1007 | Nvf::TensorView*, \ |
1008 | Nvf::Val*>( \ |
1009 | {fd->recordingState(arg1()), \ |
1010 | fd->recordingState(arg2()), \ |
1011 | fd->recordingState(arg3()), \ |
1012 | fd->recordingState(arg4())}, \ |
1013 | {fd->recordingState(output())}, \ |
1014 | ("ops." op_str), \ |
1015 | static_cast< \ |
1016 | Nvf:: \ |
1017 | TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::TensorView*, Nvf::Val*)>( \ |
1018 | Nvf::op_name))); \ |
1019 | return output; \ |
1020 | }, \ |
1021 | py::return_value_policy::reference); \ |
1022 | nvf_ops.def( \ |
1023 | op_str, \ |
1024 | [](nvfuser::FusionDefinition::Operators& self, \ |
1025 | nvfuser::Scalar arg1, \ |
1026 | nvfuser::Scalar arg2, \ |
1027 | nvfuser::Tensor arg3, \ |
1028 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
1029 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1030 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1031 | nvfuser::Tensor output = fd->defineTensor(); \ |
1032 | fd->defineRecord(new nvfuser::OpRecord< \ |
1033 | Nvf::TensorView*, \ |
1034 | Nvf::Val*, \ |
1035 | Nvf::Val*, \ |
1036 | Nvf::TensorView*, \ |
1037 | Nvf::Val*>( \ |
1038 | {fd->recordingState(arg1()), \ |
1039 | fd->recordingState(arg2()), \ |
1040 | fd->recordingState(arg3()), \ |
1041 | fd->recordingState(arg4())}, \ |
1042 | {fd->recordingState(output())}, \ |
1043 | ("ops." op_str), \ |
1044 | static_cast< \ |
1045 | Nvf:: \ |
1046 | TensorView* (*)(Nvf::Val*, Nvf::Val*, Nvf::TensorView*, Nvf::Val*)>( \ |
1047 | Nvf::op_name))); \ |
1048 | return output; \ |
1049 | }, \ |
1050 | py::return_value_policy::reference); \ |
1051 | nvf_ops.def( \ |
1052 | op_str, \ |
1053 | [](nvfuser::FusionDefinition::Operators& self, \ |
1054 | nvfuser::Tensor arg1, \ |
1055 | nvfuser::Scalar arg2, \ |
1056 | nvfuser::Scalar arg3, \ |
1057 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
1058 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1059 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1060 | nvfuser::Tensor output = fd->defineTensor(); \ |
1061 | fd->defineRecord(new nvfuser::OpRecord< \ |
1062 | Nvf::TensorView*, \ |
1063 | Nvf::TensorView*, \ |
1064 | Nvf::Val*, \ |
1065 | Nvf::Val*, \ |
1066 | Nvf::Val*>( \ |
1067 | {fd->recordingState(arg1()), \ |
1068 | fd->recordingState(arg2()), \ |
1069 | fd->recordingState(arg3()), \ |
1070 | fd->recordingState(arg4())}, \ |
1071 | {fd->recordingState(output())}, \ |
1072 | ("ops." op_str), \ |
1073 | static_cast< \ |
1074 | Nvf:: \ |
1075 | TensorView* (*)(Nvf::TensorView*, Nvf::Val*, Nvf::Val*, Nvf::Val*)>( \ |
1076 | Nvf::op_name))); \ |
1077 | return output; \ |
1078 | }, \ |
1079 | py::return_value_policy::reference); \ |
1080 | nvf_ops.def( \ |
1081 | op_str, \ |
1082 | [](nvfuser::FusionDefinition::Operators& self, \ |
1083 | nvfuser::Scalar arg1, \ |
1084 | nvfuser::Tensor arg2, \ |
1085 | nvfuser::Scalar arg3, \ |
1086 | nvfuser::Scalar arg4) -> nvfuser::Tensor { \ |
1087 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1088 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1089 | nvfuser::Tensor output = fd->defineTensor(); \ |
1090 | fd->defineRecord(new nvfuser::OpRecord< \ |
1091 | Nvf::TensorView*, \ |
1092 | Nvf::Val*, \ |
1093 | Nvf::TensorView*, \ |
1094 | Nvf::Val*, \ |
1095 | Nvf::Val*>( \ |
1096 | {fd->recordingState(arg1()), \ |
1097 | fd->recordingState(arg2()), \ |
1098 | fd->recordingState(arg3()), \ |
1099 | fd->recordingState(arg4())}, \ |
1100 | {fd->recordingState(output())}, \ |
1101 | ("ops." op_str), \ |
1102 | static_cast< \ |
1103 | Nvf:: \ |
1104 | TensorView* (*)(Nvf::Val*, Nvf::TensorView*, Nvf::Val*, Nvf::Val*)>( \ |
1105 | Nvf::op_name))); \ |
1106 | return output; \ |
1107 | }, \ |
1108 | py::return_value_policy::reference); |
1109 | |
1110 | NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP("addcmul" , addcmul) |
1111 | #undef NVFUSER_PYTHON_BINDING_TERNARY_WITH_ALPHA_OP |
1112 | |
1113 | #define NVFUSER_PYTHON_BINDING_REDUCTION_OP(op_str, op_name) \ |
1114 | nvf_ops.def( \ |
1115 | op_str, \ |
1116 | [](nvfuser::FusionDefinition::Operators& self, \ |
1117 | nvfuser::Tensor arg, \ |
1118 | const std::vector<int>& axes, \ |
1119 | bool keepdim, \ |
1120 | Nvf::DataType dtype) -> nvfuser::Tensor { \ |
1121 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1122 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1123 | nvfuser::Tensor output = fd->defineTensor(); \ |
1124 | fd->defineRecord(new nvfuser::ReductionOpRecord( \ |
1125 | {fd->recordingState(arg())}, \ |
1126 | {fd->recordingState(output())}, \ |
1127 | ("ops." op_str), \ |
1128 | static_cast< \ |
1129 | Nvf:: \ |
1130 | TensorView* (*)(Nvf::TensorView*, const std::vector<int>&, bool, Nvf::DataType)>( \ |
1131 | Nvf::op_name), \ |
1132 | axes, \ |
1133 | keepdim, \ |
1134 | dtype)); \ |
1135 | return output; \ |
1136 | }, \ |
1137 | py::arg("arg"), \ |
1138 | py::arg("axes"), \ |
1139 | py::arg("keepdim") = false, \ |
1140 | py::arg("dtype") = Nvf::DataType::Null, \ |
1141 | py::return_value_policy::reference); |
1142 | |
1143 | NVFUSER_PYTHON_BINDING_REDUCTION_OP("sum" , sum) |
1144 | NVFUSER_PYTHON_BINDING_REDUCTION_OP("max" , max) |
1145 | NVFUSER_PYTHON_BINDING_REDUCTION_OP("min" , min) |
1146 | #undef NVFUSER_PYTHON_BINDING_REDUCTION_OP |
1147 | |
1148 | #define NVFUSER_PYTHON_BINDING_CAST_OP(op_str, op_name) \ |
1149 | nvf_ops.def( \ |
1150 | op_str, \ |
1151 | [](nvfuser::FusionDefinition::Operators& self, \ |
1152 | nvfuser::Tensor arg, \ |
1153 | Nvf::DataType dtype) -> nvfuser::Tensor { \ |
1154 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1155 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1156 | nvfuser::Tensor output = fd->defineTensor(); \ |
1157 | fd->defineRecord( \ |
1158 | new nvfuser::CastOpRecord<Nvf::TensorView*, Nvf::TensorView*>( \ |
1159 | {fd->recordingState(arg())}, \ |
1160 | {fd->recordingState(output())}, \ |
1161 | ("ops." op_str), \ |
1162 | static_cast< \ |
1163 | Nvf::TensorView* (*)(Nvf::DataType, Nvf::TensorView*)>( \ |
1164 | Nvf::op_name), \ |
1165 | dtype)); \ |
1166 | return output; \ |
1167 | }, \ |
1168 | py::arg("arg"), \ |
1169 | py::arg("dtype"), \ |
1170 | py::return_value_policy::reference); \ |
1171 | nvf_ops.def( \ |
1172 | op_str, \ |
1173 | [](nvfuser::FusionDefinition::Operators& self, \ |
1174 | nvfuser::Scalar arg, \ |
1175 | Nvf::DataType dtype) -> nvfuser::Scalar { \ |
1176 | FUSER_PERF_SCOPE("Operators." op_str); \ |
1177 | nvfuser::FusionDefinition* fd = self.fusion_definition; \ |
1178 | nvfuser::Scalar output = fd->defineScalar(); \ |
1179 | fd->defineRecord(new nvfuser::CastOpRecord<Nvf::Val*, Nvf::Val*>( \ |
1180 | {fd->recordingState(arg())}, \ |
1181 | {fd->recordingState(output())}, \ |
1182 | ("ops." op_str), \ |
1183 | static_cast<Nvf::Val* (*)(Nvf::DataType, Nvf::Val*)>( \ |
1184 | Nvf::op_name), \ |
1185 | dtype)); \ |
1186 | return output; \ |
1187 | }, \ |
1188 | py::arg("arg"), \ |
1189 | py::arg("dtype"), \ |
1190 | py::return_value_policy::reference); |
1191 | |
1192 | NVFUSER_PYTHON_BINDING_CAST_OP("cast" , castOp) |
1193 | #undef NVFUSER_PYTHON_BINDING_CAST_OP |
1194 | |
1195 | nvf_ops.def( |
1196 | "permute" , |
1197 | [](nvfuser::FusionDefinition::Operators& self, |
1198 | nvfuser::Tensor arg, |
1199 | std::vector<int64_t>& dims) -> nvfuser::Tensor { |
1200 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1201 | nvfuser::Tensor output = fd->defineTensor(); |
1202 | self.fusion_definition->defineRecord(new nvfuser::PermuteOpRecord( |
1203 | {fd->recordingState(arg())}, {fd->recordingState(output())}, dims)); |
1204 | return output; |
1205 | }, |
1206 | py::arg("arg" ), |
1207 | py::arg("dims" ), |
1208 | py::return_value_policy::reference); |
1209 | nvf_ops.def( |
1210 | "squeeze" , |
1211 | [](nvfuser::FusionDefinition::Operators& self, |
1212 | nvfuser::Tensor arg, |
1213 | std::vector<int64_t>& original_shape, |
1214 | int64_t dim) -> nvfuser::Tensor { |
1215 | FUSER_PERF_SCOPE("Operators.squeeze" ); |
1216 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1217 | nvfuser::Tensor output = fd->defineTensor(); |
1218 | fd->defineRecord(new nvfuser::SqueezeOpRecord( |
1219 | {fd->recordingState(arg())}, |
1220 | {fd->recordingState(output())}, |
1221 | original_shape, |
1222 | dim)); |
1223 | return output; |
1224 | }, |
1225 | py::arg("arg" ), |
1226 | py::arg("original_shape" ), |
1227 | py::arg("dim" ), |
1228 | py::return_value_policy::reference); |
1229 | nvf_ops.def( |
1230 | "view" , |
1231 | [](nvfuser::FusionDefinition::Operators& self, |
1232 | nvfuser::Tensor arg, |
1233 | std::vector<int64_t>& original_shape, |
1234 | std::vector<int64_t>& new_shape) -> nvfuser::Tensor { |
1235 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1236 | nvfuser::Tensor output = fd->defineTensor(); |
1237 | self.fusion_definition->defineRecord(new nvfuser::ViewOpRecord( |
1238 | {fd->recordingState(arg())}, |
1239 | {fd->recordingState(output())}, |
1240 | original_shape, |
1241 | new_shape)); |
1242 | return output; |
1243 | }, |
1244 | py::arg("arg" ), |
1245 | py::arg("original_shape" ), |
1246 | py::arg("new_shape" ), |
1247 | py::return_value_policy::reference); |
1248 | nvf_ops.def( |
1249 | "full" , |
1250 | [](nvfuser::FusionDefinition::Operators& self, |
1251 | std::vector<int64_t>& size, |
1252 | nvfuser::Scalar arg, |
1253 | Nvf::DataType dtype) -> nvfuser::Tensor { |
1254 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1255 | nvfuser::Tensor output = fd->defineTensor(); |
1256 | fd->defineRecord(new nvfuser::FullOpRecord( |
1257 | {fd->recordingState(arg())}, |
1258 | {fd->recordingState(output())}, |
1259 | size, |
1260 | dtype)); |
1261 | return output; |
1262 | }, |
1263 | py::arg("size" ), |
1264 | py::arg("arg" ), |
1265 | py::arg("dtype" ), |
1266 | py::return_value_policy::reference); |
1267 | nvf_ops.def( |
1268 | "var" , |
1269 | [](nvfuser::FusionDefinition::Operators& self, |
1270 | nvfuser::Tensor arg, |
1271 | std::vector<int>& axes, |
1272 | int64_t correction, |
1273 | bool keepdim) -> nvfuser::Tensor { |
1274 | FUSER_PERF_SCOPE("Operators.var" ); |
1275 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1276 | nvfuser::Tensor output = fd->defineTensor(); |
1277 | fd->defineRecord(new nvfuser::VarianceOpRecord( |
1278 | {fd->recordingState(arg())}, |
1279 | {fd->recordingState(output())}, |
1280 | axes, |
1281 | correction, |
1282 | keepdim)); |
1283 | return output; |
1284 | }, |
1285 | py::arg("arg" ), |
1286 | py::arg("axes" ), |
1287 | py::arg("correction" ), |
1288 | py::arg("keepdim" ) = false, |
1289 | py::return_value_policy::reference); |
1290 | nvf_ops.def( |
1291 | "var_mean" , |
1292 | [](nvfuser::FusionDefinition::Operators& self, |
1293 | nvfuser::Tensor arg, |
1294 | std::vector<int>& axes, |
1295 | int64_t correction, |
1296 | bool keepdim) -> decltype(auto) { |
1297 | FUSER_PERF_SCOPE("Operators.var_mean" ); |
1298 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1299 | nvfuser::Tensor var = fd->defineTensor(); |
1300 | nvfuser::Tensor mean = fd->defineTensor(); |
1301 | fd->defineRecord(new nvfuser::VarianceMeanOpRecord( |
1302 | {fd->recordingState(arg())}, |
1303 | {fd->recordingState(var()), fd->recordingState(mean())}, |
1304 | axes, |
1305 | correction, |
1306 | keepdim)); |
1307 | return std::make_tuple(var, mean); |
1308 | }, |
1309 | py::arg("arg" ), |
1310 | py::arg("axes" ), |
1311 | py::arg("correction" ), |
1312 | py::arg("keepdim" ) = false, |
1313 | py::return_value_policy::reference); |
1314 | nvf_ops.def( |
1315 | "batch_norm" , |
1316 | [](nvfuser::FusionDefinition::Operators& self, |
1317 | nvfuser::Tensor arg, |
1318 | c10::optional<nvfuser::Tensor> weight, |
1319 | c10::optional<nvfuser::Tensor> bias, |
1320 | c10::optional<nvfuser::Tensor> running_mean, |
1321 | c10::optional<nvfuser::Tensor> running_var, |
1322 | nvfuser::Scalar momentum, |
1323 | nvfuser::Scalar eps, |
1324 | bool training, |
1325 | bool channels_last) -> decltype(auto) { |
1326 | FUSER_PERF_SCOPE("Operators.batch_norm" ); |
1327 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1328 | nvfuser::Tensor output = fd->defineTensor(); |
1329 | nvfuser::Tensor mean = fd->defineTensor(); |
1330 | nvfuser::Tensor invstd = fd->defineTensor(); |
1331 | auto weight_state = weight.has_value() |
1332 | ? fd->recordingState(weight.value()()) |
1333 | : nvfuser::State(0, nvfuser::StateType::None); |
1334 | auto bias_state = bias.has_value() |
1335 | ? fd->recordingState(bias.value()()) |
1336 | : nvfuser::State(0, nvfuser::StateType::None); |
1337 | auto running_mean_state = running_mean.has_value() |
1338 | ? fd->recordingState(running_mean.value()()) |
1339 | : nvfuser::State(0, nvfuser::StateType::None); |
1340 | auto running_var_state = running_var.has_value() |
1341 | ? fd->recordingState(running_var.value()()) |
1342 | : nvfuser::State(0, nvfuser::StateType::None); |
1343 | fd->defineRecord(new nvfuser::BatchNormOpRecord( |
1344 | {fd->recordingState(arg()), |
1345 | weight_state, |
1346 | bias_state, |
1347 | running_mean_state, |
1348 | running_var_state, |
1349 | fd->recordingState(momentum()), |
1350 | fd->recordingState(eps())}, |
1351 | {fd->recordingState(output()), |
1352 | fd->recordingState(mean()), |
1353 | fd->recordingState(invstd())}, |
1354 | training, |
1355 | channels_last)); |
1356 | return std::make_tuple(output, mean, invstd); |
1357 | }, |
1358 | py::arg("arg" ), |
1359 | py::arg("weight" ).none(true), |
1360 | py::arg("bias" ).none(true), |
1361 | py::arg("running_mean" ).none(true), |
1362 | py::arg("running_var" ).none(true), |
1363 | py::arg("momentum" ), |
1364 | py::arg("eps" ), |
1365 | py::arg("training" ), |
1366 | py::arg("channels_last" ) = false, |
1367 | py::return_value_policy::reference); |
1368 | nvf_ops.def( |
1369 | "broadcast_in_dim" , |
1370 | [](nvfuser::FusionDefinition::Operators& self, |
1371 | nvfuser::Tensor arg, |
1372 | std::vector<int64_t>& output_shape, |
1373 | std::vector<int64_t>& broadcast_dims) -> nvfuser::Tensor { |
1374 | FUSER_PERF_SCOPE("Operators.broadcast_in_dim" ); |
1375 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1376 | TORCH_CHECK( |
1377 | output_shape.size() >= broadcast_dims.size(), |
1378 | "broadcast_dims vector size is too big for output shape!" ); |
1379 | nvfuser::Tensor output = fd->defineTensor(); |
1380 | fd->defineRecord(new nvfuser::BroadcastInDimOpRecord( |
1381 | {fd->recordingState(arg())}, |
1382 | {fd->recordingState(output())}, |
1383 | "ops.broadcast_in_dim" , |
1384 | output_shape, |
1385 | broadcast_dims)); |
1386 | return output; |
1387 | }, |
1388 | py::arg("arg" ), |
1389 | py::arg("output_shape" ), |
1390 | py::arg("broadcast_dims" ), |
1391 | py::return_value_policy::reference); |
1392 | nvf_ops.def( |
1393 | "broadcast" , |
1394 | [](nvfuser::FusionDefinition::Operators& self, |
1395 | nvfuser::Tensor arg, |
1396 | std::vector<bool>& is_broadcast_dim) -> nvfuser::Tensor { |
1397 | FUSER_PERF_SCOPE("Operators.broadcast" ); |
1398 | nvfuser::FusionDefinition* fd = self.fusion_definition; |
1399 | nvfuser::Tensor output = fd->defineTensor(); |
1400 | fd->defineRecord(new nvfuser::BroadcastOpRecord( |
1401 | {fd->recordingState(arg())}, |
1402 | {fd->recordingState(output())}, |
1403 | "ops.broadcast" , |
1404 | is_broadcast_dim)); |
1405 | return output; |
1406 | }, |
1407 | py::arg("arg" ), |
1408 | py::arg("is_broadcast_dim" ), |
1409 | py::return_value_policy::reference); |
1410 | } |
1411 | |
1412 | } // namespace jit |
1413 | } // namespace torch |
1414 | |