1 | #include <gtest/gtest.h> |
2 | #include <torch/csrc/autograd/generated/variable_factories.h> |
3 | #include <torch/csrc/utils/schema_info.h> |
4 | |
5 | namespace torch { |
6 | namespace utils { |
7 | using c10::SchemaArgType; |
8 | |
9 | TEST(FunctionSchemaIsAliasingTest, Basic) { |
10 | c10::FunctionSchema schema = torch::jit::parseSchema( |
11 | "aten::test.Tensor(Tensor(a) self, Tensor(b!) other, Tensor more_other) -> (Tensor(a), Tensor(b!))" ); |
12 | ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 0})); |
13 | ASSERT_TRUE(schema.is_aliasing({SchemaArgType::output, 1})); |
14 | ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 0})); |
15 | ASSERT_TRUE(schema.is_aliasing({SchemaArgType::input, 1})); |
16 | ASSERT_FALSE(schema.is_aliasing({SchemaArgType::input, 2})); |
17 | } |
18 | |
19 | TEST(FunctionSchemaIsAliasingTest, InvalidArgument) { |
20 | c10::FunctionSchema schema = torch::jit::parseSchema( |
21 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
22 | ASSERT_THROW(schema.is_aliasing({SchemaArgType::input, 4}), c10::Error); |
23 | ASSERT_THROW(schema.is_aliasing({SchemaArgType::output, 4}), c10::Error); |
24 | } |
25 | |
26 | TEST(FunctionSchemaIsMutableTest, Basic) { |
27 | c10::FunctionSchema schema = torch::jit::parseSchema( |
28 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
29 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0})); |
30 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); |
31 | ASSERT_TRUE(schema.is_mutable("self" )); |
32 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); |
33 | ASSERT_FALSE(schema.is_mutable("other" )); |
34 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2})); |
35 | ASSERT_FALSE(schema.is_mutable("alpha" )); |
36 | } |
37 | |
38 | TEST(FunctionSchemaIsMutableTest, InvalidArgument) { |
39 | c10::FunctionSchema schema = torch::jit::parseSchema( |
40 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
41 | ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error); |
42 | ASSERT_THROW(schema.is_mutable({SchemaArgType::output, 4}), c10::Error); |
43 | ASSERT_THROW(schema.is_mutable("named_argument" ), c10::Error); |
44 | } |
45 | |
46 | TEST(SchemaInfoIsMutableTest, Basic) { |
47 | SchemaInfo schema( |
48 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
49 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); |
50 | ASSERT_TRUE(schema.is_mutable("self" )); |
51 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); |
52 | ASSERT_FALSE(schema.is_mutable("other" )); |
53 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 2})); |
54 | ASSERT_FALSE(schema.is_mutable("alpha" )); |
55 | } |
56 | |
57 | TEST(SchemaInfoIsMutableTest, InvalidArgument) { |
58 | SchemaInfo schema( |
59 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
60 | ASSERT_THROW(schema.is_mutable({SchemaArgType::input, 4}), c10::Error); |
61 | ASSERT_THROW(schema.is_mutable("named_argument" ), c10::Error); |
62 | } |
63 | |
64 | TEST(SchemaInfoIsMutableTest, AliasingInputs) { |
65 | SchemaInfo schema( |
66 | "aten::test.Tensor(Tensor(a!) self, Tensor(b) other, *, Scalar alpha=1) -> (Tensor(a!), Tensor(b))" ); |
67 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 0})); |
68 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 0})); |
69 | ASSERT_TRUE(schema.is_mutable("self" )); |
70 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::input, 1})); |
71 | ASSERT_FALSE(schema.is_mutable({SchemaArgType::output, 1})); |
72 | ASSERT_FALSE(schema.is_mutable("other" )); |
73 | at::Tensor input = at::randn({3, 3}); |
74 | schema.addArgumentValue("self" , input); |
75 | schema.addArgumentValue("other" , input); |
76 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::input, 1})); |
77 | ASSERT_TRUE(schema.is_mutable({SchemaArgType::output, 1})); |
78 | ASSERT_TRUE(schema.is_mutable("other" )); |
79 | } |
80 | |
81 | TEST(SchemaInfoIsMutableTest, InstanceNorm) { |
82 | SchemaInfo schema_info( |
83 | "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor" ); |
84 | ASSERT_TRUE(schema_info.is_mutable("running_mean" )); |
85 | ASSERT_TRUE(schema_info.is_mutable("running_var" )); |
86 | schema_info.addArgumentValue("use_input_stats" , false); |
87 | ASSERT_FALSE(schema_info.is_mutable("running_mean" )); |
88 | ASSERT_FALSE(schema_info.is_mutable("running_var" )); |
89 | } |
90 | |
91 | TEST(SchemaInfoIsMutableTest, BatchNorm) { |
92 | SchemaInfo schema_info( |
93 | "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" ); |
94 | ASSERT_TRUE(schema_info.is_mutable("running_mean" )); |
95 | ASSERT_TRUE(schema_info.is_mutable("running_var" )); |
96 | schema_info.addArgumentValue("training" , false); |
97 | ASSERT_FALSE(schema_info.is_mutable("running_mean" )); |
98 | ASSERT_FALSE(schema_info.is_mutable("running_var" )); |
99 | } |
100 | |
101 | TEST(SchemaInfoIsNonDeterministicTest, Basic) { |
102 | SchemaInfo deterministic_schema_info( |
103 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
104 | SchemaInfo nondeterministic_schema_info( |
105 | "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor" ); |
106 | ASSERT_FALSE(deterministic_schema_info.is_nondeterministic()); |
107 | ASSERT_TRUE(nondeterministic_schema_info.is_nondeterministic()); |
108 | } |
109 | |
110 | TEST(SchemaInfoIsNonDeterministicTest, Dropout) { |
111 | SchemaInfo droupout_schema_info( |
112 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" ); |
113 | ASSERT_TRUE(droupout_schema_info.is_nondeterministic()); |
114 | droupout_schema_info.addArgumentValue("train" , false); |
115 | ASSERT_FALSE(droupout_schema_info.is_nondeterministic()); |
116 | } |
117 | |
118 | TEST(FunctionSchemaMayAliasTest, Basic) { |
119 | c10::FunctionSchema schema = torch::jit::parseSchema( |
120 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
121 | ASSERT_TRUE( |
122 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
123 | ASSERT_FALSE( |
124 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
125 | ASSERT_FALSE( |
126 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::input, 0})); |
127 | } |
128 | |
129 | TEST(FunctionSchemaMayAliasTest, InvalidArgument) { |
130 | c10::FunctionSchema schema = torch::jit::parseSchema( |
131 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
132 | ASSERT_THROW( |
133 | schema.may_alias({SchemaArgType::input, 15}, {SchemaArgType::output, 0}), |
134 | c10::Error); |
135 | ASSERT_THROW( |
136 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 15}), |
137 | c10::Error); |
138 | } |
139 | |
140 | TEST(FunctionSchemaMayAliasTest, Wildcard) { |
141 | c10::FunctionSchema schema = torch::jit::parseSchema( |
142 | "aten::test.Tensor(Tensor(*) self) -> (Tensor(*), Tensor)" ); |
143 | ASSERT_TRUE( |
144 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
145 | ASSERT_FALSE( |
146 | schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0})); |
147 | } |
148 | |
149 | TEST(SchemaInfoMayAliasTest, AliasingInputs) { |
150 | SchemaInfo schema( |
151 | "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" ); |
152 | ASSERT_FALSE( |
153 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
154 | at::Tensor input = at::randn({3, 3}); |
155 | schema.addArgumentValue("self" , input); |
156 | schema.addArgumentValue("other" , input); |
157 | ASSERT_TRUE( |
158 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
159 | } |
160 | |
161 | TEST(SchemaInfoMayAliasTest, AliasingOutputs) { |
162 | SchemaInfo schema( |
163 | "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)" ); |
164 | ASSERT_FALSE( |
165 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); |
166 | at::Tensor input = at::randn({3, 3}); |
167 | schema.addArgumentValue("min" , input); |
168 | schema.addArgumentValue("max" , input); |
169 | ASSERT_TRUE( |
170 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); |
171 | } |
172 | |
173 | TEST(SchemaInfoMayAliasTest, AliasingInputOutput) { |
174 | SchemaInfo schema( |
175 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
176 | ASSERT_TRUE( |
177 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
178 | ASSERT_FALSE( |
179 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
180 | at::Tensor input = at::randn({3, 3}); |
181 | schema.addArgumentValue("self" , input); |
182 | schema.addArgumentValue("other" , input); |
183 | ASSERT_TRUE( |
184 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
185 | ASSERT_TRUE( |
186 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
187 | } |
188 | |
189 | TEST(SchemaInfoMayAliasTest, MultipleWildcardInputs) { |
190 | SchemaInfo schema( |
191 | "aten::test.Tensor(Tensor(a) a, Tensor(*) b, Tensor(*) c) -> (Tensor(a), Tensor(*))" ); |
192 | ASSERT_TRUE( |
193 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
194 | ASSERT_TRUE( |
195 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1})); |
196 | ASSERT_TRUE( |
197 | schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1})); |
198 | ASSERT_FALSE( |
199 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
200 | ASSERT_FALSE( |
201 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); |
202 | ASSERT_FALSE( |
203 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1})); |
204 | ASSERT_FALSE( |
205 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
206 | at::Tensor input = at::randn({3, 3}); |
207 | schema.addArgumentValue("a" , input); |
208 | schema.addArgumentValue("b" , input); |
209 | ASSERT_TRUE( |
210 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
211 | ASSERT_TRUE( |
212 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 1})); |
213 | ASSERT_TRUE( |
214 | schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 1})); |
215 | ASSERT_TRUE( |
216 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
217 | ASSERT_TRUE( |
218 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); |
219 | ASSERT_TRUE( |
220 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 1})); |
221 | ASSERT_TRUE( |
222 | schema.may_alias({SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
223 | } |
224 | |
225 | TEST(SchemaInfoMayAliasTest, MultipleNonWildcardInputs) { |
226 | SchemaInfo schema( |
227 | "aten::test.Tensor(Tensor(a) a, Tensor(a) b, Tensor(*) c, Tensor(b) d) -> (Tensor(a), Tensor(*))" ); |
228 | ASSERT_TRUE( |
229 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
230 | ASSERT_TRUE( |
231 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 2})); |
232 | ASSERT_TRUE( |
233 | schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::input, 1})); |
234 | ASSERT_TRUE( |
235 | schema.may_alias({SchemaArgType::input, 2}, {SchemaArgType::output, 0})); |
236 | } |
237 | |
238 | TEST(SchemaInfoMayAliasTest, MultipleNonWildcardOutputs) { |
239 | SchemaInfo schema( |
240 | "aten::test.Tensor(Tensor(a) a, Tensor(*) b) -> (Tensor(a), Tensor(a))" ); |
241 | ASSERT_TRUE( |
242 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
243 | ASSERT_TRUE( |
244 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::output, 1})); |
245 | ASSERT_TRUE( |
246 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 1})); |
247 | } |
248 | |
249 | TEST(SchemaInfoMayAliasTest, MismatchingTypes) { |
250 | SchemaInfo schema("aten::test.Tensor(Tensor(a) a) -> int(a)" ); |
251 | ASSERT_FALSE( |
252 | schema.may_alias({SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
253 | } |
254 | |
255 | TEST(FunctionSchemaMayContainAliasTest, Basic) { |
256 | c10::FunctionSchema schema = torch::jit::parseSchema( |
257 | "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))" ); |
258 | ASSERT_TRUE(schema.may_contain_alias( |
259 | {SchemaArgType::input, 0}, {SchemaArgType::output, 0})); |
260 | ASSERT_FALSE(schema.may_contain_alias( |
261 | {SchemaArgType::input, 1}, {SchemaArgType::output, 0})); |
262 | ASSERT_FALSE(schema.may_contain_alias( |
263 | {SchemaArgType::input, 1}, {SchemaArgType::input, 0})); |
264 | } |
265 | |
266 | TEST(FunctionSchemaMayContainAliasTest, Wildcard) { |
267 | c10::FunctionSchema schema = torch::jit::parseSchema( |
268 | "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)" ); |
269 | ASSERT_FALSE( |
270 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
271 | ASSERT_TRUE(schema.may_contain_alias( |
272 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
273 | ASSERT_TRUE(schema.may_contain_alias( |
274 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false)); |
275 | ASSERT_FALSE(schema.may_contain_alias( |
276 | {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false)); |
277 | ASSERT_FALSE( |
278 | schema.may_alias({SchemaArgType::output, 1}, {SchemaArgType::input, 0})); |
279 | } |
280 | |
281 | TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) { |
282 | c10::FunctionSchema schema = |
283 | torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]" ); |
284 | ASSERT_FALSE( |
285 | schema.may_alias({SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
286 | ASSERT_TRUE(schema.may_contain_alias( |
287 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
288 | ASSERT_TRUE(schema.may_contain_alias( |
289 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0}, false)); |
290 | ASSERT_TRUE(schema.may_contain_alias( |
291 | {SchemaArgType::input, 0}, {SchemaArgType::output, 0}, false)); |
292 | } |
293 | |
294 | TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsEqual) { |
295 | SchemaInfo schema( |
296 | "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" ); |
297 | ASSERT_FALSE(schema.may_contain_alias( |
298 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
299 | at::Tensor input = at::randn({3, 3}); |
300 | schema.addArgumentValue("self" , input); |
301 | schema.addArgumentValue("other" , input); |
302 | ASSERT_TRUE(schema.may_contain_alias( |
303 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
304 | ASSERT_TRUE(schema.may_contain_alias( |
305 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false)); |
306 | ASSERT_TRUE(schema.may_contain_alias( |
307 | {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false)); |
308 | } |
309 | |
310 | TEST(SchemaInfoMayContainAliasTest, ContainAliasInputsContained) { |
311 | SchemaInfo schema( |
312 | "aten::test.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor" ); |
313 | ASSERT_FALSE(schema.may_contain_alias( |
314 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
315 | at::Tensor input = at::randn({3, 3}); |
316 | schema.addArgumentValue("self" , c10::List<at::Tensor>({input})); |
317 | schema.addArgumentValue("other" , input); |
318 | ASSERT_TRUE(schema.may_contain_alias( |
319 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
320 | ASSERT_TRUE(schema.may_contain_alias( |
321 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1}, false)); |
322 | ASSERT_FALSE(schema.may_contain_alias( |
323 | {SchemaArgType::input, 1}, {SchemaArgType::input, 0}, false)); |
324 | } |
325 | |
326 | TEST(SchemaInfoMayContainAliasTest, ContainAliasOutputs) { |
327 | SchemaInfo schema( |
328 | "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)" ); |
329 | ASSERT_FALSE(schema.may_contain_alias( |
330 | {SchemaArgType::output, 0}, {SchemaArgType::output, 1})); |
331 | at::Tensor input = at::randn({3, 3}); |
332 | schema.addArgumentValue("min" , input); |
333 | schema.addArgumentValue("max" , input); |
334 | ASSERT_TRUE(schema.may_contain_alias( |
335 | {SchemaArgType::output, 0}, {SchemaArgType::output, 1})); |
336 | } |
337 | |
338 | TEST(SchemaInfoMayContainAliasTest, ContainAliasInputOutput) { |
339 | SchemaInfo schema( |
340 | "aten::test.tensor(Tensor(a) self, Tensor[] other) -> Tensor(a)" ); |
341 | ASSERT_FALSE(schema.may_contain_alias( |
342 | {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); |
343 | at::Tensor input = at::randn({3, 3}); |
344 | schema.addArgumentValue("other" , c10::List<at::Tensor>({input})); |
345 | schema.addArgumentValue("self" , input); |
346 | ASSERT_TRUE(schema.may_contain_alias( |
347 | {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); |
348 | ASSERT_FALSE(schema.may_contain_alias( |
349 | {SchemaArgType::output, 0}, {SchemaArgType::input, 1}, false)); |
350 | ASSERT_TRUE(schema.may_contain_alias( |
351 | {SchemaArgType::input, 1}, {SchemaArgType::output, 0}, false)); |
352 | } |
353 | |
354 | TEST(SchemaInfoMayContainAliasTest, InputAndOutputContainers) { |
355 | SchemaInfo schema( |
356 | "aten::test.tensor(Tensor self, Tensor[] other) -> Tensor[]" ); |
357 | ASSERT_TRUE(schema.may_contain_alias( |
358 | {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); |
359 | ASSERT_FALSE(schema.may_contain_alias( |
360 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
361 | ASSERT_FALSE(schema.may_contain_alias( |
362 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
363 | at::Tensor input = at::randn({3, 3}); |
364 | schema.addArgumentValue("other" , c10::List<at::Tensor>({input})); |
365 | schema.addArgumentValue("self" , input); |
366 | ASSERT_TRUE(schema.may_contain_alias( |
367 | {SchemaArgType::output, 0}, {SchemaArgType::input, 1})); |
368 | ASSERT_TRUE(schema.may_contain_alias( |
369 | {SchemaArgType::output, 0}, {SchemaArgType::input, 0})); |
370 | ASSERT_TRUE(schema.may_contain_alias( |
371 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
372 | } |
373 | |
374 | TEST(SchemaInfoMayContainAliasTest, Wildcard) { |
375 | SchemaInfo schema( |
376 | "aten::test.tensor(Tensor a, Tensor[] b, Tensor(*) c) -> Tensor[]" ); |
377 | ASSERT_FALSE(schema.may_contain_alias( |
378 | {SchemaArgType::input, 0}, {SchemaArgType::input, 2})); |
379 | ASSERT_FALSE(schema.may_contain_alias( |
380 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
381 | ASSERT_TRUE(schema.may_contain_alias( |
382 | {SchemaArgType::input, 2}, {SchemaArgType::input, 1})); |
383 | at::Tensor input = at::randn({3, 3}); |
384 | schema.addArgumentValue("b" , c10::List<at::Tensor>({input})); |
385 | schema.addArgumentValue("a" , input); |
386 | ASSERT_TRUE(schema.may_contain_alias( |
387 | {SchemaArgType::input, 0}, {SchemaArgType::input, 2})); |
388 | ASSERT_TRUE(schema.may_contain_alias( |
389 | {SchemaArgType::input, 0}, {SchemaArgType::input, 1})); |
390 | ASSERT_TRUE(schema.may_contain_alias( |
391 | {SchemaArgType::input, 2}, {SchemaArgType::input, 1})); |
392 | } |
393 | } // namespace utils |
394 | } // namespace torch |
395 | |