1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/jit.h> |
4 | #include <torch/script.h> |
5 | #include <torch/types.h> |
6 | |
7 | #include <string> |
8 | |
9 | TEST(TorchScriptTest, CanCompileMultipleFunctions) { |
10 | auto module = torch::jit::compile(R"JIT( |
11 | def test_mul(a, b): |
12 | return a * b |
13 | def test_relu(a, b): |
14 | return torch.relu(a + b) |
15 | def test_while(a, i): |
16 | while bool(i < 10): |
17 | a += a |
18 | i += 1 |
19 | return a |
20 | def test_len(a : List[int]): |
21 | return len(a) |
22 | )JIT" ); |
23 | auto a = torch::ones(1); |
24 | auto b = torch::ones(1); |
25 | |
26 | ASSERT_EQ(1, module->run_method("test_mul" , a, b).toTensor().item<int64_t>()); |
27 | |
28 | ASSERT_EQ( |
29 | 2, module->run_method("test_relu" , a, b).toTensor().item<int64_t>()); |
30 | |
31 | ASSERT_TRUE( |
32 | 0x200 == |
33 | module->run_method("test_while" , a, b).toTensor().item<int64_t>()); |
34 | |
35 | at::IValue list = c10::List<int64_t>({3, 4}); |
36 | ASSERT_EQ(2, module->run_method("test_len" , list).toInt()); |
37 | } |
38 | |
39 | TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) { |
40 | auto module = torch::jit::compile(R"JIT( |
41 | def nested_loop(a: List[List[Tensor]], b: int): |
42 | return torch.tensor(1.0) + b |
43 | )JIT" ); |
44 | |
45 | auto b = 3; |
46 | |
47 | torch::List<torch::Tensor> list({torch::rand({4, 4})}); |
48 | |
49 | torch::List<torch::List<torch::Tensor>> list_of_lists; |
50 | list_of_lists.push_back(list); |
51 | module->run_method("nested_loop" , list_of_lists, b); |
52 | |
53 | auto generic_list = c10::impl::GenericList(at::TensorType::get()); |
54 | auto empty_generic_list = |
55 | c10::impl::GenericList(at::ListType::create(at::TensorType::get())); |
56 | empty_generic_list.push_back(generic_list); |
57 | module->run_method("nested_loop" , empty_generic_list, b); |
58 | |
59 | auto too_many_lists = c10::impl::GenericList( |
60 | at::ListType::create(at::ListType::create(at::TensorType::get()))); |
61 | too_many_lists.push_back(empty_generic_list); |
62 | try { |
63 | module->run_method("nested_loop" , too_many_lists, b); |
64 | AT_ASSERT(false); |
65 | } catch (const c10::Error& error) { |
66 | AT_ASSERT( |
67 | std::string(error.what_without_backtrace()) |
68 | .find("nested_loop() Expected a value of type 'List[List[Tensor]]'" |
69 | " for argument 'a' but instead found type " |
70 | "'List[List[List[Tensor]]]'" ) == 0); |
71 | }; |
72 | } |
73 | |
74 | TEST(TorchScriptTest, TestDictArgMatching) { |
75 | auto module = torch::jit::compile(R"JIT( |
76 | def dict_op(a: Dict[str, Tensor], b: str): |
77 | return a[b] |
78 | )JIT" ); |
79 | c10::Dict<std::string, at::Tensor> dict; |
80 | dict.insert("hello" , torch::ones({2})); |
81 | auto output = module->run_method("dict_op" , dict, std::string("hello" )); |
82 | ASSERT_EQ(1, output.toTensor()[0].item<int64_t>()); |
83 | } |
84 | |
85 | TEST(TorchScriptTest, TestTupleArgMatching) { |
86 | auto module = torch::jit::compile(R"JIT( |
87 | def tuple_op(a: Tuple[List[int]]): |
88 | return a |
89 | )JIT" ); |
90 | |
91 | c10::List<int64_t> int_list({1}); |
92 | auto tuple_generic_list = c10::ivalue::Tuple::create({int_list}); |
93 | |
94 | // doesn't fail on arg matching |
95 | module->run_method("tuple_op" , tuple_generic_list); |
96 | } |
97 | |
98 | TEST(TorchScriptTest, TestOptionalArgMatching) { |
99 | auto module = torch::jit::compile(R"JIT( |
100 | def optional_tuple_op(a: Optional[Tuple[int, str]]): |
101 | if a is None: |
102 | return 0 |
103 | else: |
104 | return a[0] |
105 | )JIT" ); |
106 | |
107 | auto optional_tuple = c10::ivalue::Tuple::create({2, std::string("hi" )}); |
108 | |
109 | ASSERT_EQ(2, module->run_method("optional_tuple_op" , optional_tuple).toInt()); |
110 | ASSERT_EQ( |
111 | 0, module->run_method("optional_tuple_op" , torch::jit::IValue()).toInt()); |
112 | } |
113 | |
114 | TEST(TorchScriptTest, TestPickle) { |
115 | torch::IValue float_value(2.3); |
116 | |
117 | // TODO: when tensors are stored in the pickle, delete this |
118 | std::vector<at::Tensor> tensor_table; |
119 | auto data = torch::jit::pickle(float_value, &tensor_table); |
120 | |
121 | torch::IValue ivalue = torch::jit::unpickle(data.data(), data.size()); |
122 | |
123 | double diff = ivalue.toDouble() - float_value.toDouble(); |
124 | double eps = 0.0001; |
125 | ASSERT_TRUE(diff < eps && diff > -eps); |
126 | } |
127 | |