1 | #include <gtest/gtest.h> |
2 | #include <torch/csrc/jit/tensorexpr/eval.h> |
3 | #include <torch/csrc/jit/tensorexpr/expr.h> |
4 | #include <torch/csrc/jit/tensorexpr/loopnest.h> |
5 | #include <torch/csrc/jit/tensorexpr/operators/operators.h> |
6 | #include <torch/torch.h> |
7 | |
8 | using namespace torch::jit::tensorexpr; |
9 | |
10 | using Tensors = std::vector<Tensor>; |
11 | using Args = std::vector<CodeGen::BufferArg>; |
12 | std::unique_ptr<SimpleIREvaluator> compile( |
13 | const Args& inputs, |
14 | const Tensors& outputs) { |
15 | LoopNest nest({outputs}); |
16 | nest.prepareForCodegen(); |
17 | nest.simplify(); |
18 | auto join = inputs; |
19 | join.insert(join.end(), outputs.begin(), outputs.end()); |
20 | return std::make_unique<SimpleIREvaluator>(nest.root_stmt(), join); |
21 | } |
22 | |
23 | TEST(Ops, Sum) { |
24 | constexpr int M = 8; |
25 | constexpr int N = 16; |
26 | std::vector<IntList> testDims = {{0}, {1}, {0, 1}}; |
27 | std::vector<std::vector<ExprHandle>> outputShapes = {{N}, {M}, {}}; |
28 | for (unsigned idx = 0; idx < testDims.size(); idx++) { |
29 | const auto& dims = testDims[idx]; |
30 | const auto& outShape = outputShapes[idx]; |
31 | |
32 | BufHandle a("a" , {M, N}, kFloat); |
33 | std::vector<ExprHandle> outStrides = |
34 | c10::fmap<ExprHandle>(make_contiguous_strides(outShape)); |
35 | Tensor b = computeSum( |
36 | {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); |
37 | auto cg = compile({a}, {b}); |
38 | |
39 | auto at = at::arange(M * N, at::kFloat).view({M, N}); |
40 | auto ref = at::sum(at, dims); |
41 | auto bt = at::empty_like(ref); |
42 | |
43 | cg->call({at.data_ptr<float>(), bt.data_ptr<float>()}); |
44 | |
45 | ASSERT_TRUE(at::allclose(bt, ref)); |
46 | } |
47 | } |
48 | |
49 | TEST(Ops, ChannelsLastSum) { |
50 | constexpr int A = 2; |
51 | constexpr int B = 3; |
52 | constexpr int C = 4; |
53 | constexpr int D = 5; |
54 | constexpr int E = 6; |
55 | std::vector<IntList> testDims = {{0}, {1}, {0, 1}}; |
56 | |
57 | std::vector<std::vector<ExprHandle>> outputShapes = { |
58 | {B, C, D, E}, {A, C, D, E}, {C, D, E}}; |
59 | for (unsigned idx = 0; idx < testDims.size(); idx++) { |
60 | const auto& dims = testDims[idx]; |
61 | const auto& outShape = outputShapes[idx]; |
62 | |
63 | BufHandle a("a" , {A, B, C, D, E}, kFloat); |
64 | std::vector<ExprHandle> outStrides = |
65 | c10::fmap<ExprHandle>(make_channels_last_strides(outShape)); |
66 | Tensor b = computeSum( |
67 | {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); |
68 | auto cg = compile({a}, {b}); |
69 | |
70 | auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); |
71 | auto ref = at::sum(at, dims); |
72 | auto bt = at::empty_like(ref); |
73 | |
74 | cg->call({at.data_ptr<float>(), bt.data_ptr<float>()}); |
75 | |
76 | ASSERT_TRUE(at::allclose(bt, ref)); |
77 | } |
78 | } |
79 | |