1 | #pragma once |
2 | |
3 | #include <ATen/core/function.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/intrusive_ptr.h> |
7 | #include <functional> |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | struct BuiltinOpFunction : public Function { |
14 | BuiltinOpFunction( |
15 | c10::QualifiedName qualname, |
16 | c10::FunctionSchema schema, |
17 | std::function<void(Stack&)> callable, |
18 | std::string doc_string = "" ) |
19 | : name_(std::move(qualname)), |
20 | callable_(std::move(callable)), |
21 | schema_(std::move(schema)), |
22 | doc_string_(std::move(doc_string)) { |
23 | TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); |
24 | } |
25 | |
26 | c10::string_view doc_string() const override { |
27 | return doc_string_; |
28 | } |
29 | |
30 | void run(Stack& stack) override { |
31 | callable_(stack); |
32 | } |
33 | |
34 | c10::intrusive_ptr<c10::ivalue::Future> runAsync( |
35 | Stack& stack, |
36 | TaskLauncher /* not used */) override { |
37 | run(stack); |
38 | auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type()); |
39 | res->markCompleted(std::move(stack.front())); |
40 | return res; |
41 | } |
42 | |
43 | const c10::QualifiedName& qualname() const override { |
44 | return name_; |
45 | } |
46 | |
47 | // if this isn't yet defined, run its method_creator function |
48 | void ensure_defined() override { |
49 | // nop |
50 | } |
51 | |
52 | const c10::FunctionSchema& getSchema() const override { |
53 | return schema_; |
54 | } |
55 | |
56 | size_t num_inputs() const override { |
57 | return schema_.arguments().size(); |
58 | } |
59 | |
60 | Function& setSchema(c10::FunctionSchema schema) override { |
61 | schema_ = std::move(schema); |
62 | return *this; |
63 | } |
64 | |
65 | bool call(Stack& stack, c10::optional<size_t>, c10::function_ref<void(const Code&)>) override { |
66 | run(stack); |
67 | return false; |
68 | } |
69 | |
70 | bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>) override { |
71 | run(stack); |
72 | return false; |
73 | } |
74 | |
75 | ~BuiltinOpFunction() override = default; |
76 | |
77 | private: |
78 | c10::QualifiedName name_; |
79 | |
80 | std::function<void(Stack&)> callable_; |
81 | |
82 | c10::FunctionSchema schema_; |
83 | |
84 | std::string doc_string_; |
85 | }; |
86 | |
87 | } // namespace jit |
88 | } // namespace torch |
89 | |