1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/api/function_impl.h>
4#include <torch/csrc/jit/runtime/argument_spec.h>
5#include <torch/jit.h>
6
7#include "test/cpp/jit/test_utils.h"
8
9namespace torch {
10namespace jit {
11
12namespace {
13
14at::Device device(const autograd::Variable& v) {
15 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
16 return v.device();
17}
18
19bool isEqual(at::IntArrayRef lhs, at::IntArrayRef rhs) {
20 return lhs.size() == rhs.size() &&
21 std::equal(lhs.begin(), lhs.end(), rhs.begin());
22}
23
24bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) {
25 if (!ti.defined())
26 return ti.defined() == v.defined();
27 return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
28 ti.type() == v.scalar_type() && isEqual(ti.sizes(), v.sizes()) &&
29 isEqual(ti.strides(), v.strides());
30}
31
32bool isEqual(const ArgumentInfo& ti, const autograd::Variable& v) {
33 if (!ti.defined())
34 return ti.defined() == v.defined();
35 return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
36 ti.type() == v.scalar_type() && ti.dim() == v.dim();
37}
38
39autograd::Variable var(
40 at::TensorOptions t,
41 at::IntArrayRef sizes,
42 bool requires_grad) {
43 return autograd::make_variable(at::rand(sizes, t), requires_grad);
44}
45autograd::Variable undef() {
46 return autograd::Variable();
47}
48} // namespace
49
50TEST(ArgumentSpecTest, CompleteArgumentSpec_CUDA) {
51 auto const CF = at::CPU(at::kFloat);
52 auto const CD = at::CPU(at::kDouble);
53 auto const GF = at::CUDA(at::kFloat);
54 auto const GD = at::CUDA(at::kDouble);
55
56 auto list = createStack(
57 {var(CF, {1}, true),
58 var(CD, {1, 2}, false),
59 var(GF, {}, true),
60 var(GD, {4, 5, 6}, false),
61 undef()});
62
63 // make sure we have some non-standard strides
64 list[1].toTensor().transpose_(0, 1);
65
66 // same list but different backing values
67 auto list2 = createStack(
68 {var(CF, {1}, true),
69 var(CD, {1, 2}, false),
70 var(GF, {}, true),
71 var(GD, {4, 5, 6}, false),
72 undef()});
73 list2[1].toTensor().transpose_(0, 1);
74
75 CompleteArgumentSpec a(true, list);
76 CompleteArgumentSpec b(true, list);
77 ASSERT_EQ(a.hashCode(), b.hashCode());
78
79 ASSERT_EQ(a, b);
80 CompleteArgumentSpec d(true, list2);
81 ASSERT_EQ(d, a);
82 ASSERT_EQ(d.hashCode(), a.hashCode());
83
84 for (size_t i = 0; i < list.size(); ++i) {
85 ASSERT_TRUE(isEqual(a.at(i), list[i].toTensor()));
86 }
87 CompleteArgumentSpec no_grad(/*with_grad=*/false, list);
88 ASSERT_TRUE(no_grad != a);
89
90 std::unordered_set<CompleteArgumentSpec> spec;
91 spec.insert(a); // we use a below, so no move
92 ASSERT_TRUE(spec.count(b) > 0);
93 ASSERT_EQ(spec.count(no_grad), 0);
94 spec.insert(std::move(no_grad));
95 ASSERT_EQ(spec.count(CompleteArgumentSpec(true, list)), 1);
96
97 list2[1].toTensor().transpose_(0, 1);
98 CompleteArgumentSpec c(true, list2); // same as list, except for one stride
99 ASSERT_FALSE(c == a);
100 ASSERT_EQ(spec.count(c), 0);
101
102 Stack stack = {var(CF, {1, 2}, true), 3, var(CF, {1, 2}, true)};
103 CompleteArgumentSpec with_const(true, stack);
104 ASSERT_EQ(with_const.at(2).sizes().size(), 2);
105}
106
107// TODO: this test was disabled for unknown reasons and doesn't run.
108// static size_t hashCode(const TensorTypePtr& ptr) {
109// return std::hash<TensorType>()(*ptr.get());
110// }
111
112// TEST(ArgumentSpecTest, VaryingShape) {
113// c10::VaryingShape<int64_t> vs(c10::optional<size_t>{});
114// auto ptt_empty1 = TensorType::create({}, {}, vs, vs, false);
115// auto ptt_empty2 = TensorType::create({}, {}, vs, vs, false);
116// ASSERT_EQ(hashCode(ptt_empty1), hashCode(ptt_empty2));
117
118// c10::VaryingShape<int64_t> vs22(std::vector<int64_t>{2, 2});
119// auto ptt_vs22_vs22_1 = TensorType::create({}, {}, vs22, vs22, false);
120// auto ptt_vs22_vs22_2 = TensorType::create({}, {}, vs22, vs22, false);
121// ASSERT_EQ(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs22_2));
122
123// c10::VaryingShape<int64_t> vs23(std::vector<int64_t>{2, 3});
124// auto ptt_vs22_vs23_2 = TensorType::create({}, {}, vs22, vs23, false);
125// ASSERT_NE(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs23_2));
126
127// auto ptt_vs22_vs22_1_true = TensorType::create({}, {}, vs22, vs22, true);
128// auto ptt_vs22_vs22_2_true = TensorType::create({}, {}, vs22, vs22, true);
129// ASSERT_EQ(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_2_true));
130
131// auto ptt_vs22_vs22_1_false = TensorType::create({}, {}, vs22, vs22, false);
132// ASSERT_NE(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_1_false));
133// }
134
135TEST(ArgumentSpecTest, Basic_CUDA) {
136 auto& CF = at::CPU(at::kFloat);
137 auto& CD = at::CPU(at::kDouble);
138 auto& GF = at::CUDA(at::kFloat);
139 auto& GD = at::CUDA(at::kDouble);
140
141 auto graph = toGraphFunction(jit::compile(R"JIT(
142 def fn(a, b, c, d, e):
143 return a, b, c, d, e
144 )JIT")
145 ->get_function("fn"))
146 .graph();
147
148 ArgumentSpecCreator arg_spec_creator(*graph);
149
150 auto list = createStack(
151 {var(CF, {1}, true),
152 var(CD, {1, 2}, false),
153 var(GF, {}, true),
154 var(GD, {4, 5, 6}, false),
155 undef()});
156
157 // make sure we have some non-standard strides
158 list[1].toTensor().transpose_(0, 1);
159
160 // same list but different backing values
161 auto list2 = createStack(
162 {var(CF, {1}, true),
163 var(CD, {1, 2}, false),
164 var(GF, {}, true),
165 var(GD, {4, 5, 6}, false),
166 undef()});
167 list2[1].toTensor().transpose_(0, 1);
168
169 ArgumentSpec a = arg_spec_creator.create(true, list);
170 ArgumentSpec b = arg_spec_creator.create(true, list);
171 ASSERT_EQ(a.hashCode(), b.hashCode());
172
173 ASSERT_EQ(a, b);
174 ArgumentSpec d = arg_spec_creator.create(true, list2);
175 ASSERT_EQ(d, a);
176 ASSERT_EQ(d.hashCode(), a.hashCode());
177
178 for (size_t i = 0; i < list.size(); ++i) {
179 ASSERT_TRUE(isEqual(a.tensorAt(i), list[i].toTensor()));
180 }
181 ArgumentSpec no_grad = arg_spec_creator.create(/*with_grad=*/false, list);
182 ASSERT_TRUE(no_grad != a);
183
184 std::unordered_set<ArgumentSpec> spec;
185 spec.insert(a); // we still need a for the test below
186 ASSERT_TRUE(spec.count(b) > 0);
187 ASSERT_EQ(spec.count(no_grad), 0);
188 spec.insert(std::move(no_grad));
189 ASSERT_EQ(spec.count(arg_spec_creator.create(true, list)), 1);
190
191 list2[1].toTensor().transpose_(0, 1);
192 ArgumentSpec c = arg_spec_creator.create(
193 true, list2); // same as list, except for one stride, used to be
194 // different, now the same
195 ASSERT_TRUE(c == a);
196 ASSERT_EQ(spec.count(c), 1);
197}
198
199} // namespace jit
200} // namespace torch
201