1 | #pragma once |
2 | |
3 | #include <ATen/core/op_registration/op_registration.h> |
4 | #include <ATen/core/stack.h> |
5 | #include <torch/csrc/jit/runtime/operator.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | /// Registration class for new operators. Effectively calls |
11 | /// `torch::jit::registerOperator` for every supplied operator, but allows doing |
12 | /// so in the global scope when a `RegisterOperators` object is assigned to a |
13 | /// static variable. |
14 | /// Note: This is *not* the custom operator API. If you want to register custom |
15 | /// operators, take a look at torch::RegisterOperators. |
16 | struct TORCH_API RegisterOperators { |
17 | RegisterOperators() = default; |
18 | |
19 | /// Registers a vector of already created `Operator`s. |
20 | /// The operator element is now optional to filter null ops. It's backward |
21 | /// compatible and works for selective operator registration. |
22 | explicit RegisterOperators(std::vector<c10::optional<Operator>> operators) { |
23 | for (c10::optional<Operator>& o : operators) { |
24 | if (o) { |
25 | registerOperator(std::move(o.value())); |
26 | } |
27 | } |
28 | } |
29 | }; |
30 | |
31 | } // namespace jit |
32 | } // namespace torch |
33 | |