1 | #pragma once |
2 | |
3 | #include <ATen/core/function.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <ATen/core/stack.h> |
6 | #include <torch/csrc/api/include/torch/imethod.h> |
7 | #include <torch/csrc/jit/api/function_impl.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>; |
13 | |
14 | // A method in a module, e.g. f in: |
15 | // |
16 | // class M(ScriptModule): |
17 | // @script_method |
18 | // def f(self, x): |
19 | // ... |
20 | // Note: because Method/Module are exposed to python these |
21 | // classes use python method naming conventions |
22 | struct TORCH_API Method : public torch::IMethod { |
23 | Method(ObjectPtr owner, Function* function); |
24 | |
25 | // the module that contains this method. |
26 | Module owner() const; |
27 | void run(Stack& stack); |
28 | void run(Stack&& stack) { |
29 | run(stack); |
30 | } |
31 | |
32 | c10::IValue operator()( |
33 | std::vector<c10::IValue> stack, |
34 | const Kwargs& kwargs = Kwargs()) const override; |
35 | |
36 | // Run method async. Invocation on this function would invokes a JIT |
37 | // interpreter that executes ops inline, one by one, on caller's thread. A |
38 | // model can utilize async op, i.e. `fork`, to launch an asynchronous task |
39 | // which will be launched on provided `taskLauncher`. |
40 | c10::intrusive_ptr<c10::ivalue::Future> run_async( |
41 | std::vector<c10::IValue> stack, |
42 | const Kwargs& kwargs = Kwargs(), |
43 | TaskLauncher taskLauncher = at::launch); |
44 | |
45 | std::shared_ptr<Graph> graph() const { |
46 | return toGraphFunction(*function_).graph(); |
47 | } |
48 | |
49 | const std::string& name() const override { |
50 | return function_->name(); |
51 | } |
52 | |
53 | size_t num_inputs() const { |
54 | return function_->num_inputs(); |
55 | } |
56 | |
57 | GraphExecutor& get_executor() { |
58 | return toGraphFunction(*function_).get_executor(); |
59 | } |
60 | |
61 | Function& function() const { |
62 | return *function_; |
63 | } |
64 | |
65 | private: |
66 | void setArgumentNames(std::vector<std::string>&) const override; |
67 | |
68 | // Methods are uniqued onwed by a single module. This raw pointer allows |
69 | // looking up the module. |
70 | ObjectPtr owner_; |
71 | |
72 | // Underlying unbound function |
73 | Function* function_; |
74 | }; |
75 | |
76 | namespace script { |
77 | // We once had a `script::` namespace that was deleted. This is for backcompat |
78 | // of the public API; new code should not use this type alias. |
79 | using Method = ::torch::jit::Method; |
80 | } // namespace script |
81 | |
82 | } // namespace jit |
83 | } // namespace torch |
84 | |