1#ifdef TORCH_ENABLE_LLVM
2
3#include <gtest/gtest.h>
4#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
6#include <torch/csrc/jit/tensorexpr/loopnest.h>
7#include <torch/csrc/jit/tensorexpr/tensor.h>
8#include <torch/torch.h>
9#include <cstring>
10
11using namespace torch::indexing;
12namespace te = torch::jit::tensorexpr;
13
14static void vectorize(te::LoopNest* ln, te::Tensor target, int width) {
15 auto loops = ln->getLoopStmtsFor(target);
16 te::ForPtr inner, tail;
17 ln->splitWithTail(loops[0], width, &inner, &tail);
18 ASSERT_TRUE(te::LoopNest::vectorize(inner));
19}
20
21std::string diffs(const at::Tensor& a, const at::Tensor& b) {
22 auto diff = torch::abs(a.flatten() - b.flatten());
23 auto count_diffs = torch::sum(diff > 0.f);
24 auto greatest_diff_index = torch::argmax(diff);
25 std::stringstream ss;
26 ss << "Found " << count_diffs << " unequal element(s). "
27 << "The greatest difference was " << diff.index({greatest_diff_index})
28 << " at index " << greatest_diff_index;
29 return ss.str();
30}
31
32TEST(Approx, log_vml) {
33 te::VarHandle N("N", te::kInt);
34 te::BufHandle A("A", {N}, te::kFloat);
35 te::Tensor B = te::Compute(
36 "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); });
37
38 te::LoopNest ln({B});
39 ln.prepareForCodegen();
40 vectorize(&ln, B, 8);
41 te::StmtPtr s = ln.root_stmt();
42 s = te::IRSimplifier::simplify(s);
43 te::LLVMCodeGen cg(s, {A, B, N});
44
45 auto eps = std::numeric_limits<float>::epsilon();
46 auto test = [&](const at::Tensor& A_t) {
47 at::Tensor B_ref = at::log(A_t);
48 at::Tensor B_t = at::empty_like(A_t);
49 auto ap = A_t.data_ptr<float>();
50 auto bp = B_t.data_ptr<float>();
51 cg.call({ap, bp, A_t.numel()});
52 // Results should be bit-identical.
53 ASSERT_TRUE(torch::allclose(
54 B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true))
55 << "Input[:8]\n"
56 << A_t.index({Slice(0, 8)}) << "\n"
57 << "Test[:8]\n"
58 << B_t.index({Slice(0, 8)}) << "\n"
59 << "Ref[:8]\n"
60 << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref);
61 };
62
63 // Generate every single-precision FP value in [1.0, 2.0).
64 at::Tensor A_t = torch::arange(1.0f, 2.0f, eps);
65 ASSERT_EQ(A_t.numel(), 1 << 23);
66
67 test(A_t);
68
69 test(A_t * 2.0f);
70 test(A_t * 0.5f);
71
72 test(A_t * 4.0f);
73 test(A_t * 0.25f);
74
75 test(A_t * powf(2.0f, 16));
76 test(A_t * powf(2.0f, -16));
77
78 test(A_t * powf(2.0f, 126));
79 test(A_t * powf(2.0f, -126));
80
81 test(torch::full({32}, INFINITY));
82 test(torch::full({32}, NAN));
83
84 auto min = std::numeric_limits<float>::min();
85 auto denorm_min = std::numeric_limits<float>::denorm_min();
86
87 // Denormals aren't bit precise, because sleef isn't bit-precise either.
88 A_t = torch::arange(0.0f, min, denorm_min);
89 ASSERT_EQ(A_t.numel(), 1 << 23);
90 auto B_ref = at::log(A_t);
91 auto B_t = at::empty_like(B_ref);
92 cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()});
93 ASSERT_TRUE(torch::allclose(B_t, B_ref));
94}
95
96#endif // TORCH_ENABLE_LLVM
97