1#pragma once
2
3#include <ATen/core/function_schema.h>
4#include <ATen/core/ivalue.h>
5#include <ATen/core/qualified_name.h>
6#include <c10/util/Exception.h>
7#include <c10/util/FunctionRef.h>
8
9namespace c10 {
10struct FunctionSchema;
11};
12
13namespace at {
14TORCH_API void launch(std::function<void()> func);
15}
16
17namespace torch {
18namespace jit {
19
20struct Graph;
21struct Code;
22
23namespace mobile {
24struct Code;
25}
26
27using Stack = std::vector<at::IValue>;
28using Kwargs = std::unordered_map<std::string, at::IValue>;
29struct RecursiveMethodCallError : public std::exception {};
30using TaskLauncher = std::function<void(std::function<void()>)>;
31
32TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast=false);
33
34// A Function is a pure Graph with no implicit `self` object bound.
35// It contains schema information and the executor that manages the
36// execution of the function. Method is a wrapper around an
37// underlying Function that also provides a `self` object.
38struct TORCH_API Function {
39 virtual c10::string_view doc_string() const {
40 static constexpr c10::string_view no_doc_string = "";
41 return no_doc_string;
42 }
43
44 virtual bool isGraphFunction() const {
45 return false;
46 }
47
48 virtual void run(Stack& stack) = 0;
49
50 virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
51 Stack& /*stack*/,
52 TaskLauncher taskLauncher = at::launch) {
53 (void)taskLauncher; // Suppress unused variable warning
54 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
55 return {};
56 }
57
58 at::IValue operator()(
59 Stack stack,
60 const Kwargs& kwargs = Kwargs()) {
61 getSchema().checkAndNormalizeInputs(stack, kwargs);
62 run(stack);
63 return stack.front();
64 }
65
66 virtual const c10::QualifiedName& qualname() const = 0;
67
68 const std::string& name() const {
69 return qualname().name();
70 }
71
72 // if this isn't yet defined, run its method_creator function
73 virtual void ensure_defined() = 0;
74
75 virtual const c10::FunctionSchema& getSchema() const = 0;
76
77 virtual size_t num_inputs() const = 0;
78
79 virtual Function& setSchema(c10::FunctionSchema schema) = 0;
80
81 // call() defines how different interpreter implementations interacts with
82 // Function objects. Basically interpreters need to provide a callback to
83 // communicate to Functions what to do if provided a Code object.
84 // Alternatively we could design the signature to return an optional Code
85 // object, but that requires special handling the null case in interpreter
86 // and the fallback behavior is not well defined by interpreter but rather
87 // Function themselves, so a callback approach is more reasonable than
88 // returning values.
89 // If call() returns true, then callback completes successfully, otherwise
90 // call() returns false.
91
92 // Overload for server interpreter, a bailout size is needed for graph executor.
93 virtual bool call(Stack&, c10::optional<size_t>, c10::function_ref<void(const Code&)>) {
94 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
95 return false;
96 }
97
98 // Overload for mobile interpreter.
99 virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) {
100 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
101 return false;
102 }
103
104 virtual ~Function() = default;
105};
106} // namespace jit
107} // namespace torch
108