1 | #pragma once |
2 | |
3 | #include <cstring> |
4 | |
5 | #include <c10/util/ArrayRef.h> |
6 | #include "Evalue.h" |
7 | #include <functional> |
8 | #include <map> |
9 | |
10 | namespace torch { |
11 | namespace executor { |
12 | |
13 | using OpFunction = std::function<void(EValue**)>; |
14 | |
15 | template<typename T> |
16 | using ArrayRef = at::ArrayRef<T>; |
17 | |
18 | #define EXECUTORCH_SCOPE_PROF(x) |
19 | |
20 | struct Operator { |
21 | const char* name_; |
22 | OpFunction op_; |
23 | |
24 | Operator() = default; |
25 | |
26 | /** |
27 | * We are doing a copy of the string pointer instead of duplicating the string |
28 | * itself, we require the lifetime of the operator name to be at least as long |
29 | * as the operator registry. |
30 | */ |
31 | explicit Operator(const char* name, OpFunction func) |
32 | : name_(name), op_(func) {} |
33 | }; |
34 | |
35 | /** |
36 | * See OperatorRegistry::hasOpsFn() |
37 | */ |
38 | bool hasOpsFn(const char* name); |
39 | |
40 | /** |
41 | * See OperatorRegistry::getOpsFn() |
42 | */ |
43 | OpFunction& getOpsFn(const char* name); |
44 | |
45 | |
46 | [[nodiscard]] bool register_operators(const ArrayRef<Operator>&); |
47 | |
48 | struct OperatorRegistry { |
49 | public: |
50 | OperatorRegistry() : operatorRegSize_(0) {} |
51 | |
52 | bool register_operators(const ArrayRef<Operator>&); |
53 | |
54 | /** |
55 | * Checks whether an operator with a given name is registered |
56 | */ |
57 | bool hasOpsFn(const char* name); |
58 | |
59 | /** |
60 | * Checks whether an operator with a given name is registered |
61 | */ |
62 | OpFunction& getOpsFn(const char* name); |
63 | |
64 | private: |
65 | std::map<const char*, OpFunction> operators_map_; |
66 | uint32_t operatorRegSize_; |
67 | }; |
68 | |
69 | } // namespace executor |
70 | } // namespace torch |
71 | |