1 | #ifdef TORCH_ENABLE_LLVM |
2 | |
3 | #include <gtest/gtest.h> |
4 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> |
5 | #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> |
6 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
7 | #include <torch/csrc/jit/tensorexpr/tensor.h> |
8 | #include <torch/torch.h> |
9 | #include <cstring> |
10 | |
11 | using namespace torch::indexing; |
12 | namespace te = torch::jit::tensorexpr; |
13 | |
14 | static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { |
15 | auto loops = ln->getLoopStmtsFor(target); |
16 | te::ForPtr inner, tail; |
17 | ln->splitWithTail(loops[0], width, &inner, &tail); |
18 | ASSERT_TRUE(te::LoopNest::vectorize(inner)); |
19 | } |
20 | |
21 | std::string diffs(const at::Tensor& a, const at::Tensor& b) { |
22 | auto diff = torch::abs(a.flatten() - b.flatten()); |
23 | auto count_diffs = torch::sum(diff > 0.f); |
24 | auto greatest_diff_index = torch::argmax(diff); |
25 | std::stringstream ss; |
26 | ss << "Found " << count_diffs << " unequal element(s). " |
27 | << "The greatest difference was " << diff.index({greatest_diff_index}) |
28 | << " at index " << greatest_diff_index; |
29 | return ss.str(); |
30 | } |
31 | |
32 | TEST(Approx, log_vml) { |
33 | te::VarHandle N("N" , te::kInt); |
34 | te::BufHandle A("A" , {N}, te::kFloat); |
35 | te::Tensor B = te::Compute( |
36 | "B" , {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); |
37 | |
38 | te::LoopNest ln({B}); |
39 | ln.prepareForCodegen(); |
40 | vectorize(&ln, B, 8); |
41 | te::StmtPtr s = ln.root_stmt(); |
42 | s = te::IRSimplifier::simplify(s); |
43 | te::LLVMCodeGen cg(s, {A, B, N}); |
44 | |
45 | auto eps = std::numeric_limits<float>::epsilon(); |
46 | auto test = [&](const at::Tensor& A_t) { |
47 | at::Tensor B_ref = at::log(A_t); |
48 | at::Tensor B_t = at::empty_like(A_t); |
49 | auto ap = A_t.data_ptr<float>(); |
50 | auto bp = B_t.data_ptr<float>(); |
51 | cg.call({ap, bp, A_t.numel()}); |
52 | // Results should be bit-identical. |
53 | ASSERT_TRUE(torch::allclose( |
54 | B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) |
55 | << "Input[:8]\n" |
56 | << A_t.index({Slice(0, 8)}) << "\n" |
57 | << "Test[:8]\n" |
58 | << B_t.index({Slice(0, 8)}) << "\n" |
59 | << "Ref[:8]\n" |
60 | << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); |
61 | }; |
62 | |
63 | // Generate every single-precision FP value in [1.0, 2.0). |
64 | at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); |
65 | ASSERT_EQ(A_t.numel(), 1 << 23); |
66 | |
67 | test(A_t); |
68 | |
69 | test(A_t * 2.0f); |
70 | test(A_t * 0.5f); |
71 | |
72 | test(A_t * 4.0f); |
73 | test(A_t * 0.25f); |
74 | |
75 | test(A_t * powf(2.0f, 16)); |
76 | test(A_t * powf(2.0f, -16)); |
77 | |
78 | test(A_t * powf(2.0f, 126)); |
79 | test(A_t * powf(2.0f, -126)); |
80 | |
81 | test(torch::full({32}, INFINITY)); |
82 | test(torch::full({32}, NAN)); |
83 | |
84 | auto min = std::numeric_limits<float>::min(); |
85 | auto denorm_min = std::numeric_limits<float>::denorm_min(); |
86 | |
87 | // Denormals aren't bit precise, because sleef isn't bit-precise either. |
88 | A_t = torch::arange(0.0f, min, denorm_min); |
89 | ASSERT_EQ(A_t.numel(), 1 << 23); |
90 | auto B_ref = at::log(A_t); |
91 | auto B_t = at::empty_like(B_ref); |
92 | cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()}); |
93 | ASSERT_TRUE(torch::allclose(B_t, B_ref)); |
94 | } |
95 | |
96 | #endif // TORCH_ENABLE_LLVM |
97 | |