1 | #include "operator_registry.h" |
2 | #include <gtest/gtest.h> |
3 | |
4 | namespace torch { |
5 | namespace executor { |
6 | |
7 | // add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) |
8 | TEST(OperatorRegistrationTest, Add) { |
9 | EValue values[4]; |
10 | values[0] = EValue(at::ones({2, 3})); |
11 | values[1] = EValue(at::ones({2, 3})); |
12 | values[2] = EValue(int64_t(1)); |
13 | values[3] = EValue(at::zeros({2, 3})); |
14 | ASSERT_TRUE(hasOpsFn("aten::add.out" )); |
15 | auto op = getOpsFn("aten::add.out" ); |
16 | |
17 | EValue* kernel_values[4]; |
18 | for (size_t i = 0; i < 4; i++) { |
19 | kernel_values[i] = &values[i]; |
20 | } |
21 | op(kernel_values); |
22 | at::Tensor expected = at::ones({2, 3}); |
23 | expected = at::fill(expected, 2); |
24 | ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); |
25 | |
26 | } |
27 | |
28 | // custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) |
29 | TEST(OperatorRegistrationTest, CustomAdd3) { |
30 | EValue values[4]; |
31 | values[0] = EValue(at::ones({2, 3})); |
32 | values[1] = EValue(at::ones({2, 3})); |
33 | values[2] = EValue(at::ones({2, 3})); |
34 | values[3] = EValue(at::zeros({2, 3})); |
35 | ASSERT_TRUE(hasOpsFn("custom::add_3.out" )); |
36 | auto op = getOpsFn("custom::add_3.out" ); |
37 | |
38 | EValue* kernel_values[4]; |
39 | for (size_t i = 0; i < 4; i++) { |
40 | kernel_values[i] = &values[i]; |
41 | } |
42 | op(kernel_values); |
43 | at::Tensor expected = at::ones({2, 3}); |
44 | expected = at::fill(expected, 3); |
45 | ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor())); |
46 | |
47 | } |
48 | } // namespace executor |
49 | } // namespace torch |
50 | |