1#include <gtest/gtest.h>
2
3#include <test/cpp/jit/test_utils.h>
4#include <torch/csrc/jit/testing/file_check.h>
5#include "torch/csrc/jit/ir/ir.h"
6#include "torch/csrc/jit/ir/irparser.h"
7
8namespace torch {
9namespace jit {
10
11TEST(JitTypeTest, IsComplete) {
12 auto tt = c10::TensorType::create(
13 at::kFloat,
14 at::kCPU,
15 c10::SymbolicShape(std::vector<c10::optional<int64_t>>({1, 49})),
16 std::vector<c10::Stride>(
17 {c10::Stride{2, true, 1},
18 c10::Stride{1, true, 1},
19 c10::Stride{0, true, c10::nullopt}}),
20 false);
21 TORCH_INTERNAL_ASSERT(!tt->isComplete());
22 TORCH_INTERNAL_ASSERT(!tt->strides().isComplete());
23}
24
25TEST(JitTypeTest, UnifyTypes) {
26 auto bool_tensor = TensorType::get()->withScalarType(at::kBool);
27 auto opt_bool_tensor = OptionalType::create(bool_tensor);
28 auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor);
29 TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(**unified_opt_bool));
30
31 auto tensor = TensorType::get();
32 TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(*opt_bool_tensor));
33 auto unified = unifyTypes(opt_bool_tensor, tensor);
34 TORCH_INTERNAL_ASSERT(unified);
35 auto elem = (*unified)->expectRef<OptionalType>().getElementType();
36 TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(*TensorType::get()));
37
38 auto opt_tuple_none_int = OptionalType::create(
39 TupleType::create({NoneType::get(), IntType::get()}));
40 auto tuple_int_none = TupleType::create({IntType::get(), NoneType::get()});
41 auto out = unifyTypes(opt_tuple_none_int, tuple_int_none);
42 TORCH_INTERNAL_ASSERT(out);
43
44 std::stringstream ss;
45 ss << (*out)->annotation_str();
46 testing::FileCheck()
47 .check("Optional[Tuple[Optional[int], Optional[int]]]")
48 ->run(ss.str());
49
50 auto fut_1 = FutureType::create(IntType::get());
51 auto fut_2 = FutureType::create(NoneType::get());
52 auto fut_out = unifyTypes(fut_1, fut_2);
53 TORCH_INTERNAL_ASSERT(fut_out);
54 TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf(
55 *FutureType::create(OptionalType::create(IntType::get()))));
56
57 auto dict_1 = DictType::create(IntType::get(), NoneType::get());
58 auto dict_2 = DictType::create(IntType::get(), IntType::get());
59 auto dict_out = unifyTypes(dict_1, dict_2);
60 TORCH_INTERNAL_ASSERT(!dict_out);
61}
62
63} // namespace jit
64} // namespace torch
65