1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/passes/utils/memory_dag.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | TEST(MemoryDAGTest, Basic) { |
10 | auto graph = std::make_shared<Graph>(); |
11 | const Value* aValue = graph->addInput(); |
12 | const Value* bValue = graph->addInput(); |
13 | const Value* cValue = graph->addInput(); |
14 | const Value* dValue = graph->addInput(); |
15 | const Value* eValue = graph->addInput(); |
16 | const Value* fValue = graph->addInput(); |
17 | const Value* gValue = graph->addInput(); |
18 | |
19 | { |
20 | // a <- b <- c |
21 | // b <- d |
22 | // a <- e |
23 | // f <- e |
24 | // g is by itself |
25 | auto t = std::make_unique<MemoryDAGBuilder>(); |
26 | auto a = t->makeFreshValue(aValue); |
27 | auto b = t->makeFreshValue(bValue); |
28 | auto c = t->makeFreshValue(cValue); |
29 | auto d = t->makeFreshValue(dValue); |
30 | auto e = t->makeFreshValue(eValue); |
31 | auto f = t->makeFreshValue(fValue); |
32 | auto g = t->makeFreshValue(gValue); |
33 | t->makePointerTo(b, a); |
34 | t->makePointerTo(c, b); |
35 | t->makePointerTo(d, b); |
36 | t->makePointerTo(e, a); |
37 | t->makePointerTo(e, f); |
38 | |
39 | auto dag = std::make_unique<MemoryDAG>(std::move(t)); |
40 | |
41 | /** |
42 | * Test mayAlias() |
43 | */ |
44 | // Values should alias themselves |
45 | EXPECT_TRUE(dag->mayAlias(a, a)); |
46 | EXPECT_TRUE(dag->mayAlias(g, g)); |
47 | |
48 | // Values that point to the same location should alias |
49 | EXPECT_TRUE(dag->mayAlias(a, b)); |
50 | EXPECT_TRUE(dag->mayAlias(a, c)); |
51 | EXPECT_TRUE(dag->mayAlias(c, d)); |
52 | |
53 | // e may point to a OR f |
54 | EXPECT_TRUE(dag->mayAlias(e, a)); |
55 | EXPECT_TRUE(dag->mayAlias(e, f)); |
56 | // But a and f don't alias |
57 | EXPECT_FALSE(dag->mayAlias(a, f)); |
58 | } |
59 | { |
60 | // x(y) -> x contains y |
61 | |
62 | // b(a) |
63 | // c(a) |
64 | auto t = std::make_unique<MemoryDAGBuilder>(); |
65 | auto a = t->makeFreshValue(aValue); |
66 | auto b = t->makeFreshValue(bValue); |
67 | t->addToContainedElements(a, b); |
68 | |
69 | auto c = t->makeFreshValue(cValue); |
70 | t->addToContainedElements(a, c); |
71 | |
72 | auto dag = std::make_unique<MemoryDAG>(std::move(t)); |
73 | EXPECT_TRUE(dag->mayContainAlias(a, b)); |
74 | EXPECT_TRUE(dag->mayContainAlias(b, a)); |
75 | |
76 | EXPECT_TRUE(dag->mayContainAlias(a, c)); |
77 | EXPECT_TRUE(dag->mayContainAlias(c, a)); |
78 | |
79 | EXPECT_TRUE(dag->mayContainAlias(b, c)); |
80 | EXPECT_TRUE(dag->mayContainAlias(c, b)); |
81 | |
82 | // containers contain an element in themselves |
83 | EXPECT_TRUE(dag->mayContainAlias(b, b)); |
84 | EXPECT_TRUE(dag->mayContainAlias(c, c)); |
85 | EXPECT_TRUE(dag->mayContainAlias(a, a)); |
86 | } |
87 | { |
88 | // b(a) |
89 | // c(a) |
90 | // d(b(a)) |
91 | auto t = std::make_unique<MemoryDAGBuilder>(); |
92 | auto a = t->makeFreshValue(aValue); |
93 | auto b = t->makeFreshValue(bValue); |
94 | t->addToContainedElements(a, b); |
95 | |
96 | auto c = t->makeFreshValue(cValue); |
97 | t->addToContainedElements(a, c); |
98 | |
99 | auto d = t->makeFreshValue(dValue); |
100 | t->addToContainedElements(b, d); |
101 | |
102 | auto dag = std::make_unique<MemoryDAG>(std::move(t)); |
103 | EXPECT_TRUE(dag->mayContainAlias(b, d)); |
104 | EXPECT_TRUE(dag->mayContainAlias(d, b)); |
105 | |
106 | EXPECT_TRUE(dag->mayContainAlias(c, d)); |
107 | EXPECT_TRUE(dag->mayContainAlias(d, c)); |
108 | |
109 | EXPECT_TRUE(dag->mayContainAlias(a, d)); |
110 | } |
111 | { |
112 | // f(e) |
113 | auto t = std::make_unique<MemoryDAGBuilder>(); |
114 | auto a = t->makeFreshValue(aValue); |
115 | auto b = t->makeFreshValue(bValue); |
116 | t->addToContainedElements(a, b); |
117 | |
118 | auto c = t->makeFreshValue(cValue); |
119 | t->addToContainedElements(a, c); |
120 | |
121 | auto d = t->makeFreshValue(dValue); |
122 | t->addToContainedElements(b, d); |
123 | |
124 | auto f = t->makeFreshValue(aValue); |
125 | auto e = t->makeFreshValue(bValue); |
126 | |
127 | t->addToContainedElements(f, e); |
128 | |
129 | auto dag = std::make_unique<MemoryDAG>(std::move(t)); |
130 | for (auto elem : {a, b, c, d}) { |
131 | EXPECT_FALSE(dag->mayContainAlias(f, elem)); |
132 | EXPECT_FALSE(dag->mayContainAlias(e, elem)); |
133 | } |
134 | } |
135 | } |
136 | |
137 | } // namespace jit |
138 | } // namespace torch |
139 | |