1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/runtime/custom_operator.h> |
5 | #include <torch/csrc/jit/testing/file_check.h> |
6 | #include <torch/jit.h> |
7 | |
8 | #include <sstream> |
9 | #include <string> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | TEST(SchemaMatchingTest, VarType) { |
15 | RegisterOperators reg({ |
16 | Operator( |
17 | "aten::test_vartype(t[] a, t b) -> (t)" , |
18 | [](Stack& stack) { |
19 | c10::List<double> list; |
20 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
21 | double a; |
22 | pop(stack, list, a); |
23 | push(stack, a); |
24 | }, |
25 | c10::AliasAnalysisKind::FROM_SCHEMA), |
26 | }); |
27 | Module m("m" ); |
28 | m.define(R"( |
29 | def test(self): |
30 | a = (1.0, 2.0) |
31 | return torch.test_vartype(a, 2.0) |
32 | )" ); |
33 | auto result = m.run_method("test" ); |
34 | TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); |
35 | |
36 | const std::string error_example = R"JIT( |
37 | def test_2(self): |
38 | a = (1.0, 2.0) |
39 | non_float = (1, 1) |
40 | return torch.test_vartype(a, non_float) |
41 | )JIT" ; |
42 | |
43 | std::string err = "" ; |
44 | try { |
45 | m.define(error_example); |
46 | } catch (const std::exception& e) { |
47 | err = e.what(); |
48 | } |
49 | TORCH_INTERNAL_ASSERT( |
50 | err.find("previously matched to type" ) != std::string::npos); |
51 | } |
52 | |
53 | TEST(SchemaMatchingTest, VarType2) { |
54 | RegisterOperators reg({ |
55 | Operator( |
56 | "aten::test_vartype2(t a, t[] b) -> (t[])" , |
57 | [](Stack& stack) { |
58 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
59 | double a; |
60 | c10::List<double> list; |
61 | pop(stack, a, list); |
62 | push(stack, a); |
63 | }, |
64 | AliasAnalysisKind::FROM_SCHEMA), |
65 | }); |
66 | Module m("m" ); |
67 | m.define(R"JIT( |
68 | def test(self): |
69 | a = (1.0, 2.0) |
70 | return torch.test_vartype2(3.0, a) |
71 | )JIT" ); |
72 | auto result = m.run_method("test" ); |
73 | TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); |
74 | |
75 | static const auto error_exam2 = R"JIT( |
76 | def test_2(self): |
77 | a = (1, 2) |
78 | return torch.test_vartype2(3.0, a) |
79 | )JIT" ; |
80 | |
81 | std::string err = "" ; |
82 | try { |
83 | m.define(error_exam2); |
84 | } catch (const std::exception& e) { |
85 | err = e.what(); |
86 | } |
87 | TORCH_INTERNAL_ASSERT( |
88 | err.find("previously matched to type" ) != std::string::npos); |
89 | } |
90 | } // namespace jit |
91 | } // namespace torch |
92 | |