1 | #include <gtest/gtest.h> |
2 | |
3 | #include "test/cpp/jit/test_utils.h" |
4 | #include "torch/csrc/jit/runtime/graph_executor.h" |
5 | #include "torch/jit.h" |
6 | #include "torch/script.h" |
7 | #include "torch/torch.h" |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | TEST(GraphExecutorTest, Basic_CUDA) { |
13 | constexpr int batch_size = 4; |
14 | constexpr int input_size = 256; |
15 | |
16 | int hidden_size = 2 * input_size; |
17 | |
18 | auto input = at::randn({batch_size, input_size}, at::kCUDA); |
19 | auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); |
20 | auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); |
21 | auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); |
22 | auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); |
23 | |
24 | auto g = build_lstm(); |
25 | GraphExecutor executor(g, "" ); |
26 | auto stack = createStack({input, hx, cx, w_ih, w_hh}); |
27 | executor.run(stack); |
28 | ASSERT_EQ(stack.size(), 2); |
29 | at::Tensor r0, r1; |
30 | std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh); |
31 | ASSERT_TRUE(almostEqual(stack[0].toTensor(), r0)); |
32 | ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1)); |
33 | } |
34 | |
35 | TEST(GraphExecutorTest, runAsync_executor) { |
36 | /* |
37 | TODO: there are some problem with C++ parsing script program involving |
38 | fork. Use the test module below for now. |
39 | issue about this: github.com/pytorch/pytorch/issues/46368 |
40 | The test module file is generated by following: |
41 | class DemoModule(torch.nn.Module): |
42 | def forward(self): |
43 | r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) |
44 | r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100)) |
45 | return r1.wait() + r2.wait() |
46 | demo = DemoModule() |
47 | torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt') |
48 | */ |
49 | std::string filePath(__FILE__); |
50 | auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\" ) + 1); |
51 | testModelFile.append("test_interpreter_async.pt" ); |
52 | auto module = load(testModelFile); |
53 | auto graph = module.get_method("forward" ).graph(); |
54 | GraphExecutor graphExecutor(graph, "" ); |
55 | auto asyncCounter = 0; |
56 | std::mutex mtx; |
57 | // a dummy executor which actually use at::launch, but add up a counter |
58 | auto launcher = [&](std::function<void()> f) { |
59 | mtx.lock(); |
60 | ++asyncCounter; |
61 | mtx.unlock(); |
62 | at::launch(std::move(f)); |
63 | }; |
64 | std::vector<IValue> stack; |
65 | // NOLINTNEXTLINE(modernize-use-emplace) |
66 | stack.push_back(module._ivalue()); |
67 | graphExecutor.runAsync(stack, launcher)->wait(); |
68 | ASSERT_TRUE(asyncCounter > 0); |
69 | } |
70 | |
71 | } // namespace jit |
72 | } // namespace torch |
73 | |