1#include <gtest/gtest.h>
2#include <torch/csrc/autograd/generated/variable_factories.h>
3#include <torch/csrc/utils/schema_info.h>
4
5namespace torch {
6namespace utils {
7using c10::SchemaArgType;
8
9TEST(FunctionSchemaIsAliasingTest, Basic) {
10 c10::FunctionSchema schema = torch::jit::parseSchema(
11 "aten::test.Tensor(Tensor(a) self, Tensor(b!) other, Tensor more_other) -> (Tensor(a), Tensor(b!))");
12 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 0}));
13 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 1}));
14 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 0}));
15 ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 1}));
16 ASSERT_FALSE(schema.is_aliasing({SchemaArgType::input, 2}));
17}
18
19TEST(FunctionSchemaIsAliasingTest, InvalidArgument) {
20 c10::FunctionSchema schema = torch::jit::parseSchema(
21 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
22 ASSERT_THROW(schema.is_aliasing({SchemaArgType::input, 4}), c10::Error);
23 ASSERT_THROW(schema.is_aliasing({SchemaArgType::output, 4}), c10::Error);
24}
25
26TEST(FunctionSchemaIsMutableTest, Basic) {
27 c10::FunctionSchema schema = torch::jit::parseSchema(
28 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
29 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
30 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
31 ASSERT_TRUE(schema.is_mutable("self"));
32 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
33 ASSERT_FALSE(schema.is_mutable("other"));
34 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
35 ASSERT_FALSE(schema.is_mutable("alpha"));
36}
37
38TEST(FunctionSchemaIsMutableTest, InvalidArgument) {
39 c10::FunctionSchema schema = torch::jit::parseSchema(
40 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
41 ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
42 ASSERT_THROW(schema.is_mutable({SchemaArgType::output, 4}), c10::Error);
43 ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
44}
45
46TEST(SchemaInfoIsMutableTest, Basic) {
47 SchemaInfo schema(
48 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
49 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
50 ASSERT_TRUE(schema.is_mutable("self"));
51 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
52 ASSERT_FALSE(schema.is_mutable("other"));
53 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2}));
54 ASSERT_FALSE(schema.is_mutable("alpha"));
55}
56
57TEST(SchemaInfoIsMutableTest, InvalidArgument) {
58 SchemaInfo schema(
59 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
60 ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error);
61 ASSERT_THROW(schema.is_mutable("named_argument"), c10::Error);
62}
63
64TEST(SchemaInfoIsMutableTest, AliasingInputs) {
65 SchemaInfo schema(
66 "aten::test.Tensor(Tensor(a!) self, Tensor(b) other, *, Scalar alpha=1) -> (Tensor(a!), Tensor(b))");
67 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0}));
68 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0}));
69 ASSERT_TRUE(schema.is_mutable("self"));
70 ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1}));
71 ASSERT_FALSE(schema.is_mutable({SchemaArgType::output, 1}));
72 ASSERT_FALSE(schema.is_mutable("other"));
73 at::Tensor input = at::randn({3, 3});
74 schema.addArgumentValue("self", input);
75 schema.addArgumentValue("other", input);
76 ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 1}));
77 ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 1}));
78 ASSERT_TRUE(schema.is_mutable("other"));
79}
80
81TEST(SchemaInfoIsMutableTest, InstanceNorm) {
82 SchemaInfo schema_info(
83 "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor");
84 ASSERT_TRUE(schema_info.is_mutable("running_mean"));
85 ASSERT_TRUE(schema_info.is_mutable("running_var"));
86 schema_info.addArgumentValue("use_input_stats", false);
87 ASSERT_FALSE(schema_info.is_mutable("running_mean"));
88 ASSERT_FALSE(schema_info.is_mutable("running_var"));
89}
90
91TEST(SchemaInfoIsMutableTest, BatchNorm) {
92 SchemaInfo schema_info(
93 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor");
94 ASSERT_TRUE(schema_info.is_mutable("running_mean"));
95 ASSERT_TRUE(schema_info.is_mutable("running_var"));
96 schema_info.addArgumentValue("training", false);
97 ASSERT_FALSE(schema_info.is_mutable("running_mean"));
98 ASSERT_FALSE(schema_info.is_mutable("running_var"));
99}
100
101TEST(SchemaInfoIsNonDeterministicTest, Basic) {
102 SchemaInfo deterministic_schema_info(
103 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
104 SchemaInfo nondeterministic_schema_info(
105 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor");
106 ASSERT_FALSE(deterministic_schema_info.is_nondeterministic());
107 ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic());
108}
109
110TEST(SchemaInfoIsNonDeterministicTest, Dropout) {
111 SchemaInfo droupout_schema_info(
112 "aten::dropout(Tensor input, float p, bool train) -> Tensor");
113 ASSERT_TRUE(droupout_schema_info.is_nondeterministic());
114 droupout_schema_info.addArgumentValue("train", false);
115 ASSERT_FALSE(droupout_schema_info.is_nondeterministic());
116}
117
118TEST(FunctionSchemaMayAliasTest, Basic) {
119 c10::FunctionSchema schema = torch::jit::parseSchema(
120 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
121 ASSERT_TRUE(
122 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
123 ASSERT_FALSE(
124 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
125 ASSERT_FALSE(
126 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
127}
128
129TEST(FunctionSchemaMayAliasTest, InvalidArgument) {
130 c10::FunctionSchema schema = torch::jit::parseSchema(
131 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
132 ASSERT_THROW(
133 schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}),
134 c10::Error);
135 ASSERT_THROW(
136 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}),
137 c10::Error);
138}
139
140TEST(FunctionSchemaMayAliasTest, Wildcard) {
141 c10::FunctionSchema schema = torch::jit::parseSchema(
142 "aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)");
143 ASSERT_TRUE(
144 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
145 ASSERT_FALSE(
146 schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
147}
148
149TEST(SchemaInfoMayAliasTest, AliasingInputs) {
150 SchemaInfo schema(
151 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
152 ASSERT_FALSE(
153 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
154 at::Tensor input = at::randn({3, 3});
155 schema.addArgumentValue("self", input);
156 schema.addArgumentValue("other", input);
157 ASSERT_TRUE(
158 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
159}
160
161TEST(SchemaInfoMayAliasTest, AliasingOutputs) {
162 SchemaInfo schema(
163 "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
164 ASSERT_FALSE(
165 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
166 at::Tensor input = at::randn({3, 3});
167 schema.addArgumentValue("min", input);
168 schema.addArgumentValue("max", input);
169 ASSERT_TRUE(
170 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
171}
172
173TEST(SchemaInfoMayAliasTest, AliasingInputOutput) {
174 SchemaInfo schema(
175 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
176 ASSERT_TRUE(
177 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
178 ASSERT_FALSE(
179 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
180 at::Tensor input = at::randn({3, 3});
181 schema.addArgumentValue("self", input);
182 schema.addArgumentValue("other", input);
183 ASSERT_TRUE(
184 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
185 ASSERT_TRUE(
186 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
187}
188
189TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) {
190 SchemaInfo schema(
191 "aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))");
192 ASSERT_TRUE(
193 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
194 ASSERT_TRUE(
195 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
196 ASSERT_TRUE(
197 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
198 ASSERT_FALSE(
199 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
200 ASSERT_FALSE(
201 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
202 ASSERT_FALSE(
203 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
204 ASSERT_FALSE(
205 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
206 at::Tensor input = at::randn({3, 3});
207 schema.addArgumentValue("a", input);
208 schema.addArgumentValue("b", input);
209 ASSERT_TRUE(
210 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
211 ASSERT_TRUE(
212 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1}));
213 ASSERT_TRUE(
214 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1}));
215 ASSERT_TRUE(
216 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
217 ASSERT_TRUE(
218 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
219 ASSERT_TRUE(
220 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1}));
221 ASSERT_TRUE(
222 schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
223}
224
225TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) {
226 SchemaInfo schema(
227 "aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))");
228 ASSERT_TRUE(
229 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
230 ASSERT_TRUE(
231 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
232 ASSERT_TRUE(
233 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
234 ASSERT_TRUE(
235 schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0}));
236}
237
238TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) {
239 SchemaInfo schema(
240 "aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))");
241 ASSERT_TRUE(
242 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
243 ASSERT_TRUE(
244 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
245 ASSERT_TRUE(
246 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
247}
248
249TEST(SchemaInfoMayAliasTest, MismatchingTypes) {
250 SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)");
251 ASSERT_FALSE(
252 schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
253}
254
255TEST(FunctionSchemaMayContainAliasTest, Basic) {
256 c10::FunctionSchema schema = torch::jit::parseSchema(
257 "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
258 ASSERT_TRUE(schema.may_contain_alias(
259 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}));
260 ASSERT_FALSE(schema.may_contain_alias(
261 {SchemaArgType::input, 1}, {SchemaArgType::output, 0}));
262 ASSERT_FALSE(schema.may_contain_alias(
263 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}));
264}
265
266TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
267 c10::FunctionSchema schema = torch::jit::parseSchema(
268 "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
269 ASSERT_FALSE(
270 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
271 ASSERT_TRUE(schema.may_contain_alias(
272 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
273 ASSERT_TRUE(schema.may_contain_alias(
274 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
275 ASSERT_FALSE(schema.may_contain_alias(
276 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
277 ASSERT_FALSE(
278 schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0}));
279}
280
281TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
282 c10::FunctionSchema schema =
283 torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
284 ASSERT_FALSE(
285 schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
286 ASSERT_TRUE(schema.may_contain_alias(
287 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
288 ASSERT_TRUE(schema.may_contain_alias(
289 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false));
290 ASSERT_TRUE(schema.may_contain_alias(
291 {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false));
292}
293
294TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsEqual) {
295 SchemaInfo schema(
296 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
297 ASSERT_FALSE(schema.may_contain_alias(
298 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
299 at::Tensor input = at::randn({3, 3});
300 schema.addArgumentValue("self", input);
301 schema.addArgumentValue("other", input);
302 ASSERT_TRUE(schema.may_contain_alias(
303 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
304 ASSERT_TRUE(schema.may_contain_alias(
305 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
306 ASSERT_TRUE(schema.may_contain_alias(
307 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
308}
309
310TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsContained) {
311 SchemaInfo schema(
312 "aten::test.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor");
313 ASSERT_FALSE(schema.may_contain_alias(
314 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
315 at::Tensor input = at::randn({3, 3});
316 schema.addArgumentValue("self", c10::List<at::Tensor>({input}));
317 schema.addArgumentValue("other", input);
318 ASSERT_TRUE(schema.may_contain_alias(
319 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
320 ASSERT_TRUE(schema.may_contain_alias(
321 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false));
322 ASSERT_FALSE(schema.may_contain_alias(
323 {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false));
324}
325
326TEST(SchemaInfoMayContainAliasTest, ContainAliasOutputs) {
327 SchemaInfo schema(
328 "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)");
329 ASSERT_FALSE(schema.may_contain_alias(
330 {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
331 at::Tensor input = at::randn({3, 3});
332 schema.addArgumentValue("min", input);
333 schema.addArgumentValue("max", input);
334 ASSERT_TRUE(schema.may_contain_alias(
335 {SchemaArgType::output, 0}, {SchemaArgType::output, 1}));
336}
337
338TEST(SchemaInfoMayContainAliasTest, ContainAliasInputOutput) {
339 SchemaInfo schema(
340 "aten::test.tensor(Tensor(a) self, Tensor[] other) -> Tensor(a)");
341 ASSERT_FALSE(schema.may_contain_alias(
342 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
343 at::Tensor input = at::randn({3, 3});
344 schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
345 schema.addArgumentValue("self", input);
346 ASSERT_TRUE(schema.may_contain_alias(
347 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
348 ASSERT_FALSE(schema.may_contain_alias(
349 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}, false));
350 ASSERT_TRUE(schema.may_contain_alias(
351 {SchemaArgType::input, 1}, {SchemaArgType::output, 0}, false));
352}
353
354TEST(SchemaInfoMayContainAliasTest, InputAndOutputContainers) {
355 SchemaInfo schema(
356 "aten::test.tensor(Tensor self, Tensor[] other) -> Tensor[]");
357 ASSERT_TRUE(schema.may_contain_alias(
358 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
359 ASSERT_FALSE(schema.may_contain_alias(
360 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
361 ASSERT_FALSE(schema.may_contain_alias(
362 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
363 at::Tensor input = at::randn({3, 3});
364 schema.addArgumentValue("other", c10::List<at::Tensor>({input}));
365 schema.addArgumentValue("self", input);
366 ASSERT_TRUE(schema.may_contain_alias(
367 {SchemaArgType::output, 0}, {SchemaArgType::input, 1}));
368 ASSERT_TRUE(schema.may_contain_alias(
369 {SchemaArgType::output, 0}, {SchemaArgType::input, 0}));
370 ASSERT_TRUE(schema.may_contain_alias(
371 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
372}
373
374TEST(SchemaInfoMayContainAliasTest, Wildcard) {
375 SchemaInfo schema(
376 "aten::test.tensor(Tensor a, Tensor[] b, Tensor(*) c) -> Tensor[]");
377 ASSERT_FALSE(schema.may_contain_alias(
378 {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
379 ASSERT_FALSE(schema.may_contain_alias(
380 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
381 ASSERT_TRUE(schema.may_contain_alias(
382 {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
383 at::Tensor input = at::randn({3, 3});
384 schema.addArgumentValue("b", c10::List<at::Tensor>({input}));
385 schema.addArgumentValue("a", input);
386 ASSERT_TRUE(schema.may_contain_alias(
387 {SchemaArgType::input, 0}, {SchemaArgType::input, 2}));
388 ASSERT_TRUE(schema.may_contain_alias(
389 {SchemaArgType::input, 0}, {SchemaArgType::input, 1}));
390 ASSERT_TRUE(schema.may_contain_alias(
391 {SchemaArgType::input, 2}, {SchemaArgType::input, 1}));
392}
393} // namespace utils
394} // namespace torch
395