1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/utils/object_ptr.h>
5
6namespace torch {
7namespace jit {
8
9void initPythonIRBindings(PyObject* module);
10
11// execute a Python function, used for Ops we can't optimize but that we want to
12// optimize around
13struct ConcretePythonOp : public PythonOp {
14 static Symbol Kind;
15
16 ConcretePythonOp(Graph* graph) : PythonOp(graph, ::c10::prim::PythonOp) {}
17 ConcretePythonOp* init(
18 THPObjectPtr&& pyobj,
19 const std::string& cconv,
20 pyobj_list&& scalar_args) {
21 this->pyobj = std::move(pyobj);
22 this->scalar_args = std::move(scalar_args);
23 this->cconv = cconv;
24 return this;
25 }
26 // The Python object which contains the implementation of this function.
27 // This is either a class (non-legacy) or an object (legacy). See
28 // TraceInterpreterState for execution semantics.
29 THPObjectPtr pyobj;
30 // The calling convention for the Python function.
31 // 'c' -- constant argument
32 // 'd' -- dynamic argument
33 std::string cconv;
34 // Scalar arguments to the Python function. Not necessarily passed to
35 // the function in this order; see cconv for the correct order.
36 std::vector<THPObjectPtr> scalar_args;
37
38 std::string name() const override;
39 void cloneFrom(Node* other_) override;
40 Node* allocNewInstance(Graph* g) override {
41 return new ConcretePythonOp(g);
42 }
43 // recover the autograd.Function instance, if this PythonOp's function
44 // was originally SomeFunction.apply
45 // used in ONNX for discovering symbolics
46 c10::optional<THPObjectPtr> autogradFunction() const override;
47 void writeScalars(std::ostream& out) const override;
48 void lint_python() const override;
49};
50
51} // namespace jit
52} // namespace torch
53