1#pragma once
2
3#include <executor.h>
4#include <expr_evaluator.h>
5#include <ir_all_nodes.h>
6#include <kernel_ir_dispatch.h>
7#include <lower2device.h>
8#include <lower_magic_zero.h>
9#include <transform_replay.h>
10
11#include <ATen/Context.h>
12#include <ATen/cuda/CUDAContext.h>
13#include <c10/cuda/CUDACachingAllocator.h>
14#include <torch/torch.h>
15
16#include <gtest/gtest.h>
17
18#include <cstddef>
19
20// Tests go in torch::jit
21namespace torch {
22namespace jit {
23
24using namespace torch::jit::fuser::cuda;
25
26namespace {
27bool var;
28// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
29// but unknown sizes
30TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
31 return TensorViewBuilder()
32 .ndims(ndims)
33 .dtype(dtype)
34 .contiguity(std::vector<bool>(ndims, true))
35 .build();
36}
37
38// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
39// but unknown sizes
40TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
41 return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
42}
43
44// Make a non-contiguous tensor of compile-time known sizes
45TensorView* makeConcreteTensor(
46 std::vector<int64_t> shape,
47 DataType dtype = DataType::Float) {
48 return TensorViewBuilder().shape(shape).dtype(dtype).build();
49}
50
51TensorView* makeContigConcreteTensor(
52 std::vector<int64_t> shape,
53 DataType dtype = DataType::Float) {
54 return TensorViewBuilder()
55 .shape(shape)
56 .dtype(dtype)
57 .contiguity(std::vector<bool>(shape.size(), true))
58 .build();
59}
60
61void checkIntValue(
62 ExpressionEvaluator& evaluator,
63 Val* val,
64 Int::ScalarType expected_value) {
65 TORCH_CHECK(val->isAnInt());
66 const auto actual_value = evaluator.evaluate(val);
67 TORCH_CHECK(actual_value.has_value());
68 TORCH_CHECK(actual_value.value() == expected_value);
69}
70
71void checkIntValue(
72 kir::ExpressionEvaluator& evaluator,
73 const Val* val,
74 Int::ScalarType expected_value) {
75 const auto actual_value = evaluator.evaluate(val);
76 TORCH_CHECK(actual_value.has_value());
77 TORCH_CHECK(actual_value.value() == expected_value);
78}
79
80// prime numbers
81int64_t prime_numbers[] = {
82 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37,
83 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89,
84 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
85 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223,
86 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
87 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359,
88 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433,
89 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
90 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593,
91 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659,
92 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743,
93 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827,
94 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
95 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997,
96 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069,
97 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163,
98 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223};
99
100bool deviceMajorMinorCheck(int major, int minor = 0) {
101 auto dev_prop = at::cuda::getCurrentDeviceProperties();
102 if (dev_prop->major < major ||
103 (dev_prop->major == major && dev_prop->minor < minor)) {
104 return false;
105 }
106 return true;
107}
108
109int deviceSMCount() {
110 int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
111 return sm_count;
112}
113
114void clearL2Cache() {
115 torch::NoGradGuard no_grad;
116 auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize;
117 auto options =
118 torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0);
119
120 auto l2_elems = l2_cache_size / 4;
121 torch::Tensor t0 = torch::empty(l2_elems, options);
122 torch::Tensor t1 = torch::clone(t0);
123};
124
125TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) {
126 auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as<Fusion>());
127 TensorView* matching_tv = nullptr;
128 for (auto lowered_tv : used_tvs) {
129 if (lowered_tv->name() == tv->name()) {
130 matching_tv = lowered_tv;
131 }
132 }
133 TORCH_INTERNAL_ASSERT(matching_tv != nullptr);
134 return matching_tv;
135}
136
137class PredicatedChecker : public kir::IrVisitor {
138 public:
139 // Checks if the provided tv is written to within a non-trivial conditional
140 static bool isPredicated(TensorView* tv, GpuLower& gpulw) {
141 PredicatedChecker checker(
142 loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs());
143 return checker.is_predicated_;
144 }
145
146 private:
147 PredicatedChecker() = delete;
148
149 PredicatedChecker(TensorView* tv, std::vector<Expr*> exprs) : tv_(tv) {
150 kir::IrVisitor::handle(exprs);
151 }
152
153 using kir::IrVisitor::handle;
154 bool is_predicated_ = false;
155 bool predicated_ite_ = false;
156 TensorView* tv_ = nullptr;
157
158 void handle(kir::IfThenElse* ite) final {
159 auto prev_ite = predicated_ite_;
160 predicated_ite_ = !ite->predicate()->value()->isConstScalar();
161 kir::IrVisitor::handle(ite);
162 predicated_ite_ = prev_ite;
163 }
164
165 void handle(Expr* expr) final {
166 if (expr->outputs().size() && expr->outputs()[0]->isA<kir::TensorIndex>()) {
167 auto ti = expr->outputs()[0]->as<kir::TensorIndex>();
168 if (ti->view() == tv_) {
169 is_predicated_ = is_predicated_ | predicated_ite_;
170 if (expr->predicate() != nullptr &&
171 !expr->predicate()->value()->isConst()) {
172 is_predicated_ = true;
173 }
174 }
175 }
176 kir::IrVisitor::handle(expr);
177 }
178};
179
180class UnswitchInElseChecker : public kir::IrVisitor {
181 public:
182 // Checks if there are any unswitched for loops within an else clause
183 static bool check(GpuLower& gpulw) {
184 UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs());
185 return checker.found_in_else_;
186 }
187
188 private:
189 UnswitchInElseChecker() = delete;
190 UnswitchInElseChecker(std::vector<Expr*> exprs) {
191 kir::IrVisitor::handle(exprs);
192 }
193
194 using kir::IrVisitor::handle;
195 bool within_else_ = false;
196 bool found_in_else_ = false;
197
198 void handle(kir::IfThenElse* ite) final {
199 auto prev_within_else = within_else_;
200 within_else_ = true;
201 kir::IrVisitor::handle(ite->elseBody().exprs());
202 within_else_ = prev_within_else;
203 }
204
205 void handle(kir::ForLoop* for_loop) final {
206 if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) {
207 found_in_else_ = found_in_else_ || within_else_;
208 }
209 kir::IrVisitor::handle(for_loop);
210 }
211};
212
213class PredicateMagicZeroChecker : public kir::IrVisitor {
214 public:
215 // Checks if all predicated domains of the provided tv are protected with
216 // magic zero
217 static bool isProtected(TensorView* tv, GpuLower& gpulw) {
218 PredicateMagicZeroChecker checker(
219 loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs());
220 return checker.is_protected_;
221 }
222
223 private:
224 using kir::IrVisitor::handle;
225
226 PredicateMagicZeroChecker(TensorView* tv, std::vector<Expr*> exprs)
227 : tv_(tv) {
228 handle(exprs);
229 }
230
231 void handle(kir::IfThenElse* ite) final {
232 auto prev_predicate = predicate_;
233 predicate_ = ite->predicate()->value();
234 kir::IrVisitor::handle(ite);
235 predicate_ = prev_predicate;
236 }
237
238 void handle(Expr* expr) final {
239 if (expr->outputs().size() && expr->outputs()[0]->isA<kir::TensorIndex>()) {
240 auto ti = expr->outputs()[0]->as<kir::TensorIndex>();
241 if (ti->view() == tv_) {
242 is_protected_ = checkPredicateOfTensor(predicate_);
243 return;
244 }
245 }
246
247 if (expr->isA<kir::ForLoop>()) {
248 handle(expr->as<kir::ForLoop>());
249 } else if (expr->isA<kir::IfThenElse>()) {
250 handle(expr->as<kir::IfThenElse>());
251 } else {
252 for (auto input : expr->inputs()) {
253 handle(input);
254 }
255 }
256 }
257
258 // Return true If all predicated domains are protected
259 bool checkPredicateOfTensor(Val* predicate) {
260 auto id_predicates = decomposeCompoundPredicate(predicate);
261 for (auto id_predicate : id_predicates) {
262 // Just check if nvfuser_zero is used. Not perfect but probably
263 // good enough.
264 is_magic_zero_found_ = false;
265 handle(id_predicate);
266 if (!is_magic_zero_found_) {
267 return false;
268 }
269 }
270 return true;
271 }
272
273 // Decompose "X && Y" to a vector of {X, Y}.
274 std::vector<Val*> decomposeCompoundPredicate(Val* predicate) {
275 if (auto binary_op = dynamic_cast<BinaryOp*>(predicate->definition())) {
276 if (binary_op->getBinaryOpType() == BinaryOpType::And) {
277 auto pred = decomposeCompoundPredicate(binary_op->lhs());
278 auto rhs_pred = decomposeCompoundPredicate(binary_op->rhs());
279 pred.insert(pred.end(), rhs_pred.begin(), rhs_pred.end());
280 return pred;
281 }
282 }
283
284 return {predicate};
285 }
286
287 void handle(Val* val) final {
288 if (isMagicZero(val)) {
289 is_magic_zero_found_ = true;
290 return;
291 }
292
293 auto def = val->definition();
294 if (def != nullptr) {
295 handle(def);
296 }
297 }
298
299 private:
300 bool is_protected_ = false;
301 Val* predicate_ = nullptr;
302 TensorView* tv_ = nullptr;
303 bool is_magic_zero_found_ = false;
304};
305
306// Basically just TransformPropagator, except that it checks the consistency
307// replayPasC with getMatchedLeafPosWithoutReplayPasC, replayCasP with
308// getMatchedLeafPosWithoutReplayCasP, and fullSelfReplay with fullSelfMatching:
309// - After replayPasC, getMatchedLeafPosWithoutReplayPasC should return the same
310// replayed position
311// - After replayCasP, getMatchedLeafPosWithoutReplayCasP should return the same
312// replayed position
313// - After fullSelfReplay, fullSelfMatching should return true
314struct TransformPropagatorWithCheck : public TransformPropagator {
315 public:
316 virtual void propagateC2P(TensorView* from, TensorView* to) override {
317 TransformPropagator::propagateC2P(from, to);
318 auto from_pos = replayed_pos_.at(from);
319 auto to_pos = replayed_pos_.at(to);
320 TORCH_CHECK(
321 TransformReplay::getMatchedLeafPosWithoutReplayPasC(
322 to, from, from_pos) == (int)to_pos);
323 }
324 virtual void propagateP2C(TensorView* from, TensorView* to) override {
325 TransformPropagator::propagateP2C(from, to);
326 auto from_pos = replayed_pos_.at(from);
327 auto to_pos = replayed_pos_.at(to);
328 TORCH_CHECK(
329 TransformReplay::getMatchedLeafPosWithoutReplayCasP(
330 to, from, from_pos) == (int)to_pos);
331 }
332 virtual void propagateSibling(TensorView* from, TensorView* to) override {
333 TransformPropagator::propagateSibling(from, to);
334 auto from_pos = replayed_pos_.at(from);
335 auto to_pos = replayed_pos_.at(to);
336 TORCH_CHECK(from_pos == to_pos);
337 TORCH_CHECK(TransformReplay::fullSelfMatching(from, to));
338 }
339 using TransformPropagator::TransformPropagator;
340};
341
342} // namespace
343
344class ContextCudnnTF32Disabled {
345 public:
346 ContextCudnnTF32Disabled() {
347 flag_ = at::globalContext().allowTF32CuDNN();
348 at::globalContext().setAllowTF32CuDNN(false);
349 }
350
351 ~ContextCudnnTF32Disabled() {
352 at::globalContext().setAllowTF32CuDNN(flag_);
353 }
354
355 private:
356 bool flag_;
357};
358
359// Fixture class must be uniquely identified, i.e., can't be in an
360// anonymous namespace
361class NVFuserTest : public ::testing::Test {
362 protected:
363 void SetUp() override {
364 // requires PASCAL or newer
365 if (!deviceMajorMinorCheck(6)) {
366 GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
367 }
368 setFillAllocationWithNan(true);
369 }
370};
371
372} // namespace jit
373} // namespace torch
374