1 | #include <ATen/core/ivalue.h> |
2 | #include <c10/util/Exception.h> |
3 | #include <torch/csrc/Export.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/script.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | #ifdef _MSC_VER |
11 | #define JIT_TEST_API |
12 | #else |
13 | #define JIT_TEST_API TORCH_API |
14 | #endif |
15 | |
16 | namespace { |
17 | |
18 | bool isSandcastle() { |
19 | return ( |
20 | (std::getenv("SANDCASTLE" )) || |
21 | (std::getenv("TW_JOB_USER" ) && |
22 | std::string(std::getenv("TW_JOB_USER" )) == "sandcastle" )); |
23 | } |
24 | |
25 | void testEvalModeForLoadedModule() { |
26 | if (isSandcastle()) |
27 | return; // The module file to load is not generated in Sandcastle |
28 | std::string module_path = "dropout_model.pt" ; |
29 | torch::jit::Module module = torch::jit::load(module_path); |
30 | AT_ASSERT(module.attr("dropout" ).toModule().is_training()); |
31 | module.eval(); |
32 | AT_ASSERT(!module.attr("dropout" ).toModule().is_training()); |
33 | module.train(); |
34 | AT_ASSERT(module.attr("dropout" ).toModule().is_training()); |
35 | } |
36 | |
37 | // TODO: this test never ran before and is broken. |
38 | // void testSerializationInterop() { |
39 | // if (isSandcastle()) { |
40 | // // The module file to load is not generated in Sandcastle |
41 | // return; |
42 | // } |
43 | |
44 | // // This should be generated by `test/cpp/jit/tests_setup.py` |
45 | // std::ifstream input_stream("ivalue.pt"); |
46 | // std::vector<char> input; |
47 | // input.insert( |
48 | // input.begin(), |
49 | // std::istream_iterator<char>(input_stream), |
50 | // std::istream_iterator<char>()); |
51 | // IValue ivalue = pickle_load(input); |
52 | |
53 | // auto elements = ivalue.toTupleRef().elements(); |
54 | // auto ones = torch::ones({2, 2}); |
55 | // AT_ASSERT(ones.equal(elements.at(0).toTensor())); |
56 | |
57 | // // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
58 | // auto twos = torch::ones({3, 5}) * 2; |
59 | // AT_ASSERT(twos.equal(elements.at(1).toTensor())); |
60 | // } |
61 | |
62 | void testTorchSaveError() { |
63 | if (isSandcastle()) { |
64 | // The file to load is not generated in Sandcastle |
65 | return; |
66 | } |
67 | |
68 | // This should be generated by `test/cpp/jit/tests_setup.py` |
69 | bool passed = true; |
70 | try { |
71 | torch::jit::load("eager_value.pt" ); |
72 | passed = false; |
73 | } catch (const std::exception& c) { |
74 | } |
75 | // Ensure torch::jit::load did not run |
76 | AT_ASSERT(passed); |
77 | } |
78 | } // namespace |
79 | |
80 | JIT_TEST_API void runJITCPPTests() { |
81 | // TODO: this test never ran before and is broken. |
82 | // testSerializationInterop(); |
83 | testEvalModeForLoadedModule(); |
84 | testTorchSaveError(); |
85 | } |
86 | } // namespace jit |
87 | } // namespace torch |
88 | |