1 | #pragma once |
2 | |
3 | #include <vector> |
4 | |
5 | #include <ATen/core/function.h> |
6 | #include <ATen/core/function_schema.h> |
7 | #include <ATen/core/ivalue.h> |
8 | #include <torch/csrc/jit/mobile/code.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | enum OpCode : uint8_t; |
13 | struct Instruction; |
14 | struct OperatorString; |
15 | |
16 | namespace mobile { |
17 | |
18 | class TORCH_API Function : public torch::jit::Function { |
19 | public: |
20 | explicit Function(c10::QualifiedName name); |
21 | Function( |
22 | c10::QualifiedName name, |
23 | Code code, |
24 | at::optional<c10::FunctionSchema> schema); |
25 | void run(Stack& stack) override; |
26 | at::IValue operator()(Stack& stack); |
27 | void ensure_defined() override {} |
28 | size_t num_inputs() const override; |
29 | const c10::QualifiedName& qualname() const override; |
30 | bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override; |
31 | |
32 | // NOTE: the APIs below is dangerous: if you call append_instruction with |
33 | // dbg_handle and then call it without; then the dbg_handle will become |
34 | // misaligned. Therefore only use ONE variant at time. |
35 | void append_instruction(OpCode op, int X, int N, int64_t dbg_handle); |
36 | void append_instruction(OpCode op, int X, int N); |
37 | void append_operator( |
38 | const std::string& name, |
39 | const std::string& overload_name, |
40 | const c10::optional<int>& num_specified_args); |
41 | void append_constant(const c10::IValue& constant); |
42 | void append_type(const c10::TypePtr& type); |
43 | void append_function(mobile::Function& func); |
44 | |
45 | void set_register_size(size_t size); |
46 | |
47 | int64_t get_debug_handle(size_t pc) const; |
48 | const Code& get_code() const; |
49 | Code& get_code(); |
50 | |
51 | torch::jit::Function& setSchema(c10::FunctionSchema schema) override; |
52 | bool hasSchema() const; |
53 | const c10::FunctionSchema& getSchema() const override; |
54 | |
55 | // Returns the debug handle corresponding to where the execution |
56 | // is halted due to exception. |
57 | // If no corresponding debug handle is found then -1 is returned. |
58 | const std::vector<int64_t>& getExceptionDebugHandles() const; |
59 | static Function& registerFunc( |
60 | const std::string& qualified_name, |
61 | const std::vector<Instruction>& instructions, |
62 | const std::vector<c10::IValue>& constants, |
63 | const std::vector<c10::TypePtr>& types, |
64 | const size_t register_size); |
65 | |
66 | // if not initialize, initialize by loading operators. |
67 | // return true of all op loaded, return false if some op is not found |
68 | // in the current runtime. Then, the ops that did not found will be filled |
69 | // in unsupported_op_names |
70 | bool initialize_operators(bool should_check_operators); |
71 | |
72 | private: |
73 | c10::QualifiedName name_; |
74 | Code code_; |
75 | at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+) |
76 | }; |
77 | |
78 | c10::optional<std::function<void(Stack&)>> makeOperatorFunction( |
79 | c10::OperatorName opname, |
80 | c10::optional<int> num_specified_args); |
81 | |
82 | TORCH_API std::string operator_str(const c10::OperatorName& opname); |
83 | |
84 | } // namespace mobile |
85 | } // namespace jit |
86 | } // namespace torch |
87 | |