1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | void initPythonIRBindings(PyObject* module); |
10 | |
11 | // execute a Python function, used for Ops we can't optimize but that we want to |
12 | // optimize around |
13 | struct ConcretePythonOp : public PythonOp { |
14 | static Symbol Kind; |
15 | |
16 | ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {} |
17 | ConcretePythonOp* init( |
18 | THPObjectPtr&& pyobj, |
19 | const std::string& cconv, |
20 | pyobj_list&& scalar_args) { |
21 | this->pyobj = std::move(pyobj); |
22 | this->scalar_args = std::move(scalar_args); |
23 | this->cconv = cconv; |
24 | return this; |
25 | } |
26 | // The Python object which contains the implementation of this function. |
27 | // This is either a class (non-legacy) or an object (legacy). See |
28 | // TraceInterpreterState for execution semantics. |
29 | THPObjectPtr pyobj; |
30 | // The calling convention for the Python function. |
31 | // 'c' -- constant argument |
32 | // 'd' -- dynamic argument |
33 | std::string cconv; |
34 | // Scalar arguments to the Python function. Not necessarily passed to |
35 | // the function in this order; see cconv for the correct order. |
36 | std::vector<THPObjectPtr> scalar_args; |
37 | |
38 | std::string name() const override; |
39 | void cloneFrom(Node* other_) override; |
40 | Node* allocNewInstance(Graph* g) override { |
41 | return new ConcretePythonOp(g); |
42 | } |
43 | // recover the autograd.Function instance, if this PythonOp's function |
44 | // was originally SomeFunction.apply |
45 | // used in ONNX for discovering symbolics |
46 | c10::optional<THPObjectPtr> autogradFunction() const override; |
47 | void writeScalars(std::ostream& out) const override; |
48 | void lint_python() const override; |
49 | }; |
50 | |
51 | } // namespace jit |
52 | } // namespace torch |
53 | |