1#pragma once
2
3#include <executor_utils.h>
4#include <expr_evaluator.h>
5#include <fusion.h>
6#include <ir_iostream.h>
7#include <lower_utils.h>
8
9#include <ATen/cuda/CUDAContext.h>
10
11#include <unordered_map>
12
13// Tests go in torch::jit
14namespace torch {
15namespace jit {
16
17using namespace torch::jit::fuser::cuda;
18
19namespace {
20
21struct ValidationConstants {
22 // Tolerances generated from randn + add + sum fusion
23 // compared against double precision
24 std::array<std::array<double, 2>, 20> sum_tolerances_float = {
25 {{4, 1.68222e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
26 {32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06},
27 {256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05},
28 {2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207},
29 {16384, 0.000136882}, {32768, 0.000248561}, {65536, 0.000407594},
30 {131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909},
31 {1048576, 0.00223107}, {2097152, 0.00343043}}};
32
33 // Tolerances generated from randn + add + sum fusion
34 // compared against double precision
35 std::array<std::array<double, 2>, 20> sum_tolerances_half = {
36 {{4, 0.00390625}, {8, 0.0078125}, {16, 0.0078125},
37 {32, 0.0155334}, {64, 0.0156269}, {128, 0.0312042},
38 {256, 0.0312548}, {512, 0.0619979}, {1024, 0.0625103},
39 {2048, 0.124686}, {4096, 0.12501}, {8192, 0.24945},
40 {16384, 0.250049}, {32768, 0.498946}, {65536, 0.500071},
41 {131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234},
42 {1048576, 2.00032}, {2097152, 3.99073}}};
43
44 double base_half_abs_tol = -1;
45 double base_half_rel_tol = -1;
46 double base_float_abs_tol = -1;
47 double base_float_rel_tol = -1;
48};
49
50// Returns abs and relative values to use for validation
51std::pair<double, double> getTolerance(
52 DataType dtype,
53 int64_t reduction_size,
54 const ValidationConstants& tolerances) {
55 switch (dtype) {
56 case DataType::ComplexFloat:
57 case DataType::ComplexDouble:
58 case DataType::Float:
59 // TODO: Pull new tolerances for Double, for now we will just use float
60 // tolerances as it should be no worse.
61 case DataType::Double: {
62 const auto& sum_tolerance_entry = tolerances.sum_tolerances_float;
63 const auto& base_abs = tolerances.base_float_abs_tol;
64 const auto& base_rel = tolerances.base_float_rel_tol;
65
66 if (reduction_size <= 1) {
67 // No reduction case
68 if (base_abs == -1 || base_rel == -1) {
69 return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
70 } else {
71 return {base_abs, base_rel};
72 }
73 } else {
74 // Reduction case
75 size_t entry = 0;
76 while (entry < sum_tolerance_entry.size() &&
77 sum_tolerance_entry[entry][0] < reduction_size) {
78 entry++;
79 }
80 double abs_tol = 0.0;
81 if (entry + 1 < sum_tolerance_entry.size()) {
82 // Grab the next entry up so we have some margin
83 abs_tol = sum_tolerance_entry[entry + 1][1];
84 } else {
85 // If we hit the end of the list, return twice the max error we
86 // measured
87 abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
88 }
89 // Relative tol we're going to set to 1% of abs tol just for
90 // a small margin of rel error.
91 return {abs_tol, abs_tol * 0.01};
92 }
93 }
94 case DataType::Half: {
95 // Copied from float case
96 const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
97 const auto& base_abs = tolerances.base_half_abs_tol;
98 const auto& base_rel = tolerances.base_half_rel_tol;
99
100 if (reduction_size <= 1) {
101 // No reduction case
102 if (base_abs == -1 || base_rel == -1) {
103 return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
104 } else {
105 return {base_abs, base_rel};
106 }
107 } else {
108 // Reduction case
109 size_t entry = 0;
110 while (sum_tolerance_entry[entry][0] < reduction_size &&
111 entry < sum_tolerance_entry.size()) {
112 entry++;
113 }
114 double abs_tol = 0.0;
115 if (entry + 1 < sum_tolerance_entry.size()) {
116 // Grab the next entry up so we have some margin
117 abs_tol = sum_tolerance_entry[entry + 1][1];
118 } else {
119 // If we hit the end of the list, return twice the max error we
120 // measured
121 abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
122 }
123 // Relative tol we're going to set to 1% of abs tol just for
124 // a small margin of rel error.
125 return {abs_tol, abs_tol * 0.01};
126 }
127 }
128 case DataType::BFloat16: {
129 // Copied from float case
130 const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
131 const auto& base_abs = tolerances.base_half_abs_tol;
132 const auto& base_rel = tolerances.base_half_rel_tol;
133
134 if (reduction_size <= 1) {
135 // No reduction case
136 if (base_abs == -1 || base_rel == -1) {
137 return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
138 } else {
139 return {base_abs * 10.0, base_rel * 10.0};
140 }
141 } else {
142 // Reduction case
143 size_t entry = 0;
144 while (sum_tolerance_entry[entry][0] < reduction_size &&
145 entry < sum_tolerance_entry.size()) {
146 entry++;
147 }
148 double abs_tol = 0.0;
149 if (entry + 1 < sum_tolerance_entry.size()) {
150 // Grab the next entry up so we have some margin
151 abs_tol = sum_tolerance_entry[entry + 1][1];
152 } else {
153 // If we hit the end of the list, return twice the max error we
154 // measured
155 abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
156 }
157 // Relative tol we're going to set to 1% of abs tol just for
158 // a small margin of rel error.
159 return {abs_tol * 10.0, abs_tol * 0.01 * 10.0};
160 }
161 }
162 case DataType::Int:
163 return {0.0, 0.0};
164 case DataType::Int32:
165 return {0.0, 0.0};
166 case DataType::Bool:
167 return {0.0, 0.0};
168 default:
169 TORCH_INTERNAL_ASSERT(
170 false, "Do not have tolerance computation for type ", dtype, ".");
171 }
172}
173
174class ReductionSizeMapper : private IterVisitor {
175 public:
176 //! Runs through the fusion and determines how many reductions were performed
177 //! to compute each tensorview.
178 static std::unordered_map<TensorView*, int64_t> computeReductionSizes(
179 Fusion* fusion,
180 ExpressionEvaluator& expr_eval) {
181 ReductionSizeMapper mapper(fusion, expr_eval);
182 return mapper.reduction_map;
183 }
184
185 private:
186 ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval)
187 : expr_eval_(expr_eval) {
188 // Initialize input values
189 for (auto inp : fusion->inputs()) {
190 if (inp->isA<TensorView>()) {
191 auto tv = inp->as<TensorView>();
192 // Shouldn't have any reductions, but run it through analysis anyways.
193 reduction_map[tv] = getReductionSize(tv);
194 }
195 }
196
197 IterVisitor::traverse(fusion);
198
199 // catch up with dangling outputs;
200 for (auto out : fusion->outputs()) {
201 if (out->isA<TensorView>()) {
202 auto tv = out->as<TensorView>();
203 // possible that we have a dangling output that's not generated by any
204 // expression. e.g. 0 workspace or null tensor
205 if (reduction_map.count(tv) == 0) {
206 // Shouldn't have any reductions, but run it through analysis anyways.
207 reduction_map[tv] = getReductionSize(tv);
208 }
209 }
210 }
211 }
212
213 int64_t getReductionSize(const TensorView* tv) {
214 int64_t reduction_elements = 1;
215 for (auto id : tv->getMaybeRFactorDomain()) {
216 if (id->isReduction()) {
217 auto inferred_extent = expr_eval_.evaluate(id->extent());
218 TORCH_INTERNAL_ASSERT(
219 inferred_extent.has_value(),
220 "Couldn't figure out what the dimensions of a tensorview is in evaluation for validation. ",
221 id,
222 " in ",
223 tv);
224 reduction_elements =
225 reduction_elements * inferred_extent->as<int64_t>();
226 }
227 }
228 return reduction_elements;
229 }
230
231 void handle(Expr* expr) override {
232 if (!ir_utils::isTvOp(expr)) {
233 return;
234 }
235
236 int64_t inp_reduction_elements = 1;
237 for (auto inp : expr->inputs()) {
238 if (inp->isA<TensorView>()) {
239 if (auto tv = inp->as<TensorView>()) {
240 inp_reduction_elements =
241 std::max(inp_reduction_elements, reduction_map.at(tv));
242 }
243 }
244 }
245
246 for (auto out : expr->outputs()) {
247 if (out->isA<TensorView>()) {
248 auto tv = out->as<TensorView>();
249 reduction_map[tv] = getReductionSize(tv) * inp_reduction_elements;
250 }
251 }
252 }
253
254 private:
255 using IterVisitor::handle;
256
257 std::unordered_map<TensorView*, int64_t> reduction_map;
258 ExpressionEvaluator& expr_eval_;
259};
260
261ExpressionEvaluator bindInputsAndLaunchParams(
262 Fusion* fusion,
263 const at::ArrayRef<IValue>& aten_inputs,
264 const LaunchParams& launch_constraints) {
265 // index_mode is not important here
266 KernelArgumentHolder argument_holder(KernelIndexMode::INT64);
267 argument_holder.push(aten_inputs);
268
269 auto expr_eval = executor_utils::bindFusionInputs(argument_holder, fusion);
270 for (auto val : fusion->vals()) {
271 if (!val->isA<TensorView>()) {
272 continue;
273 }
274
275 // Roughly taken from executor.cpp/computeLaunchParams
276 auto tv = val->as<TensorView>();
277 for (auto id : tv->domain()->domain()) {
278 if (!(id->isThread() && id->extent()->definition() == nullptr)) {
279 continue;
280 }
281
282 if (id->isBroadcast()) {
283 continue;
284 }
285
286 auto extent = id->extent();
287 auto inferred_extent = expr_eval.evaluate(extent);
288 auto p_type = id->getParallelType();
289
290 if (inferred_extent.has_value()) {
291 // This value could have been inferred, make sure it was set right.
292 TORCH_CHECK(
293 inferred_extent.value() == launch_constraints.getDim(p_type) ||
294 launch_constraints.getRawVal(p_type) == -1,
295 "inferred that ",
296 p_type,
297 " should be set to ",
298 inferred_extent.value(),
299 " but launch constraints specified ",
300 launch_constraints.getRawVal(p_type));
301 } else {
302 // Bind the launch constraint into our evaluation context
303 if (launch_constraints.hasDim(id->getParallelType())) {
304 expr_eval.bind(extent, launch_constraints.getDim(p_type));
305 }
306 }
307 }
308 }
309 return expr_eval;
310}
311
312// Validation will look through the fusion and figure out how many elements were
313// reduced to create each output. It will then compute a tolernace to use for
314// allclose based on experimental results. The experimental results were based
315// on adding two tensors then summing them. This of course has an assumption
316// that we're always summing values between -2 and 2. If we start summing values
317// larger than that this approach might not hold.
318void testValidate(
319 Fusion* fusion,
320 const std::vector<at::Tensor>& fusion_outputs,
321 const at::ArrayRef<IValue>& aten_inputs,
322 const std::vector<at::Tensor>& aten_outputs,
323 int line_number,
324 const char* file_name,
325 std::string err_msg = "",
326 const LaunchParams& lparams = LaunchParams(),
327 const ValidationConstants& tolerances = ValidationConstants()) {
328 FusionGuard fg(fusion);
329
330 auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams);
331
332 auto reduction_sizes =
333 ReductionSizeMapper::computeReductionSizes(fusion, expr_eval);
334
335 auto output_alias_indices = fusion->getOutputAliasIndices();
336
337 TORCH_INTERNAL_ASSERT(
338 fusion_outputs.size() == aten_outputs.size() &&
339 aten_outputs.size() ==
340 fusion->outputs().size() - output_alias_indices.size(),
341 "Number of outputs don't match.");
342
343 TORCH_INTERNAL_ASSERT(
344 fusion->inputs().size() == aten_inputs.size(),
345 "Number of inputs don't match.");
346
347 for (size_t i = 0; i < fusion->inputs().size(); i++) {
348 if (fusion->inputs()[i]->isA<TensorView>()) {
349 TORCH_INTERNAL_ASSERT(
350 aten_inputs[i].isTensor(), "Mismatch of tensor inputs.");
351
352 auto fusion_input_tv = fusion->inputs()[i]->as<TensorView>();
353 auto at_tensor = aten_inputs[i].toTensor();
354
355 TORCH_INTERNAL_ASSERT(
356 at_tensor.dim() ==
357 static_cast<int64_t>(TensorDomain::noReductions(
358 fusion_input_tv->getMaybeRFactorDomain())
359 .size()),
360 "Dimensionality mismatch in inputs.");
361 }
362 }
363
364 for (size_t i = 0, j = 0; i < fusion->outputs().size(); i++) {
365 TORCH_INTERNAL_ASSERT(
366 fusion->outputs()[i]->isA<TensorView>(), "Mismatch of tensor outputs.");
367 if (output_alias_indices.count(i) != 0) {
368 // this is an aliased output, let's not check this;
369 continue;
370 }
371
372 auto fusion_output_tensor = fusion_outputs[j];
373 auto fusion_output_tv = fusion->outputs()[i]->as<TensorView>();
374 auto aten_output_tensor = aten_outputs[j];
375
376 TORCH_INTERNAL_ASSERT(
377 reduction_sizes.count(fusion_output_tv),
378 "Missed reduction size count on fusion output at index: ",
379 i);
380
381 int64_t reduction_size = reduction_sizes.at(fusion_output_tv);
382
383 TORCH_INTERNAL_ASSERT(
384 aten_output_tensor.dim() == fusion_output_tensor.dim() &&
385 fusion_outputs[j].dim() ==
386 static_cast<int64_t>(
387 TensorDomain::noReductions(
388 fusion_output_tv->getMaybeRFactorDomain())
389 .size()),
390 "Dimensionality mismatch in outputs.");
391
392 auto tolerance_values = getTolerance(
393 fusion_output_tv->getDataType().value(), reduction_size, tolerances);
394
395 if (aten_output_tensor.is_floating_point() ||
396 aten_output_tensor.is_complex()) {
397 TORCH_INTERNAL_ASSERT(
398 aten_output_tensor.allclose(
399 fusion_output_tensor.to(aten_output_tensor.dtype()),
400 tolerance_values.second,
401 tolerance_values.first),
402 "\n",
403 err_msg,
404 "\nValidation error in output ",
405 j,
406 " on line ",
407 line_number,
408 " in file ",
409 file_name,
410 ".\n Detected abs error of: ",
411 aten_output_tensor.sub(fusion_output_tensor)
412 .abs()
413 .max()
414 .item()
415 .to<double>(),
416 "\n absolute tolerance was set to ",
417 tolerance_values.first,
418 "\n and relative tolerance set to ",
419 tolerance_values.second);
420 } else {
421 TORCH_INTERNAL_ASSERT(
422 aten_output_tensor.equal(
423 fusion_output_tensor.to(aten_output_tensor.dtype())),
424 "\n",
425 err_msg,
426 ".\n Validation error in output ",
427 j,
428 " on line ",
429 line_number,
430 " in file ",
431 file_name,
432 ".\n Values are not equal and are not a floating type.");
433 }
434 j++;
435 }
436}
437
438} // namespace
439} // namespace jit
440} // namespace torch
441