1 | #if defined(USE_CUDA) |
2 | #include <gmock/gmock-matchers.h> |
3 | #include <gtest/gtest.h> |
4 | |
5 | #include <fusion.h> |
6 | #include <lower_utils.h> |
7 | #include <ops/all_ops.h> |
8 | #include <scheduler/utils.h> |
9 | #include <test/test_gpu_validator.h> |
10 | #include <test/test_utils.h> |
11 | |
12 | // Tests go in torch::jit |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | using namespace torch::jit::fuser::cuda; |
17 | |
18 | TEST_F(NVFuserTest, FusionSplitDims_CUDA) { |
19 | Fusion fusion; |
20 | FusionGuard fg(&fusion); |
21 | |
22 | int64_t* p = prime_numbers; |
23 | auto tv = makeConcreteTensor( |
24 | {p[0] * p[1] * p[2], p[3], p[4], p[5] * p[6], p[7], p[8], p[9] * p[10]}); |
25 | std::vector<size_t> dims{0, 1, 2, 3, 4, 5, 6}; |
26 | scheduler_utils::splitDims( |
27 | tv, {{0, p[2]}, {0, p[1]}, {3, p[6]}, {6, p[10]}}, dims); |
28 | TORCH_CHECK(tv->nDims() == 11); |
29 | for (auto i : c10::irange(11)) { |
30 | TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == p[i]); |
31 | } |
32 | std::vector<size_t> expect{0, 3, 4, 5, 7, 8, 9}; |
33 | TORCH_CHECK(dims == expect); |
34 | } |
35 | |
36 | TEST_F(NVFuserTest, FusionMergeDims_CUDA) { |
37 | Fusion fusion; |
38 | FusionGuard fg(&fusion); |
39 | |
40 | int64_t* p = prime_numbers; |
41 | auto tv = makeConcreteTensor( |
42 | {p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10]}); |
43 | std::vector<size_t> dims{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; |
44 | auto merged = scheduler_utils::mergeDims(tv, {2, 3, 7, 8, 9}, dims); |
45 | TORCH_CHECK(merged == (size_t)2); |
46 | std::vector<int64_t> expect_shape{ |
47 | p[0], p[1], p[2] * p[3] * p[7] * p[8] * p[9], p[4], p[5], p[6], p[10]}; |
48 | TORCH_CHECK(tv->nDims() == expect_shape.size()); |
49 | for (auto i : c10::irange(expect_shape.size())) { |
50 | TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == expect_shape[i]); |
51 | } |
52 | std::vector<size_t> expect_dims{0, 1, 2, 2, 3, 4, 5, 2, 2, 2, 6}; |
53 | TORCH_CHECK(dims == expect_dims); |
54 | } |
55 | |
56 | TEST_F(NVFuserTest, FusionReorderAsRFactor_CUDA) { |
57 | Fusion fusion; |
58 | FusionGuard fg(&fusion); |
59 | |
60 | int a = 1, b = 2, c = 3, d = 4; |
61 | |
62 | TensorView* tv0 = makeConcreteTensor({a, b, c, d}); |
63 | fusion.addInput(tv0); |
64 | fusion.addOutput(tv0); |
65 | |
66 | // [a, b, c, d] |
67 | tv0->merge(0, 2); |
68 | // [a*c, b, d] |
69 | tv0->split(1, 2); |
70 | // [a*c, bo, bi, d] |
71 | tv0->split(3, 3); |
72 | // [a*c, bo, bi, do, di] |
73 | tv0->reorder({{1, 4}, {2, 1}, {3, 3}, {4, 2}}); |
74 | // [a*c, bi, di, do, bo] |
75 | tv0->merge(3); |
76 | tv0->merge(1); |
77 | // [a*c, bi*di, do*bo] |
78 | tv0->reorder({{0, 2}}); |
79 | // [bi*di, do*bo, a*c] |
80 | // Order we want is: |
81 | // [a*c, do*bo, bi*di] |
82 | auto old2new = scheduler_utils::domainReorderAsRfactorMap(tv0); |
83 | TORCH_CHECK(old2new[0] == 2); |
84 | TORCH_CHECK(old2new[1] == 1); |
85 | TORCH_CHECK(old2new[2] == 0); |
86 | } |
87 | |
88 | TEST_F(NVFuserTest, FusionDisjointViewSet_CUDA) { |
89 | auto fusion = std::make_unique<Fusion>(); |
90 | FusionGuard fg(fusion.get()); |
91 | |
92 | auto tv0 = makeConcreteTensor({2, 3, 4}); |
93 | fusion->addInput(tv0); |
94 | |
95 | auto tv1 = view(tv0, {2, 3, 4}, {2, 12}); |
96 | |
97 | auto tv2 = makeConcreteTensor({2, 12}); |
98 | fusion->addInput(tv2); |
99 | |
100 | auto tv3 = add(tv2, tv1); |
101 | fusion->addOutput(tv3); |
102 | |
103 | auto disjoint_exact = scheduler_utils::disjointViewSets(fusion.get()); |
104 | |
105 | TORCH_INTERNAL_ASSERT( |
106 | disjoint_exact.strictAreMapped(tv0->axis(1), tv0->axis(2))); |
107 | } |
108 | |
109 | TEST_F(NVFuserTest, FusionMatchingViews_CUDA) { |
110 | Fusion fusion; |
111 | FusionGuard fg(&fusion); |
112 | |
113 | int x = 2, y = 3, z = 4; |
114 | |
115 | auto tv0 = makeConcreteTensor({x, y, z}); |
116 | fusion.addInput(tv0); |
117 | |
118 | auto tv1 = view(tv0, {x, y, z}, {x * y, z}); |
119 | |
120 | auto tv2 = sin(tv1); |
121 | |
122 | auto tv3 = view(tv2, {x * y, z}, {x, y * z}); |
123 | fusion.addOutput(tv3); |
124 | |
125 | auto tv4 = makeConcreteTensor({x, y, z}); |
126 | fusion.addInput(tv4); |
127 | |
128 | auto tv5 = view(tv4, {x, y, z}, {x, y * z}); |
129 | fusion.addOutput(tv5); |
130 | |
131 | // Link 0 and 3 together for view analysis done based on before the views |
132 | // actually happened. |
133 | auto tv6 = add(tv0, tv4); |
134 | fusion.addOutput(tv6); |
135 | |
136 | TORCH_INTERNAL_ASSERT(!scheduler_utils::allMatchingViews(&fusion)); |
137 | } |
138 | |
139 | TEST_F(NVFuserTest, FusionBroadcastViewMultiples_CUDA) { |
140 | Fusion fusion; |
141 | FusionGuard fg(&fusion); |
142 | |
143 | int a = 2, b = 3, c = 5, d = 7, e = 11, f = 13; |
144 | |
145 | auto tv0 = makeConcreteTensor({a, b, c, d, e, f}); |
146 | fusion.addInput(tv0); |
147 | |
148 | // tie e and f together (swapping values next to eachother enforces they'll be |
149 | // merged then split by view) |
150 | auto tv1 = view(tv0, {a, b, c, d, e, f}, {a, b, c, d, f, e}); |
151 | fusion.addOutput(tv1); |
152 | |
153 | // swap d and e |
154 | auto tv2 = transpose(tv1, 3, 4); |
155 | // tie c and e together |
156 | auto tv3 = view(tv2, {a, b, c, e, d, f}, {a, b, e, c, d, f}); |
157 | |
158 | fusion.addOutput(tv3); |
159 | |
160 | auto tv4 = set(tv0); |
161 | // Use tv4 as the reference |
162 | fusion.addOutput(tv4); |
163 | |
164 | // a, b, d aren't tied to anything so they are valid broadcasts from the |
165 | // perspective of broadcast multiples analysis. |
166 | auto tv5 = makeConcreteTensor({1, 1, c, 1, e, f}); |
167 | fusion.addInput(tv5); |
168 | |
169 | // c, e, and f are tied together so this shouldn't be counted as a broadcast |
170 | // dim in the reference since it's a partial bcast |
171 | auto tv6 = makeConcreteTensor({a, b, c, 1, 1, 1}); |
172 | fusion.addInput(tv6); |
173 | |
174 | // c, e, and f are tied together this should be counted as a broadcast dim in |
175 | // the reference since it's a partial bcast |
176 | auto tv7 = makeConcreteTensor({a, b, 1, 1, 1, 1}); |
177 | fusion.addInput(tv7); |
178 | |
179 | // plug the broadcasts into the fusion |
180 | auto tv8 = add(tv5, tv4); |
181 | auto tv9 = add(tv6, tv8); |
182 | auto tv10 = add(tv7, tv9); |
183 | fusion.addOutput(tv10); |
184 | |
185 | auto bcast_info = |
186 | scheduler_utils::getBroadcastMultiples(tv4, DataType::Int32); |
187 | |
188 | // linked c, e, and f together so they should have the same id. |
189 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[5] == 0); |
190 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[4] == 0); |
191 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[3] == 1); |
192 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[2] == 0); |
193 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[1] == 2); |
194 | TORCH_CHECK(bcast_info.view_disjoint_set_ids[0] == 3); |
195 | |
196 | TORCH_CHECK( |
197 | scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 0)); |
198 | TORCH_CHECK( |
199 | scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 1)); |
200 | TORCH_CHECK( |
201 | scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 2)); |
202 | TORCH_CHECK( |
203 | !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 3)); |
204 | TORCH_CHECK( |
205 | !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 4)); |
206 | TORCH_CHECK( |
207 | !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 5)); |
208 | |
209 | // tv0 [a, b, c, d, e, f] |
210 | // tv1 [a, b, c, d, e, f] |
211 | // tv3 [a, b, c, d, e, f] |
212 | // tv4 [a, b, c, d, e, f] |
213 | // tv5 [1, 1, c, 1, e, f] -> Left bcasts should show up in some multiples |
214 | // tv6 [a, b, c, 1, 1, 1] -> view interferes with bcasts, non of these should |
215 | // show up |
216 | // tv7 [a, b, 1, 1, 1, 1] -> These broadcasts could be recognized |
217 | // tv10 [a, b, c, d, e, f] |
218 | |
219 | TORCH_CHECK( |
220 | bcast_info.broadcast_multiples[0].lhs_multiple == 0 && |
221 | bcast_info.broadcast_multiples[0].rhs_multiple == 8 * 4); |
222 | |
223 | TORCH_CHECK( |
224 | bcast_info.broadcast_multiples[1].lhs_multiple == 7 * 4 && |
225 | bcast_info.broadcast_multiples[1].rhs_multiple == 8 * 4); |
226 | |
227 | TORCH_CHECK( |
228 | bcast_info.broadcast_multiples[2].lhs_multiple == 7 * 4 && |
229 | bcast_info.broadcast_multiples[2].rhs_multiple == 7 * 4); |
230 | |
231 | TORCH_CHECK( |
232 | bcast_info.broadcast_multiples[3].lhs_multiple == 8 * 4 && |
233 | bcast_info.broadcast_multiples[3].rhs_multiple == 7 * 4); |
234 | |
235 | TORCH_CHECK( |
236 | bcast_info.broadcast_multiples[4].lhs_multiple == 8 * 4 && |
237 | bcast_info.broadcast_multiples[4].rhs_multiple == 7 * 4); |
238 | |
239 | TORCH_CHECK( |
240 | bcast_info.broadcast_multiples[5].lhs_multiple == 8 * 4 && |
241 | bcast_info.broadcast_multiples[5].rhs_multiple == 7 * 4); |
242 | } |
243 | |
244 | TEST_F(NVFuserTest, FusionTVDomainGuard_CUDA) { |
245 | Fusion fusion; |
246 | FusionGuard fg(&fusion); |
247 | |
248 | std::vector<bool> all_true = {true, true}; |
249 | std::vector<bool> all_false = {false, false}; |
250 | std::vector<bool> false_true = {false, true}; |
251 | auto tv = TensorViewBuilder().ndims(2).contiguity(false_true).build(); |
252 | TORCH_CHECK(tv->domain()->contiguity() == false_true); |
253 | { |
254 | auto guard = ir_utils::overrideContiguityGuard(tv, true); |
255 | TORCH_CHECK(tv->domain()->contiguity() == all_true); |
256 | } |
257 | TORCH_CHECK(tv->domain()->contiguity() == false_true); |
258 | { |
259 | auto guard = ir_utils::overrideContiguityGuard(tv, false); |
260 | TORCH_CHECK(tv->domain()->contiguity() == all_false); |
261 | } |
262 | TORCH_CHECK(tv->domain()->contiguity() == false_true); |
263 | { |
264 | auto guard1 = ir_utils::overrideContiguityGuard(tv, true); |
265 | auto guard2 = std::move(guard1); |
266 | TORCH_CHECK(tv->domain()->contiguity() == all_true); |
267 | } |
268 | TORCH_CHECK(tv->domain()->contiguity() == false_true); |
269 | } |
270 | |
271 | } // namespace jit |
272 | } // namespace torch |
273 | #endif // #if defined(USE_CUDA) |
274 | |