1#include <gtest/gtest.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/passes/utils/memory_dag.h>
5
6namespace torch {
7namespace jit {
8
9TEST(MemoryDAGTest, Basic) {
10 auto graph = std::make_shared<Graph>();
11 const Value* aValue = graph->addInput();
12 const Value* bValue = graph->addInput();
13 const Value* cValue = graph->addInput();
14 const Value* dValue = graph->addInput();
15 const Value* eValue = graph->addInput();
16 const Value* fValue = graph->addInput();
17 const Value* gValue = graph->addInput();
18
19 {
20 // a <- b <- c
21 // b <- d
22 // a <- e
23 // f <- e
24 // g is by itself
25 auto t = std::make_unique<MemoryDAGBuilder>();
26 auto a = t->makeFreshValue(aValue);
27 auto b = t->makeFreshValue(bValue);
28 auto c = t->makeFreshValue(cValue);
29 auto d = t->makeFreshValue(dValue);
30 auto e = t->makeFreshValue(eValue);
31 auto f = t->makeFreshValue(fValue);
32 auto g = t->makeFreshValue(gValue);
33 t->makePointerTo(b, a);
34 t->makePointerTo(c, b);
35 t->makePointerTo(d, b);
36 t->makePointerTo(e, a);
37 t->makePointerTo(e, f);
38
39 auto dag = std::make_unique<MemoryDAG>(std::move(t));
40
41 /**
42 * Test mayAlias()
43 */
44 // Values should alias themselves
45 EXPECT_TRUE(dag->mayAlias(a, a));
46 EXPECT_TRUE(dag->mayAlias(g, g));
47
48 // Values that point to the same location should alias
49 EXPECT_TRUE(dag->mayAlias(a, b));
50 EXPECT_TRUE(dag->mayAlias(a, c));
51 EXPECT_TRUE(dag->mayAlias(c, d));
52
53 // e may point to a OR f
54 EXPECT_TRUE(dag->mayAlias(e, a));
55 EXPECT_TRUE(dag->mayAlias(e, f));
56 // But a and f don't alias
57 EXPECT_FALSE(dag->mayAlias(a, f));
58 }
59 {
60 // x(y) -> x contains y
61
62 // b(a)
63 // c(a)
64 auto t = std::make_unique<MemoryDAGBuilder>();
65 auto a = t->makeFreshValue(aValue);
66 auto b = t->makeFreshValue(bValue);
67 t->addToContainedElements(a, b);
68
69 auto c = t->makeFreshValue(cValue);
70 t->addToContainedElements(a, c);
71
72 auto dag = std::make_unique<MemoryDAG>(std::move(t));
73 EXPECT_TRUE(dag->mayContainAlias(a, b));
74 EXPECT_TRUE(dag->mayContainAlias(b, a));
75
76 EXPECT_TRUE(dag->mayContainAlias(a, c));
77 EXPECT_TRUE(dag->mayContainAlias(c, a));
78
79 EXPECT_TRUE(dag->mayContainAlias(b, c));
80 EXPECT_TRUE(dag->mayContainAlias(c, b));
81
82 // containers contain an element in themselves
83 EXPECT_TRUE(dag->mayContainAlias(b, b));
84 EXPECT_TRUE(dag->mayContainAlias(c, c));
85 EXPECT_TRUE(dag->mayContainAlias(a, a));
86 }
87 {
88 // b(a)
89 // c(a)
90 // d(b(a))
91 auto t = std::make_unique<MemoryDAGBuilder>();
92 auto a = t->makeFreshValue(aValue);
93 auto b = t->makeFreshValue(bValue);
94 t->addToContainedElements(a, b);
95
96 auto c = t->makeFreshValue(cValue);
97 t->addToContainedElements(a, c);
98
99 auto d = t->makeFreshValue(dValue);
100 t->addToContainedElements(b, d);
101
102 auto dag = std::make_unique<MemoryDAG>(std::move(t));
103 EXPECT_TRUE(dag->mayContainAlias(b, d));
104 EXPECT_TRUE(dag->mayContainAlias(d, b));
105
106 EXPECT_TRUE(dag->mayContainAlias(c, d));
107 EXPECT_TRUE(dag->mayContainAlias(d, c));
108
109 EXPECT_TRUE(dag->mayContainAlias(a, d));
110 }
111 {
112 // f(e)
113 auto t = std::make_unique<MemoryDAGBuilder>();
114 auto a = t->makeFreshValue(aValue);
115 auto b = t->makeFreshValue(bValue);
116 t->addToContainedElements(a, b);
117
118 auto c = t->makeFreshValue(cValue);
119 t->addToContainedElements(a, c);
120
121 auto d = t->makeFreshValue(dValue);
122 t->addToContainedElements(b, d);
123
124 auto f = t->makeFreshValue(aValue);
125 auto e = t->makeFreshValue(bValue);
126
127 t->addToContainedElements(f, e);
128
129 auto dag = std::make_unique<MemoryDAG>(std::move(t));
130 for (auto elem : {a, b, c, d}) {
131 EXPECT_FALSE(dag->mayContainAlias(f, elem));
132 EXPECT_FALSE(dag->mayContainAlias(e, elem));
133 }
134 }
135}
136
137} // namespace jit
138} // namespace torch
139