1#pragma once
2
3#include <ATen/core/function.h>
4#include <ATen/core/ivalue.h>
5#include <c10/util/Exception.h>
6#include <c10/util/intrusive_ptr.h>
7#include <functional>
8#include <utility>
9
10namespace torch {
11namespace jit {
12
13struct BuiltinOpFunction : public Function {
14 BuiltinOpFunction(
15 c10::QualifiedName qualname,
16 c10::FunctionSchema schema,
17 std::function<void(Stack&)> callable,
18 std::string doc_string = "")
19 : name_(std::move(qualname)),
20 callable_(std::move(callable)),
21 schema_(std::move(schema)),
22 doc_string_(std::move(doc_string)) {
23 TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
24 }
25
26 c10::string_view doc_string() const override {
27 return doc_string_;
28 }
29
30 void run(Stack& stack) override {
31 callable_(stack);
32 }
33
34 c10::intrusive_ptr<c10::ivalue::Future> runAsync(
35 Stack& stack,
36 TaskLauncher /* not used */) override {
37 run(stack);
38 auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
39 res->markCompleted(std::move(stack.front()));
40 return res;
41 }
42
43 const c10::QualifiedName& qualname() const override {
44 return name_;
45 }
46
47 // if this isn't yet defined, run its method_creator function
48 void ensure_defined() override {
49 // nop
50 }
51
52 const c10::FunctionSchema& getSchema() const override {
53 return schema_;
54 }
55
56 size_t num_inputs() const override {
57 return schema_.arguments().size();
58 }
59
60 Function& setSchema(c10::FunctionSchema schema) override {
61 schema_ = std::move(schema);
62 return *this;
63 }
64
65 bool call(Stack& stack, c10::optional<size_t>, c10::function_ref<void(const Code&)>) override {
66 run(stack);
67 return false;
68 }
69
70 bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>) override {
71 run(stack);
72 return false;
73 }
74
75 ~BuiltinOpFunction() override = default;
76
77 private:
78 c10::QualifiedName name_;
79
80 std::function<void(Stack&)> callable_;
81
82 c10::FunctionSchema schema_;
83
84 std::string doc_string_;
85};
86
87} // namespace jit
88} // namespace torch
89