1 | #include <gtest/gtest.h> |
2 | |
3 | #include <test/cpp/jit/test_utils.h> |
4 | #include <torch/csrc/jit/testing/file_check.h> |
5 | #include "torch/csrc/jit/ir/ir.h" |
6 | #include "torch/csrc/jit/ir/irparser.h" |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | TEST(JitTypeTest, IsComplete) { |
12 | auto tt = c10::TensorType::create( |
13 | at::kFloat, |
14 | at::kCPU, |
15 | c10::SymbolicShape(std::vector<c10::optional<int64_t>>({1, 49})), |
16 | std::vector<c10::Stride>( |
17 | {c10::Stride{2, true, 1}, |
18 | c10::Stride{1, true, 1}, |
19 | c10::Stride{0, true, c10::nullopt}}), |
20 | false); |
21 | TORCH_INTERNAL_ASSERT(!tt->isComplete()); |
22 | TORCH_INTERNAL_ASSERT(!tt->strides().isComplete()); |
23 | } |
24 | |
25 | TEST(JitTypeTest, UnifyTypes) { |
26 | auto bool_tensor = TensorType::get()->withScalarType(at::kBool); |
27 | auto opt_bool_tensor = OptionalType::create(bool_tensor); |
28 | auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor); |
29 | TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(**unified_opt_bool)); |
30 | |
31 | auto tensor = TensorType::get(); |
32 | TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(*opt_bool_tensor)); |
33 | auto unified = unifyTypes(opt_bool_tensor, tensor); |
34 | TORCH_INTERNAL_ASSERT(unified); |
35 | auto elem = (*unified)->expectRef<OptionalType>().getElementType(); |
36 | TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(*TensorType::get())); |
37 | |
38 | auto opt_tuple_none_int = OptionalType::create( |
39 | TupleType::create({NoneType::get(), IntType::get()})); |
40 | auto tuple_int_none = TupleType::create({IntType::get(), NoneType::get()}); |
41 | auto out = unifyTypes(opt_tuple_none_int, tuple_int_none); |
42 | TORCH_INTERNAL_ASSERT(out); |
43 | |
44 | std::stringstream ss; |
45 | ss << (*out)->annotation_str(); |
46 | testing::FileCheck() |
47 | .check("Optional[Tuple[Optional[int], Optional[int]]]" ) |
48 | ->run(ss.str()); |
49 | |
50 | auto fut_1 = FutureType::create(IntType::get()); |
51 | auto fut_2 = FutureType::create(NoneType::get()); |
52 | auto fut_out = unifyTypes(fut_1, fut_2); |
53 | TORCH_INTERNAL_ASSERT(fut_out); |
54 | TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf( |
55 | *FutureType::create(OptionalType::create(IntType::get())))); |
56 | |
57 | auto dict_1 = DictType::create(IntType::get(), NoneType::get()); |
58 | auto dict_2 = DictType::create(IntType::get(), IntType::get()); |
59 | auto dict_out = unifyTypes(dict_1, dict_2); |
60 | TORCH_INTERNAL_ASSERT(!dict_out); |
61 | } |
62 | |
63 | } // namespace jit |
64 | } // namespace torch |
65 | |