1 | #include <caffe2/serialize/versions.h> |
---|---|
2 | #include <gtest/gtest.h> |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/operator_upgraders/upgraders.h> |
5 | #include <torch/csrc/jit/operator_upgraders/version_map.h> |
6 | #include <torch/csrc/jit/serialization/import.h> |
7 | |
8 | #include <test/cpp/jit/test_utils.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | // Basic tests to check if C++ torch::jit::load |
14 | // can load the upgraders fine |
15 | // TODO (tugsuu) add more tests |
16 | TEST(UpgraderLoad, CanPopulateUpgradersGraph) { |
17 | Module m("m"); |
18 | m.define(R"( |
19 | def forward(self, x: Tensor): |
20 | b = 5 |
21 | return torch.div(x, b) |
22 | )"); |
23 | std::stringstream ms; |
24 | m.save(ms); |
25 | auto loaded_m = torch::jit::load(ms); |
26 | auto version_map = get_operator_version_map(); |
27 | auto upgraders = dump_upgraders_map(); |
28 | |
29 | for (const auto& entry : version_map) { |
30 | auto list_of_upgraders_for_op = entry.second; |
31 | for (const auto& upgrader_entry : list_of_upgraders_for_op) { |
32 | EXPECT_TRUE( |
33 | upgraders.find(upgrader_entry.upgrader_name) != upgraders.end()); |
34 | } |
35 | } |
36 | |
37 | auto test_graph = loaded_m.get_method("forward").graph(); |
38 | // should have saved with version 4, so it is still up to date |
39 | testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph); |
40 | } |
41 | |
42 | } // namespace jit |
43 | } // namespace torch |
44 |