1#include <gtest/gtest.h>
2#include <torch/csrc/jit/tensorexpr/eval.h>
3#include <torch/csrc/jit/tensorexpr/expr.h>
4#include <torch/csrc/jit/tensorexpr/loopnest.h>
5#include <torch/csrc/jit/tensorexpr/operators/operators.h>
6#include <torch/torch.h>
7
8using namespace torch::jit::tensorexpr;
9
10using Tensors = std::vector<Tensor>;
11using Args = std::vector<CodeGen::BufferArg>;
12std::unique_ptr<SimpleIREvaluator> compile(
13 const Args& inputs,
14 const Tensors& outputs) {
15 LoopNest nest({outputs});
16 nest.prepareForCodegen();
17 nest.simplify();
18 auto join = inputs;
19 join.insert(join.end(), outputs.begin(), outputs.end());
20 return std::make_unique<SimpleIREvaluator>(nest.root_stmt(), join);
21}
22
23TEST(Ops, Sum) {
24 constexpr int M = 8;
25 constexpr int N = 16;
26 std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
27 std::vector<std::vector<ExprHandle>> outputShapes = {{N}, {M}, {}};
28 for (unsigned idx = 0; idx < testDims.size(); idx++) {
29 const auto& dims = testDims[idx];
30 const auto& outShape = outputShapes[idx];
31
32 BufHandle a("a", {M, N}, kFloat);
33 std::vector<ExprHandle> outStrides =
34 c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
35 Tensor b = computeSum(
36 {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
37 auto cg = compile({a}, {b});
38
39 auto at = at::arange(M * N, at::kFloat).view({M, N});
40 auto ref = at::sum(at, dims);
41 auto bt = at::empty_like(ref);
42
43 cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
44
45 ASSERT_TRUE(at::allclose(bt, ref));
46 }
47}
48
49TEST(Ops, ChannelsLastSum) {
50 constexpr int A = 2;
51 constexpr int B = 3;
52 constexpr int C = 4;
53 constexpr int D = 5;
54 constexpr int E = 6;
55 std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
56
57 std::vector<std::vector<ExprHandle>> outputShapes = {
58 {B, C, D, E}, {A, C, D, E}, {C, D, E}};
59 for (unsigned idx = 0; idx < testDims.size(); idx++) {
60 const auto& dims = testDims[idx];
61 const auto& outShape = outputShapes[idx];
62
63 BufHandle a("a", {A, B, C, D, E}, kFloat);
64 std::vector<ExprHandle> outStrides =
65 c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
66 Tensor b = computeSum(
67 {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
68 auto cg = compile({a}, {b});
69
70 auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
71 auto ref = at::sum(at, dims);
72 auto bt = at::empty_like(ref);
73
74 cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
75
76 ASSERT_TRUE(at::allclose(bt, ref));
77 }
78}
79