1 | #include <torch/custom_class.h> |
---|---|
2 | #include <torch/script.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | struct ScalarTypeClass : public torch::CustomClassHolder { |
8 | ScalarTypeClass(at::ScalarType s) : scalar_type_(s) {} |
9 | at::ScalarType scalar_type_; |
10 | }; |
11 | |
12 | template <class T> |
13 | struct MyStackClass : torch::CustomClassHolder { |
14 | std::vector<T> stack_; |
15 | MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {} |
16 | |
17 | void push(T x) { |
18 | stack_.push_back(x); |
19 | } |
20 | T pop() { |
21 | auto val = stack_.back(); |
22 | stack_.pop_back(); |
23 | return val; |
24 | } |
25 | |
26 | c10::intrusive_ptr<MyStackClass> clone() const { |
27 | return c10::make_intrusive<MyStackClass>(stack_); |
28 | } |
29 | |
30 | void merge(const c10::intrusive_ptr<MyStackClass>& c) { |
31 | for (auto& elem : c->stack_) { |
32 | push(elem); |
33 | } |
34 | } |
35 | |
36 | std::tuple<double, int64_t> return_a_tuple() const { |
37 | return std::make_tuple(1337.0f, 123); |
38 | } |
39 | }; |
40 | } // namespace jit |
41 | } // namespace torch |
42 |