1#pragma once
2
3#include <cstring>
4
5#include <c10/util/ArrayRef.h>
6#include "Evalue.h"
7#include <functional>
8#include <map>
9
10namespace torch {
11namespace executor {
12
13using OpFunction = std::function<void(EValue**)>;
14
15template<typename T>
16using ArrayRef = at::ArrayRef<T>;
17
18#define EXECUTORCH_SCOPE_PROF(x)
19
20struct Operator {
21 const char* name_;
22 OpFunction op_;
23
24 Operator() = default;
25
26 /**
27 * We are doing a copy of the string pointer instead of duplicating the string
28 * itself, we require the lifetime of the operator name to be at least as long
29 * as the operator registry.
30 */
31 explicit Operator(const char* name, OpFunction func)
32 : name_(name), op_(func) {}
33};
34
35/**
36 * See OperatorRegistry::hasOpsFn()
37 */
38bool hasOpsFn(const char* name);
39
40/**
41 * See OperatorRegistry::getOpsFn()
42 */
43OpFunction& getOpsFn(const char* name);
44
45
46[[nodiscard]] bool register_operators(const ArrayRef<Operator>&);
47
48struct OperatorRegistry {
49 public:
50 OperatorRegistry() : operatorRegSize_(0) {}
51
52 bool register_operators(const ArrayRef<Operator>&);
53
54 /**
55 * Checks whether an operator with a given name is registered
56 */
57 bool hasOpsFn(const char* name);
58
59 /**
60 * Checks whether an operator with a given name is registered
61 */
62 OpFunction& getOpsFn(const char* name);
63
64 private:
65 std::map<const char*, OpFunction> operators_map_;
66 uint32_t operatorRegSize_;
67};
68
69} // namespace executor
70} // namespace torch
71