1 | #pragma once |
2 | |
3 | #include <executor.h> |
4 | #include <expr_evaluator.h> |
5 | #include <ir_all_nodes.h> |
6 | #include <kernel_ir_dispatch.h> |
7 | #include <lower2device.h> |
8 | #include <lower_magic_zero.h> |
9 | #include <transform_replay.h> |
10 | |
11 | #include <ATen/Context.h> |
12 | #include <ATen/cuda/CUDAContext.h> |
13 | #include <c10/cuda/CUDACachingAllocator.h> |
14 | #include <torch/torch.h> |
15 | |
16 | #include <gtest/gtest.h> |
17 | |
18 | #include <cstddef> |
19 | |
20 | // Tests go in torch::jit |
21 | namespace torch { |
22 | namespace jit { |
23 | |
24 | using namespace torch::jit::fuser::cuda; |
25 | |
26 | namespace { |
27 | bool var; |
28 | // Make a tensor that is known to be fully contiguous of dimensionality=ndims, |
29 | // but unknown sizes |
30 | TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { |
31 | return TensorViewBuilder() |
32 | .ndims(ndims) |
33 | .dtype(dtype) |
34 | .contiguity(std::vector<bool>(ndims, true)) |
35 | .build(); |
36 | } |
37 | |
38 | // Make a tensor that is known to be non-contiguous of dimensionality=ndims, |
39 | // but unknown sizes |
40 | TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) { |
41 | return TensorViewBuilder().ndims(ndims).dtype(dtype).build(); |
42 | } |
43 | |
44 | // Make a non-contiguous tensor of compile-time known sizes |
45 | TensorView* makeConcreteTensor( |
46 | std::vector<int64_t> shape, |
47 | DataType dtype = DataType::Float) { |
48 | return TensorViewBuilder().shape(shape).dtype(dtype).build(); |
49 | } |
50 | |
51 | TensorView* makeContigConcreteTensor( |
52 | std::vector<int64_t> shape, |
53 | DataType dtype = DataType::Float) { |
54 | return TensorViewBuilder() |
55 | .shape(shape) |
56 | .dtype(dtype) |
57 | .contiguity(std::vector<bool>(shape.size(), true)) |
58 | .build(); |
59 | } |
60 | |
61 | void checkIntValue( |
62 | ExpressionEvaluator& evaluator, |
63 | Val* val, |
64 | Int::ScalarType expected_value) { |
65 | TORCH_CHECK(val->isAnInt()); |
66 | const auto actual_value = evaluator.evaluate(val); |
67 | TORCH_CHECK(actual_value.has_value()); |
68 | TORCH_CHECK(actual_value.value() == expected_value); |
69 | } |
70 | |
71 | void checkIntValue( |
72 | kir::ExpressionEvaluator& evaluator, |
73 | const Val* val, |
74 | Int::ScalarType expected_value) { |
75 | const auto actual_value = evaluator.evaluate(val); |
76 | TORCH_CHECK(actual_value.has_value()); |
77 | TORCH_CHECK(actual_value.value() == expected_value); |
78 | } |
79 | |
80 | // prime numbers |
81 | int64_t prime_numbers[] = { |
82 | 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, |
83 | 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, |
84 | 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, |
85 | 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, |
86 | 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, |
87 | 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, |
88 | 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, |
89 | 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, |
90 | 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, |
91 | 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, |
92 | 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, |
93 | 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, |
94 | 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, |
95 | 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, |
96 | 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069, |
97 | 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, |
98 | 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223}; |
99 | |
100 | bool deviceMajorMinorCheck(int major, int minor = 0) { |
101 | auto dev_prop = at::cuda::getCurrentDeviceProperties(); |
102 | if (dev_prop->major < major || |
103 | (dev_prop->major == major && dev_prop->minor < minor)) { |
104 | return false; |
105 | } |
106 | return true; |
107 | } |
108 | |
109 | int deviceSMCount() { |
110 | int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; |
111 | return sm_count; |
112 | } |
113 | |
114 | void clearL2Cache() { |
115 | torch::NoGradGuard no_grad; |
116 | auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; |
117 | auto options = |
118 | torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); |
119 | |
120 | auto l2_elems = l2_cache_size / 4; |
121 | torch::Tensor t0 = torch::empty(l2_elems, options); |
122 | torch::Tensor t1 = torch::clone(t0); |
123 | }; |
124 | |
125 | TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { |
126 | auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as<Fusion>()); |
127 | TensorView* matching_tv = nullptr; |
128 | for (auto lowered_tv : used_tvs) { |
129 | if (lowered_tv->name() == tv->name()) { |
130 | matching_tv = lowered_tv; |
131 | } |
132 | } |
133 | TORCH_INTERNAL_ASSERT(matching_tv != nullptr); |
134 | return matching_tv; |
135 | } |
136 | |
137 | class PredicatedChecker : public kir::IrVisitor { |
138 | public: |
139 | // Checks if the provided tv is written to within a non-trivial conditional |
140 | static bool isPredicated(TensorView* tv, GpuLower& gpulw) { |
141 | PredicatedChecker checker( |
142 | loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); |
143 | return checker.is_predicated_; |
144 | } |
145 | |
146 | private: |
147 | PredicatedChecker() = delete; |
148 | |
149 | PredicatedChecker(TensorView* tv, std::vector<Expr*> exprs) : tv_(tv) { |
150 | kir::IrVisitor::handle(exprs); |
151 | } |
152 | |
153 | using kir::IrVisitor::handle; |
154 | bool is_predicated_ = false; |
155 | bool predicated_ite_ = false; |
156 | TensorView* tv_ = nullptr; |
157 | |
158 | void handle(kir::IfThenElse* ite) final { |
159 | auto prev_ite = predicated_ite_; |
160 | predicated_ite_ = !ite->predicate()->value()->isConstScalar(); |
161 | kir::IrVisitor::handle(ite); |
162 | predicated_ite_ = prev_ite; |
163 | } |
164 | |
165 | void handle(Expr* expr) final { |
166 | if (expr->outputs().size() && expr->outputs()[0]->isA<kir::TensorIndex>()) { |
167 | auto ti = expr->outputs()[0]->as<kir::TensorIndex>(); |
168 | if (ti->view() == tv_) { |
169 | is_predicated_ = is_predicated_ | predicated_ite_; |
170 | if (expr->predicate() != nullptr && |
171 | !expr->predicate()->value()->isConst()) { |
172 | is_predicated_ = true; |
173 | } |
174 | } |
175 | } |
176 | kir::IrVisitor::handle(expr); |
177 | } |
178 | }; |
179 | |
180 | class UnswitchInElseChecker : public kir::IrVisitor { |
181 | public: |
182 | // Checks if there are any unswitched for loops within an else clause |
183 | static bool check(GpuLower& gpulw) { |
184 | UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs()); |
185 | return checker.found_in_else_; |
186 | } |
187 | |
188 | private: |
189 | UnswitchInElseChecker() = delete; |
190 | UnswitchInElseChecker(std::vector<Expr*> exprs) { |
191 | kir::IrVisitor::handle(exprs); |
192 | } |
193 | |
194 | using kir::IrVisitor::handle; |
195 | bool within_else_ = false; |
196 | bool found_in_else_ = false; |
197 | |
198 | void handle(kir::IfThenElse* ite) final { |
199 | auto prev_within_else = within_else_; |
200 | within_else_ = true; |
201 | kir::IrVisitor::handle(ite->elseBody().exprs()); |
202 | within_else_ = prev_within_else; |
203 | } |
204 | |
205 | void handle(kir::ForLoop* for_loop) final { |
206 | if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) { |
207 | found_in_else_ = found_in_else_ || within_else_; |
208 | } |
209 | kir::IrVisitor::handle(for_loop); |
210 | } |
211 | }; |
212 | |
213 | class PredicateMagicZeroChecker : public kir::IrVisitor { |
214 | public: |
215 | // Checks if all predicated domains of the provided tv are protected with |
216 | // magic zero |
217 | static bool isProtected(TensorView* tv, GpuLower& gpulw) { |
218 | PredicateMagicZeroChecker checker( |
219 | loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); |
220 | return checker.is_protected_; |
221 | } |
222 | |
223 | private: |
224 | using kir::IrVisitor::handle; |
225 | |
226 | PredicateMagicZeroChecker(TensorView* tv, std::vector<Expr*> exprs) |
227 | : tv_(tv) { |
228 | handle(exprs); |
229 | } |
230 | |
231 | void handle(kir::IfThenElse* ite) final { |
232 | auto prev_predicate = predicate_; |
233 | predicate_ = ite->predicate()->value(); |
234 | kir::IrVisitor::handle(ite); |
235 | predicate_ = prev_predicate; |
236 | } |
237 | |
238 | void handle(Expr* expr) final { |
239 | if (expr->outputs().size() && expr->outputs()[0]->isA<kir::TensorIndex>()) { |
240 | auto ti = expr->outputs()[0]->as<kir::TensorIndex>(); |
241 | if (ti->view() == tv_) { |
242 | is_protected_ = checkPredicateOfTensor(predicate_); |
243 | return; |
244 | } |
245 | } |
246 | |
247 | if (expr->isA<kir::ForLoop>()) { |
248 | handle(expr->as<kir::ForLoop>()); |
249 | } else if (expr->isA<kir::IfThenElse>()) { |
250 | handle(expr->as<kir::IfThenElse>()); |
251 | } else { |
252 | for (auto input : expr->inputs()) { |
253 | handle(input); |
254 | } |
255 | } |
256 | } |
257 | |
258 | // Return true If all predicated domains are protected |
259 | bool checkPredicateOfTensor(Val* predicate) { |
260 | auto id_predicates = decomposeCompoundPredicate(predicate); |
261 | for (auto id_predicate : id_predicates) { |
262 | // Just check if nvfuser_zero is used. Not perfect but probably |
263 | // good enough. |
264 | is_magic_zero_found_ = false; |
265 | handle(id_predicate); |
266 | if (!is_magic_zero_found_) { |
267 | return false; |
268 | } |
269 | } |
270 | return true; |
271 | } |
272 | |
273 | // Decompose "X && Y" to a vector of {X, Y}. |
274 | std::vector<Val*> decomposeCompoundPredicate(Val* predicate) { |
275 | if (auto binary_op = dynamic_cast<BinaryOp*>(predicate->definition())) { |
276 | if (binary_op->getBinaryOpType() == BinaryOpType::And) { |
277 | auto pred = decomposeCompoundPredicate(binary_op->lhs()); |
278 | auto rhs_pred = decomposeCompoundPredicate(binary_op->rhs()); |
279 | pred.insert(pred.end(), rhs_pred.begin(), rhs_pred.end()); |
280 | return pred; |
281 | } |
282 | } |
283 | |
284 | return {predicate}; |
285 | } |
286 | |
287 | void handle(Val* val) final { |
288 | if (isMagicZero(val)) { |
289 | is_magic_zero_found_ = true; |
290 | return; |
291 | } |
292 | |
293 | auto def = val->definition(); |
294 | if (def != nullptr) { |
295 | handle(def); |
296 | } |
297 | } |
298 | |
299 | private: |
300 | bool is_protected_ = false; |
301 | Val* predicate_ = nullptr; |
302 | TensorView* tv_ = nullptr; |
303 | bool is_magic_zero_found_ = false; |
304 | }; |
305 | |
306 | // Basically just TransformPropagator, except that it checks the consistency |
307 | // replayPasC with getMatchedLeafPosWithoutReplayPasC, replayCasP with |
308 | // getMatchedLeafPosWithoutReplayCasP, and fullSelfReplay with fullSelfMatching: |
309 | // - After replayPasC, getMatchedLeafPosWithoutReplayPasC should return the same |
310 | // replayed position |
311 | // - After replayCasP, getMatchedLeafPosWithoutReplayCasP should return the same |
312 | // replayed position |
313 | // - After fullSelfReplay, fullSelfMatching should return true |
314 | struct TransformPropagatorWithCheck : public TransformPropagator { |
315 | public: |
316 | virtual void propagateC2P(TensorView* from, TensorView* to) override { |
317 | TransformPropagator::propagateC2P(from, to); |
318 | auto from_pos = replayed_pos_.at(from); |
319 | auto to_pos = replayed_pos_.at(to); |
320 | TORCH_CHECK( |
321 | TransformReplay::getMatchedLeafPosWithoutReplayPasC( |
322 | to, from, from_pos) == (int)to_pos); |
323 | } |
324 | virtual void propagateP2C(TensorView* from, TensorView* to) override { |
325 | TransformPropagator::propagateP2C(from, to); |
326 | auto from_pos = replayed_pos_.at(from); |
327 | auto to_pos = replayed_pos_.at(to); |
328 | TORCH_CHECK( |
329 | TransformReplay::getMatchedLeafPosWithoutReplayCasP( |
330 | to, from, from_pos) == (int)to_pos); |
331 | } |
332 | virtual void propagateSibling(TensorView* from, TensorView* to) override { |
333 | TransformPropagator::propagateSibling(from, to); |
334 | auto from_pos = replayed_pos_.at(from); |
335 | auto to_pos = replayed_pos_.at(to); |
336 | TORCH_CHECK(from_pos == to_pos); |
337 | TORCH_CHECK(TransformReplay::fullSelfMatching(from, to)); |
338 | } |
339 | using TransformPropagator::TransformPropagator; |
340 | }; |
341 | |
342 | } // namespace |
343 | |
344 | class ContextCudnnTF32Disabled { |
345 | public: |
346 | ContextCudnnTF32Disabled() { |
347 | flag_ = at::globalContext().allowTF32CuDNN(); |
348 | at::globalContext().setAllowTF32CuDNN(false); |
349 | } |
350 | |
351 | ~ContextCudnnTF32Disabled() { |
352 | at::globalContext().setAllowTF32CuDNN(flag_); |
353 | } |
354 | |
355 | private: |
356 | bool flag_; |
357 | }; |
358 | |
359 | // Fixture class must be uniquely identified, i.e., can't be in an |
360 | // anonymous namespace |
361 | class NVFuserTest : public ::testing::Test { |
362 | protected: |
363 | void SetUp() override { |
364 | // requires PASCAL or newer |
365 | if (!deviceMajorMinorCheck(6)) { |
366 | GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs" ; |
367 | } |
368 | setFillAllocationWithNan(true); |
369 | } |
370 | }; |
371 | |
372 | } // namespace jit |
373 | } // namespace torch |
374 | |