1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/runtime/custom_operator.h>
5#include <torch/csrc/jit/testing/file_check.h>
6#include <torch/jit.h>
7
8#include <sstream>
9#include <string>
10
11namespace torch {
12namespace jit {
13
14TEST(SchemaMatchingTest, VarType) {
15 RegisterOperators reg({
16 Operator(
17 "aten::test_vartype(t[] a, t b) -> (t)",
18 [](Stack& stack) {
19 c10::List<double> list;
20 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
21 double a;
22 pop(stack, list, a);
23 push(stack, a);
24 },
25 c10::AliasAnalysisKind::FROM_SCHEMA),
26 });
27 Module m("m");
28 m.define(R"(
29 def test(self):
30 a = (1.0, 2.0)
31 return torch.test_vartype(a, 2.0)
32 )");
33 auto result = m.run_method("test");
34 TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
35
36 const std::string error_example = R"JIT(
37 def test_2(self):
38 a = (1.0, 2.0)
39 non_float = (1, 1)
40 return torch.test_vartype(a, non_float)
41 )JIT";
42
43 std::string err = "";
44 try {
45 m.define(error_example);
46 } catch (const std::exception& e) {
47 err = e.what();
48 }
49 TORCH_INTERNAL_ASSERT(
50 err.find("previously matched to type") != std::string::npos);
51}
52
53TEST(SchemaMatchingTest, VarType2) {
54 RegisterOperators reg({
55 Operator(
56 "aten::test_vartype2(t a, t[] b) -> (t[])",
57 [](Stack& stack) {
58 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
59 double a;
60 c10::List<double> list;
61 pop(stack, a, list);
62 push(stack, a);
63 },
64 AliasAnalysisKind::FROM_SCHEMA),
65 });
66 Module m("m");
67 m.define(R"JIT(
68 def test(self):
69 a = (1.0, 2.0)
70 return torch.test_vartype2(3.0, a)
71 )JIT");
72 auto result = m.run_method("test");
73 TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
74
75 static const auto error_exam2 = R"JIT(
76 def test_2(self):
77 a = (1, 2)
78 return torch.test_vartype2(3.0, a)
79 )JIT";
80
81 std::string err = "";
82 try {
83 m.define(error_exam2);
84 } catch (const std::exception& e) {
85 err = e.what();
86 }
87 TORCH_INTERNAL_ASSERT(
88 err.find("previously matched to type") != std::string::npos);
89}
90} // namespace jit
91} // namespace torch
92