1 | #include <c10/util/Exception.h> |
---|---|
2 | #include <operator_registry.h> |
3 | |
4 | namespace torch { |
5 | namespace executor { |
6 | |
7 | OperatorRegistry& getOperatorRegistry() { |
8 | static OperatorRegistry operator_registry; |
9 | return operator_registry; |
10 | } |
11 | |
12 | bool register_operators(const ArrayRef<Operator>& operators) { |
13 | return getOperatorRegistry().register_operators(operators); |
14 | } |
15 | |
16 | bool OperatorRegistry::register_operators( |
17 | const ArrayRef<Operator>& operators) { |
18 | for (const auto& op : operators) { |
19 | this->operators_map_[op.name_] = op.op_; |
20 | } |
21 | return true; |
22 | } |
23 | |
24 | bool hasOpsFn(const char* name) { |
25 | return getOperatorRegistry().hasOpsFn(name); |
26 | } |
27 | |
28 | bool OperatorRegistry::hasOpsFn(const char* name) { |
29 | auto op = this->operators_map_.find(name); |
30 | return op != this->operators_map_.end(); |
31 | } |
32 | |
33 | OpFunction& getOpsFn(const char* name) { |
34 | return getOperatorRegistry().getOpsFn(name); |
35 | } |
36 | |
37 | OpFunction& OperatorRegistry::getOpsFn(const char* name) { |
38 | auto op = this->operators_map_.find(name); |
39 | TORCH_CHECK_MSG(op != this->operators_map_.end(), "Operator not found!"); |
40 | return op->second; |
41 | } |
42 | |
43 | |
44 | } // namespace executor |
45 | } // namespace torch |
46 |