1 | #include <gtest/gtest.h> |
2 | #include <torch/csrc/jit/operator_upgraders/utils.h> |
3 | #include <torch/csrc/jit/operator_upgraders/version_map.h> |
4 | |
5 | #include <test/cpp/jit/test_utils.h> |
6 | |
7 | #include <vector> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(UpgraderUtils, FindCorrectUpgrader) { |
13 | std::vector<UpgraderEntry> dummy_entry = { |
14 | {4, "foo__0_3" , "foo.bar()" }, |
15 | {8, "foo__4_7" , "foo.bar()" }, |
16 | }; |
17 | |
18 | auto upgrader_at_6 = findUpgrader(dummy_entry, 6); |
19 | EXPECT_TRUE(upgrader_at_6.has_value()); |
20 | EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7" ); |
21 | |
22 | auto upgrader_at_1 = findUpgrader(dummy_entry, 1); |
23 | EXPECT_TRUE(upgrader_at_1.has_value()); |
24 | EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3" ); |
25 | |
26 | auto upgrader_at_10 = findUpgrader(dummy_entry, 10); |
27 | EXPECT_TRUE(upgrader_at_1.has_value()); |
28 | EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3" ); |
29 | } |
30 | |
31 | TEST(UpgraderUtils, IsVersionMapSorted) { |
32 | auto map = get_operator_version_map(); |
33 | // tests if the each list of UpgraderEntry in the map is sorted by |
34 | // their bumped_at_version field. |
35 | for (const auto& entry : map) { |
36 | std::vector<int> versions; |
37 | for (const auto& el : entry.second) { |
38 | versions.push_back(el.bumped_at_version); |
39 | } |
40 | EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end())); |
41 | } |
42 | } |
43 | |
44 | TEST(UpgraderUtils, FindIfOpIsCurrent) { |
45 | std::vector<UpgraderEntry> dummy_entry = { |
46 | {4, "foo__0_3" , "foo.bar()" }, |
47 | {8, "foo__4_7" , "foo.bar()" }, |
48 | }; |
49 | |
50 | auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6); |
51 | auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8); |
52 | EXPECT_FALSE(isCurrent); |
53 | EXPECT_TRUE(isCurrentV2); |
54 | |
55 | // symbol based look up |
56 | test_only_add_entry("foo" , dummy_entry[0]); |
57 | test_only_add_entry("foo" , dummy_entry[1]); |
58 | EXPECT_FALSE(isOpSymbolCurrent("foo" , 6)); |
59 | EXPECT_TRUE(isOpSymbolCurrent("foo" , 8)); |
60 | test_only_remove_entry("foo" ); |
61 | } |
62 | |
63 | TEST(UpgraderUtils, CanLoadHistoricOp) { |
64 | std::vector<UpgraderEntry> dummy_entry = { |
65 | {4, "foo__0_3" , "foo.bar()" }, |
66 | {8, "foo__4_7" , "foo.foo()" }, |
67 | }; |
68 | |
69 | std::vector<std::string> schemas = {"foo.bar()" , "foo.foo()" }; |
70 | |
71 | // symbol based look up |
72 | test_only_add_entry("old_op_not_exist.first" , dummy_entry[0]); |
73 | test_only_add_entry("old_op_not_exist.second" , dummy_entry[1]); |
74 | |
75 | auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist" , 2); |
76 | EXPECT_EQ(oldSchemas.size(), 2); |
77 | for (const auto& entry : oldSchemas) { |
78 | EXPECT_TRUE( |
79 | std::find(schemas.begin(), schemas.end(), entry) != schemas.end()); |
80 | } |
81 | |
82 | auto oldSchemasWithCurrentVersion = |
83 | loadPossibleHistoricOps("old_op_not_exist" , 9); |
84 | EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0); |
85 | |
86 | test_only_remove_entry("old_op_not_exist.first" ); |
87 | test_only_remove_entry("old_op_not_exist.first" ); |
88 | |
89 | // it is ok to have old schemas without overload |
90 | test_only_add_entry("old_op_not_exist_no_overload" , dummy_entry[0]); |
91 | auto oldSchemasNoOverload = |
92 | loadPossibleHistoricOps("old_op_not_exist_no_overload" , 2); |
93 | EXPECT_EQ(oldSchemasNoOverload.size(), 1); |
94 | EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()" ); |
95 | test_only_remove_entry("old_op_not_exist_no_overload" ); |
96 | } |
97 | |
98 | } // namespace jit |
99 | } // namespace torch |
100 | |