1#include "operator_registry.h"
2#include <gtest/gtest.h>
3
4namespace torch {
5namespace executor {
6
7// add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
8TEST(OperatorRegistrationTest, Add) {
9 EValue values[4];
10 values[0] = EValue(at::ones({2, 3}));
11 values[1] = EValue(at::ones({2, 3}));
12 values[2] = EValue(int64_t(1));
13 values[3] = EValue(at::zeros({2, 3}));
14 ASSERT_TRUE(hasOpsFn("aten::add.out"));
15 auto op = getOpsFn("aten::add.out");
16
17 EValue* kernel_values[4];
18 for (size_t i = 0; i < 4; i++) {
19 kernel_values[i] = &values[i];
20 }
21 op(kernel_values);
22 at::Tensor expected = at::ones({2, 3});
23 expected = at::fill(expected, 2);
24 ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
25
26}
27
28// custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!)
29TEST(OperatorRegistrationTest, CustomAdd3) {
30 EValue values[4];
31 values[0] = EValue(at::ones({2, 3}));
32 values[1] = EValue(at::ones({2, 3}));
33 values[2] = EValue(at::ones({2, 3}));
34 values[3] = EValue(at::zeros({2, 3}));
35 ASSERT_TRUE(hasOpsFn("custom::add_3.out"));
36 auto op = getOpsFn("custom::add_3.out");
37
38 EValue* kernel_values[4];
39 for (size_t i = 0; i < 4; i++) {
40 kernel_values[i] = &values[i];
41 }
42 op(kernel_values);
43 at::Tensor expected = at::ones({2, 3});
44 expected = at::fill(expected, 3);
45 ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
46
47}
48} // namespace executor
49} // namespace torch
50