1 | #pragma once |
2 | |
3 | #include <executor_utils.h> |
4 | #include <expr_evaluator.h> |
5 | #include <fusion.h> |
6 | #include <ir_iostream.h> |
7 | #include <lower_utils.h> |
8 | |
9 | #include <ATen/cuda/CUDAContext.h> |
10 | |
11 | #include <unordered_map> |
12 | |
13 | // Tests go in torch::jit |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | using namespace torch::jit::fuser::cuda; |
18 | |
19 | namespace { |
20 | |
21 | struct ValidationConstants { |
22 | // Tolerances generated from randn + add + sum fusion |
23 | // compared against double precision |
24 | std::array<std::array<double, 2>, 20> sum_tolerances_float = { |
25 | {{4, 1.68222e-06}, {8, 2.23704e-06}, {16, 2.95788e-06}, |
26 | {32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06}, |
27 | {256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05}, |
28 | {2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207}, |
29 | {16384, 0.000136882}, {32768, 0.000248561}, {65536, 0.000407594}, |
30 | {131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909}, |
31 | {1048576, 0.00223107}, {2097152, 0.00343043}}}; |
32 | |
33 | // Tolerances generated from randn + add + sum fusion |
34 | // compared against double precision |
35 | std::array<std::array<double, 2>, 20> sum_tolerances_half = { |
36 | {{4, 0.00390625}, {8, 0.0078125}, {16, 0.0078125}, |
37 | {32, 0.0155334}, {64, 0.0156269}, {128, 0.0312042}, |
38 | {256, 0.0312548}, {512, 0.0619979}, {1024, 0.0625103}, |
39 | {2048, 0.124686}, {4096, 0.12501}, {8192, 0.24945}, |
40 | {16384, 0.250049}, {32768, 0.498946}, {65536, 0.500071}, |
41 | {131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234}, |
42 | {1048576, 2.00032}, {2097152, 3.99073}}}; |
43 | |
44 | double base_half_abs_tol = -1; |
45 | double base_half_rel_tol = -1; |
46 | double base_float_abs_tol = -1; |
47 | double base_float_rel_tol = -1; |
48 | }; |
49 | |
50 | // Returns abs and relative values to use for validation |
51 | std::pair<double, double> getTolerance( |
52 | DataType dtype, |
53 | int64_t reduction_size, |
54 | const ValidationConstants& tolerances) { |
55 | switch (dtype) { |
56 | case DataType::ComplexFloat: |
57 | case DataType::ComplexDouble: |
58 | case DataType::Float: |
59 | // TODO: Pull new tolerances for Double, for now we will just use float |
60 | // tolerances as it should be no worse. |
61 | case DataType::Double: { |
62 | const auto& sum_tolerance_entry = tolerances.sum_tolerances_float; |
63 | const auto& base_abs = tolerances.base_float_abs_tol; |
64 | const auto& base_rel = tolerances.base_float_rel_tol; |
65 | |
66 | if (reduction_size <= 1) { |
67 | // No reduction case |
68 | if (base_abs == -1 || base_rel == -1) { |
69 | return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]}; |
70 | } else { |
71 | return {base_abs, base_rel}; |
72 | } |
73 | } else { |
74 | // Reduction case |
75 | size_t entry = 0; |
76 | while (entry < sum_tolerance_entry.size() && |
77 | sum_tolerance_entry[entry][0] < reduction_size) { |
78 | entry++; |
79 | } |
80 | double abs_tol = 0.0; |
81 | if (entry + 1 < sum_tolerance_entry.size()) { |
82 | // Grab the next entry up so we have some margin |
83 | abs_tol = sum_tolerance_entry[entry + 1][1]; |
84 | } else { |
85 | // If we hit the end of the list, return twice the max error we |
86 | // measured |
87 | abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.; |
88 | } |
89 | // Relative tol we're going to set to 1% of abs tol just for |
90 | // a small margin of rel error. |
91 | return {abs_tol, abs_tol * 0.01}; |
92 | } |
93 | } |
94 | case DataType::Half: { |
95 | // Copied from float case |
96 | const auto& sum_tolerance_entry = tolerances.sum_tolerances_half; |
97 | const auto& base_abs = tolerances.base_half_abs_tol; |
98 | const auto& base_rel = tolerances.base_half_rel_tol; |
99 | |
100 | if (reduction_size <= 1) { |
101 | // No reduction case |
102 | if (base_abs == -1 || base_rel == -1) { |
103 | return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]}; |
104 | } else { |
105 | return {base_abs, base_rel}; |
106 | } |
107 | } else { |
108 | // Reduction case |
109 | size_t entry = 0; |
110 | while (sum_tolerance_entry[entry][0] < reduction_size && |
111 | entry < sum_tolerance_entry.size()) { |
112 | entry++; |
113 | } |
114 | double abs_tol = 0.0; |
115 | if (entry + 1 < sum_tolerance_entry.size()) { |
116 | // Grab the next entry up so we have some margin |
117 | abs_tol = sum_tolerance_entry[entry + 1][1]; |
118 | } else { |
119 | // If we hit the end of the list, return twice the max error we |
120 | // measured |
121 | abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.; |
122 | } |
123 | // Relative tol we're going to set to 1% of abs tol just for |
124 | // a small margin of rel error. |
125 | return {abs_tol, abs_tol * 0.01}; |
126 | } |
127 | } |
128 | case DataType::BFloat16: { |
129 | // Copied from float case |
130 | const auto& sum_tolerance_entry = tolerances.sum_tolerances_half; |
131 | const auto& base_abs = tolerances.base_half_abs_tol; |
132 | const auto& base_rel = tolerances.base_half_rel_tol; |
133 | |
134 | if (reduction_size <= 1) { |
135 | // No reduction case |
136 | if (base_abs == -1 || base_rel == -1) { |
137 | return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]}; |
138 | } else { |
139 | return {base_abs * 10.0, base_rel * 10.0}; |
140 | } |
141 | } else { |
142 | // Reduction case |
143 | size_t entry = 0; |
144 | while (sum_tolerance_entry[entry][0] < reduction_size && |
145 | entry < sum_tolerance_entry.size()) { |
146 | entry++; |
147 | } |
148 | double abs_tol = 0.0; |
149 | if (entry + 1 < sum_tolerance_entry.size()) { |
150 | // Grab the next entry up so we have some margin |
151 | abs_tol = sum_tolerance_entry[entry + 1][1]; |
152 | } else { |
153 | // If we hit the end of the list, return twice the max error we |
154 | // measured |
155 | abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.; |
156 | } |
157 | // Relative tol we're going to set to 1% of abs tol just for |
158 | // a small margin of rel error. |
159 | return {abs_tol * 10.0, abs_tol * 0.01 * 10.0}; |
160 | } |
161 | } |
162 | case DataType::Int: |
163 | return {0.0, 0.0}; |
164 | case DataType::Int32: |
165 | return {0.0, 0.0}; |
166 | case DataType::Bool: |
167 | return {0.0, 0.0}; |
168 | default: |
169 | TORCH_INTERNAL_ASSERT( |
170 | false, "Do not have tolerance computation for type " , dtype, "." ); |
171 | } |
172 | } |
173 | |
174 | class ReductionSizeMapper : private IterVisitor { |
175 | public: |
176 | //! Runs through the fusion and determines how many reductions were performed |
177 | //! to compute each tensorview. |
178 | static std::unordered_map<TensorView*, int64_t> computeReductionSizes( |
179 | Fusion* fusion, |
180 | ExpressionEvaluator& expr_eval) { |
181 | ReductionSizeMapper mapper(fusion, expr_eval); |
182 | return mapper.reduction_map; |
183 | } |
184 | |
185 | private: |
186 | ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval) |
187 | : expr_eval_(expr_eval) { |
188 | // Initialize input values |
189 | for (auto inp : fusion->inputs()) { |
190 | if (inp->isA<TensorView>()) { |
191 | auto tv = inp->as<TensorView>(); |
192 | // Shouldn't have any reductions, but run it through analysis anyways. |
193 | reduction_map[tv] = getReductionSize(tv); |
194 | } |
195 | } |
196 | |
197 | IterVisitor::traverse(fusion); |
198 | |
199 | // catch up with dangling outputs; |
200 | for (auto out : fusion->outputs()) { |
201 | if (out->isA<TensorView>()) { |
202 | auto tv = out->as<TensorView>(); |
203 | // possible that we have a dangling output that's not generated by any |
204 | // expression. e.g. 0 workspace or null tensor |
205 | if (reduction_map.count(tv) == 0) { |
206 | // Shouldn't have any reductions, but run it through analysis anyways. |
207 | reduction_map[tv] = getReductionSize(tv); |
208 | } |
209 | } |
210 | } |
211 | } |
212 | |
213 | int64_t getReductionSize(const TensorView* tv) { |
214 | int64_t reduction_elements = 1; |
215 | for (auto id : tv->getMaybeRFactorDomain()) { |
216 | if (id->isReduction()) { |
217 | auto inferred_extent = expr_eval_.evaluate(id->extent()); |
218 | TORCH_INTERNAL_ASSERT( |
219 | inferred_extent.has_value(), |
220 | "Couldn't figure out what the dimensions of a tensorview is in evaluation for validation. " , |
221 | id, |
222 | " in " , |
223 | tv); |
224 | reduction_elements = |
225 | reduction_elements * inferred_extent->as<int64_t>(); |
226 | } |
227 | } |
228 | return reduction_elements; |
229 | } |
230 | |
231 | void handle(Expr* expr) override { |
232 | if (!ir_utils::isTvOp(expr)) { |
233 | return; |
234 | } |
235 | |
236 | int64_t inp_reduction_elements = 1; |
237 | for (auto inp : expr->inputs()) { |
238 | if (inp->isA<TensorView>()) { |
239 | if (auto tv = inp->as<TensorView>()) { |
240 | inp_reduction_elements = |
241 | std::max(inp_reduction_elements, reduction_map.at(tv)); |
242 | } |
243 | } |
244 | } |
245 | |
246 | for (auto out : expr->outputs()) { |
247 | if (out->isA<TensorView>()) { |
248 | auto tv = out->as<TensorView>(); |
249 | reduction_map[tv] = getReductionSize(tv) * inp_reduction_elements; |
250 | } |
251 | } |
252 | } |
253 | |
254 | private: |
255 | using IterVisitor::handle; |
256 | |
257 | std::unordered_map<TensorView*, int64_t> reduction_map; |
258 | ExpressionEvaluator& expr_eval_; |
259 | }; |
260 | |
261 | ExpressionEvaluator bindInputsAndLaunchParams( |
262 | Fusion* fusion, |
263 | const at::ArrayRef<IValue>& aten_inputs, |
264 | const LaunchParams& launch_constraints) { |
265 | // index_mode is not important here |
266 | KernelArgumentHolder argument_holder(KernelIndexMode::INT64); |
267 | argument_holder.push(aten_inputs); |
268 | |
269 | auto expr_eval = executor_utils::bindFusionInputs(argument_holder, fusion); |
270 | for (auto val : fusion->vals()) { |
271 | if (!val->isA<TensorView>()) { |
272 | continue; |
273 | } |
274 | |
275 | // Roughly taken from executor.cpp/computeLaunchParams |
276 | auto tv = val->as<TensorView>(); |
277 | for (auto id : tv->domain()->domain()) { |
278 | if (!(id->isThread() && id->extent()->definition() == nullptr)) { |
279 | continue; |
280 | } |
281 | |
282 | if (id->isBroadcast()) { |
283 | continue; |
284 | } |
285 | |
286 | auto extent = id->extent(); |
287 | auto inferred_extent = expr_eval.evaluate(extent); |
288 | auto p_type = id->getParallelType(); |
289 | |
290 | if (inferred_extent.has_value()) { |
291 | // This value could have been inferred, make sure it was set right. |
292 | TORCH_CHECK( |
293 | inferred_extent.value() == launch_constraints.getDim(p_type) || |
294 | launch_constraints.getRawVal(p_type) == -1, |
295 | "inferred that " , |
296 | p_type, |
297 | " should be set to " , |
298 | inferred_extent.value(), |
299 | " but launch constraints specified " , |
300 | launch_constraints.getRawVal(p_type)); |
301 | } else { |
302 | // Bind the launch constraint into our evaluation context |
303 | if (launch_constraints.hasDim(id->getParallelType())) { |
304 | expr_eval.bind(extent, launch_constraints.getDim(p_type)); |
305 | } |
306 | } |
307 | } |
308 | } |
309 | return expr_eval; |
310 | } |
311 | |
312 | // Validation will look through the fusion and figure out how many elements were |
313 | // reduced to create each output. It will then compute a tolernace to use for |
314 | // allclose based on experimental results. The experimental results were based |
315 | // on adding two tensors then summing them. This of course has an assumption |
316 | // that we're always summing values between -2 and 2. If we start summing values |
317 | // larger than that this approach might not hold. |
318 | void testValidate( |
319 | Fusion* fusion, |
320 | const std::vector<at::Tensor>& fusion_outputs, |
321 | const at::ArrayRef<IValue>& aten_inputs, |
322 | const std::vector<at::Tensor>& aten_outputs, |
323 | int line_number, |
324 | const char* file_name, |
325 | std::string err_msg = "" , |
326 | const LaunchParams& lparams = LaunchParams(), |
327 | const ValidationConstants& tolerances = ValidationConstants()) { |
328 | FusionGuard fg(fusion); |
329 | |
330 | auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams); |
331 | |
332 | auto reduction_sizes = |
333 | ReductionSizeMapper::computeReductionSizes(fusion, expr_eval); |
334 | |
335 | auto output_alias_indices = fusion->getOutputAliasIndices(); |
336 | |
337 | TORCH_INTERNAL_ASSERT( |
338 | fusion_outputs.size() == aten_outputs.size() && |
339 | aten_outputs.size() == |
340 | fusion->outputs().size() - output_alias_indices.size(), |
341 | "Number of outputs don't match." ); |
342 | |
343 | TORCH_INTERNAL_ASSERT( |
344 | fusion->inputs().size() == aten_inputs.size(), |
345 | "Number of inputs don't match." ); |
346 | |
347 | for (size_t i = 0; i < fusion->inputs().size(); i++) { |
348 | if (fusion->inputs()[i]->isA<TensorView>()) { |
349 | TORCH_INTERNAL_ASSERT( |
350 | aten_inputs[i].isTensor(), "Mismatch of tensor inputs." ); |
351 | |
352 | auto fusion_input_tv = fusion->inputs()[i]->as<TensorView>(); |
353 | auto at_tensor = aten_inputs[i].toTensor(); |
354 | |
355 | TORCH_INTERNAL_ASSERT( |
356 | at_tensor.dim() == |
357 | static_cast<int64_t>(TensorDomain::noReductions( |
358 | fusion_input_tv->getMaybeRFactorDomain()) |
359 | .size()), |
360 | "Dimensionality mismatch in inputs." ); |
361 | } |
362 | } |
363 | |
364 | for (size_t i = 0, j = 0; i < fusion->outputs().size(); i++) { |
365 | TORCH_INTERNAL_ASSERT( |
366 | fusion->outputs()[i]->isA<TensorView>(), "Mismatch of tensor outputs." ); |
367 | if (output_alias_indices.count(i) != 0) { |
368 | // this is an aliased output, let's not check this; |
369 | continue; |
370 | } |
371 | |
372 | auto fusion_output_tensor = fusion_outputs[j]; |
373 | auto fusion_output_tv = fusion->outputs()[i]->as<TensorView>(); |
374 | auto aten_output_tensor = aten_outputs[j]; |
375 | |
376 | TORCH_INTERNAL_ASSERT( |
377 | reduction_sizes.count(fusion_output_tv), |
378 | "Missed reduction size count on fusion output at index: " , |
379 | i); |
380 | |
381 | int64_t reduction_size = reduction_sizes.at(fusion_output_tv); |
382 | |
383 | TORCH_INTERNAL_ASSERT( |
384 | aten_output_tensor.dim() == fusion_output_tensor.dim() && |
385 | fusion_outputs[j].dim() == |
386 | static_cast<int64_t>( |
387 | TensorDomain::noReductions( |
388 | fusion_output_tv->getMaybeRFactorDomain()) |
389 | .size()), |
390 | "Dimensionality mismatch in outputs." ); |
391 | |
392 | auto tolerance_values = getTolerance( |
393 | fusion_output_tv->getDataType().value(), reduction_size, tolerances); |
394 | |
395 | if (aten_output_tensor.is_floating_point() || |
396 | aten_output_tensor.is_complex()) { |
397 | TORCH_INTERNAL_ASSERT( |
398 | aten_output_tensor.allclose( |
399 | fusion_output_tensor.to(aten_output_tensor.dtype()), |
400 | tolerance_values.second, |
401 | tolerance_values.first), |
402 | "\n" , |
403 | err_msg, |
404 | "\nValidation error in output " , |
405 | j, |
406 | " on line " , |
407 | line_number, |
408 | " in file " , |
409 | file_name, |
410 | ".\n Detected abs error of: " , |
411 | aten_output_tensor.sub(fusion_output_tensor) |
412 | .abs() |
413 | .max() |
414 | .item() |
415 | .to<double>(), |
416 | "\n absolute tolerance was set to " , |
417 | tolerance_values.first, |
418 | "\n and relative tolerance set to " , |
419 | tolerance_values.second); |
420 | } else { |
421 | TORCH_INTERNAL_ASSERT( |
422 | aten_output_tensor.equal( |
423 | fusion_output_tensor.to(aten_output_tensor.dtype())), |
424 | "\n" , |
425 | err_msg, |
426 | ".\n Validation error in output " , |
427 | j, |
428 | " on line " , |
429 | line_number, |
430 | " in file " , |
431 | file_name, |
432 | ".\n Values are not equal and are not a floating type." ); |
433 | } |
434 | j++; |
435 | } |
436 | } |
437 | |
438 | } // namespace |
439 | } // namespace jit |
440 | } // namespace torch |
441 | |