1#include <gtest/gtest.h>
2#include <torch/csrc/jit/operator_upgraders/utils.h>
3#include <torch/csrc/jit/operator_upgraders/version_map.h>
4
5#include <test/cpp/jit/test_utils.h>
6
7#include <vector>
8
9namespace torch {
10namespace jit {
11
12TEST(UpgraderUtils, FindCorrectUpgrader) {
13 std::vector<UpgraderEntry> dummy_entry = {
14 {4, "foo__0_3", "foo.bar()"},
15 {8, "foo__4_7", "foo.bar()"},
16 };
17
18 auto upgrader_at_6 = findUpgrader(dummy_entry, 6);
19 EXPECT_TRUE(upgrader_at_6.has_value());
20 EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7");
21
22 auto upgrader_at_1 = findUpgrader(dummy_entry, 1);
23 EXPECT_TRUE(upgrader_at_1.has_value());
24 EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
25
26 auto upgrader_at_10 = findUpgrader(dummy_entry, 10);
27 EXPECT_TRUE(upgrader_at_1.has_value());
28 EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
29}
30
31TEST(UpgraderUtils, IsVersionMapSorted) {
32 auto map = get_operator_version_map();
33 // tests if the each list of UpgraderEntry in the map is sorted by
34 // their bumped_at_version field.
35 for (const auto& entry : map) {
36 std::vector<int> versions;
37 for (const auto& el : entry.second) {
38 versions.push_back(el.bumped_at_version);
39 }
40 EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end()));
41 }
42}
43
44TEST(UpgraderUtils, FindIfOpIsCurrent) {
45 std::vector<UpgraderEntry> dummy_entry = {
46 {4, "foo__0_3", "foo.bar()"},
47 {8, "foo__4_7", "foo.bar()"},
48 };
49
50 auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6);
51 auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8);
52 EXPECT_FALSE(isCurrent);
53 EXPECT_TRUE(isCurrentV2);
54
55 // symbol based look up
56 test_only_add_entry("foo", dummy_entry[0]);
57 test_only_add_entry("foo", dummy_entry[1]);
58 EXPECT_FALSE(isOpSymbolCurrent("foo", 6));
59 EXPECT_TRUE(isOpSymbolCurrent("foo", 8));
60 test_only_remove_entry("foo");
61}
62
63TEST(UpgraderUtils, CanLoadHistoricOp) {
64 std::vector<UpgraderEntry> dummy_entry = {
65 {4, "foo__0_3", "foo.bar()"},
66 {8, "foo__4_7", "foo.foo()"},
67 };
68
69 std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};
70
71 // symbol based look up
72 test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
73 test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);
74
75 auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
76 EXPECT_EQ(oldSchemas.size(), 2);
77 for (const auto& entry : oldSchemas) {
78 EXPECT_TRUE(
79 std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
80 }
81
82 auto oldSchemasWithCurrentVersion =
83 loadPossibleHistoricOps("old_op_not_exist", 9);
84 EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);
85
86 test_only_remove_entry("old_op_not_exist.first");
87 test_only_remove_entry("old_op_not_exist.first");
88
89 // it is ok to have old schemas without overload
90 test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
91 auto oldSchemasNoOverload =
92 loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
93 EXPECT_EQ(oldSchemasNoOverload.size(), 1);
94 EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
95 test_only_remove_entry("old_op_not_exist_no_overload");
96}
97
98} // namespace jit
99} // namespace torch
100