1#if defined(USE_CUDA)
2#include <gmock/gmock-matchers.h>
3#include <gtest/gtest.h>
4
5#include <fusion.h>
6#include <lower_utils.h>
7#include <ops/all_ops.h>
8#include <scheduler/utils.h>
9#include <test/test_gpu_validator.h>
10#include <test/test_utils.h>
11
12// Tests go in torch::jit
13namespace torch {
14namespace jit {
15
16using namespace torch::jit::fuser::cuda;
17
18TEST_F(NVFuserTest, FusionSplitDims_CUDA) {
19 Fusion fusion;
20 FusionGuard fg(&fusion);
21
22 int64_t* p = prime_numbers;
23 auto tv = makeConcreteTensor(
24 {p[0] * p[1] * p[2], p[3], p[4], p[5] * p[6], p[7], p[8], p[9] * p[10]});
25 std::vector<size_t> dims{0, 1, 2, 3, 4, 5, 6};
26 scheduler_utils::splitDims(
27 tv, {{0, p[2]}, {0, p[1]}, {3, p[6]}, {6, p[10]}}, dims);
28 TORCH_CHECK(tv->nDims() == 11);
29 for (auto i : c10::irange(11)) {
30 TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == p[i]);
31 }
32 std::vector<size_t> expect{0, 3, 4, 5, 7, 8, 9};
33 TORCH_CHECK(dims == expect);
34}
35
36TEST_F(NVFuserTest, FusionMergeDims_CUDA) {
37 Fusion fusion;
38 FusionGuard fg(&fusion);
39
40 int64_t* p = prime_numbers;
41 auto tv = makeConcreteTensor(
42 {p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10]});
43 std::vector<size_t> dims{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
44 auto merged = scheduler_utils::mergeDims(tv, {2, 3, 7, 8, 9}, dims);
45 TORCH_CHECK(merged == (size_t)2);
46 std::vector<int64_t> expect_shape{
47 p[0], p[1], p[2] * p[3] * p[7] * p[8] * p[9], p[4], p[5], p[6], p[10]};
48 TORCH_CHECK(tv->nDims() == expect_shape.size());
49 for (auto i : c10::irange(expect_shape.size())) {
50 TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == expect_shape[i]);
51 }
52 std::vector<size_t> expect_dims{0, 1, 2, 2, 3, 4, 5, 2, 2, 2, 6};
53 TORCH_CHECK(dims == expect_dims);
54}
55
56TEST_F(NVFuserTest, FusionReorderAsRFactor_CUDA) {
57 Fusion fusion;
58 FusionGuard fg(&fusion);
59
60 int a = 1, b = 2, c = 3, d = 4;
61
62 TensorView* tv0 = makeConcreteTensor({a, b, c, d});
63 fusion.addInput(tv0);
64 fusion.addOutput(tv0);
65
66 // [a, b, c, d]
67 tv0->merge(0, 2);
68 // [a*c, b, d]
69 tv0->split(1, 2);
70 // [a*c, bo, bi, d]
71 tv0->split(3, 3);
72 // [a*c, bo, bi, do, di]
73 tv0->reorder({{1, 4}, {2, 1}, {3, 3}, {4, 2}});
74 // [a*c, bi, di, do, bo]
75 tv0->merge(3);
76 tv0->merge(1);
77 // [a*c, bi*di, do*bo]
78 tv0->reorder({{0, 2}});
79 // [bi*di, do*bo, a*c]
80 // Order we want is:
81 // [a*c, do*bo, bi*di]
82 auto old2new = scheduler_utils::domainReorderAsRfactorMap(tv0);
83 TORCH_CHECK(old2new[0] == 2);
84 TORCH_CHECK(old2new[1] == 1);
85 TORCH_CHECK(old2new[2] == 0);
86}
87
88TEST_F(NVFuserTest, FusionDisjointViewSet_CUDA) {
89 auto fusion = std::make_unique<Fusion>();
90 FusionGuard fg(fusion.get());
91
92 auto tv0 = makeConcreteTensor({2, 3, 4});
93 fusion->addInput(tv0);
94
95 auto tv1 = view(tv0, {2, 3, 4}, {2, 12});
96
97 auto tv2 = makeConcreteTensor({2, 12});
98 fusion->addInput(tv2);
99
100 auto tv3 = add(tv2, tv1);
101 fusion->addOutput(tv3);
102
103 auto disjoint_exact = scheduler_utils::disjointViewSets(fusion.get());
104
105 TORCH_INTERNAL_ASSERT(
106 disjoint_exact.strictAreMapped(tv0->axis(1), tv0->axis(2)));
107}
108
109TEST_F(NVFuserTest, FusionMatchingViews_CUDA) {
110 Fusion fusion;
111 FusionGuard fg(&fusion);
112
113 int x = 2, y = 3, z = 4;
114
115 auto tv0 = makeConcreteTensor({x, y, z});
116 fusion.addInput(tv0);
117
118 auto tv1 = view(tv0, {x, y, z}, {x * y, z});
119
120 auto tv2 = sin(tv1);
121
122 auto tv3 = view(tv2, {x * y, z}, {x, y * z});
123 fusion.addOutput(tv3);
124
125 auto tv4 = makeConcreteTensor({x, y, z});
126 fusion.addInput(tv4);
127
128 auto tv5 = view(tv4, {x, y, z}, {x, y * z});
129 fusion.addOutput(tv5);
130
131 // Link 0 and 3 together for view analysis done based on before the views
132 // actually happened.
133 auto tv6 = add(tv0, tv4);
134 fusion.addOutput(tv6);
135
136 TORCH_INTERNAL_ASSERT(!scheduler_utils::allMatchingViews(&fusion));
137}
138
139TEST_F(NVFuserTest, FusionBroadcastViewMultiples_CUDA) {
140 Fusion fusion;
141 FusionGuard fg(&fusion);
142
143 int a = 2, b = 3, c = 5, d = 7, e = 11, f = 13;
144
145 auto tv0 = makeConcreteTensor({a, b, c, d, e, f});
146 fusion.addInput(tv0);
147
148 // tie e and f together (swapping values next to eachother enforces they'll be
149 // merged then split by view)
150 auto tv1 = view(tv0, {a, b, c, d, e, f}, {a, b, c, d, f, e});
151 fusion.addOutput(tv1);
152
153 // swap d and e
154 auto tv2 = transpose(tv1, 3, 4);
155 // tie c and e together
156 auto tv3 = view(tv2, {a, b, c, e, d, f}, {a, b, e, c, d, f});
157
158 fusion.addOutput(tv3);
159
160 auto tv4 = set(tv0);
161 // Use tv4 as the reference
162 fusion.addOutput(tv4);
163
164 // a, b, d aren't tied to anything so they are valid broadcasts from the
165 // perspective of broadcast multiples analysis.
166 auto tv5 = makeConcreteTensor({1, 1, c, 1, e, f});
167 fusion.addInput(tv5);
168
169 // c, e, and f are tied together so this shouldn't be counted as a broadcast
170 // dim in the reference since it's a partial bcast
171 auto tv6 = makeConcreteTensor({a, b, c, 1, 1, 1});
172 fusion.addInput(tv6);
173
174 // c, e, and f are tied together this should be counted as a broadcast dim in
175 // the reference since it's a partial bcast
176 auto tv7 = makeConcreteTensor({a, b, 1, 1, 1, 1});
177 fusion.addInput(tv7);
178
179 // plug the broadcasts into the fusion
180 auto tv8 = add(tv5, tv4);
181 auto tv9 = add(tv6, tv8);
182 auto tv10 = add(tv7, tv9);
183 fusion.addOutput(tv10);
184
185 auto bcast_info =
186 scheduler_utils::getBroadcastMultiples(tv4, DataType::Int32);
187
188 // linked c, e, and f together so they should have the same id.
189 TORCH_CHECK(bcast_info.view_disjoint_set_ids[5] == 0);
190 TORCH_CHECK(bcast_info.view_disjoint_set_ids[4] == 0);
191 TORCH_CHECK(bcast_info.view_disjoint_set_ids[3] == 1);
192 TORCH_CHECK(bcast_info.view_disjoint_set_ids[2] == 0);
193 TORCH_CHECK(bcast_info.view_disjoint_set_ids[1] == 2);
194 TORCH_CHECK(bcast_info.view_disjoint_set_ids[0] == 3);
195
196 TORCH_CHECK(
197 scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 0));
198 TORCH_CHECK(
199 scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 1));
200 TORCH_CHECK(
201 scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 2));
202 TORCH_CHECK(
203 !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 3));
204 TORCH_CHECK(
205 !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 4));
206 TORCH_CHECK(
207 !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 5));
208
209 // tv0 [a, b, c, d, e, f]
210 // tv1 [a, b, c, d, e, f]
211 // tv3 [a, b, c, d, e, f]
212 // tv4 [a, b, c, d, e, f]
213 // tv5 [1, 1, c, 1, e, f] -> Left bcasts should show up in some multiples
214 // tv6 [a, b, c, 1, 1, 1] -> view interferes with bcasts, non of these should
215 // show up
216 // tv7 [a, b, 1, 1, 1, 1] -> These broadcasts could be recognized
217 // tv10 [a, b, c, d, e, f]
218
219 TORCH_CHECK(
220 bcast_info.broadcast_multiples[0].lhs_multiple == 0 &&
221 bcast_info.broadcast_multiples[0].rhs_multiple == 8 * 4);
222
223 TORCH_CHECK(
224 bcast_info.broadcast_multiples[1].lhs_multiple == 7 * 4 &&
225 bcast_info.broadcast_multiples[1].rhs_multiple == 8 * 4);
226
227 TORCH_CHECK(
228 bcast_info.broadcast_multiples[2].lhs_multiple == 7 * 4 &&
229 bcast_info.broadcast_multiples[2].rhs_multiple == 7 * 4);
230
231 TORCH_CHECK(
232 bcast_info.broadcast_multiples[3].lhs_multiple == 8 * 4 &&
233 bcast_info.broadcast_multiples[3].rhs_multiple == 7 * 4);
234
235 TORCH_CHECK(
236 bcast_info.broadcast_multiples[4].lhs_multiple == 8 * 4 &&
237 bcast_info.broadcast_multiples[4].rhs_multiple == 7 * 4);
238
239 TORCH_CHECK(
240 bcast_info.broadcast_multiples[5].lhs_multiple == 8 * 4 &&
241 bcast_info.broadcast_multiples[5].rhs_multiple == 7 * 4);
242}
243
244TEST_F(NVFuserTest, FusionTVDomainGuard_CUDA) {
245 Fusion fusion;
246 FusionGuard fg(&fusion);
247
248 std::vector<bool> all_true = {true, true};
249 std::vector<bool> all_false = {false, false};
250 std::vector<bool> false_true = {false, true};
251 auto tv = TensorViewBuilder().ndims(2).contiguity(false_true).build();
252 TORCH_CHECK(tv->domain()->contiguity() == false_true);
253 {
254 auto guard = ir_utils::overrideContiguityGuard(tv, true);
255 TORCH_CHECK(tv->domain()->contiguity() == all_true);
256 }
257 TORCH_CHECK(tv->domain()->contiguity() == false_true);
258 {
259 auto guard = ir_utils::overrideContiguityGuard(tv, false);
260 TORCH_CHECK(tv->domain()->contiguity() == all_false);
261 }
262 TORCH_CHECK(tv->domain()->contiguity() == false_true);
263 {
264 auto guard1 = ir_utils::overrideContiguityGuard(tv, true);
265 auto guard2 = std::move(guard1);
266 TORCH_CHECK(tv->domain()->contiguity() == all_true);
267 }
268 TORCH_CHECK(tv->domain()->contiguity() == false_true);
269}
270
271} // namespace jit
272} // namespace torch
273#endif // #if defined(USE_CUDA)
274