1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/reduce.hpp"
18
19#include <vector>
20
21#include "gpu/jit/ir/tensor.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28stmt_t create_reduce_stmt(const layout_t &src, const layout_t &dst,
29 const expr_t &src_buf, const expr_t &dst_buf, const tensor_t &_subtile,
30 uint32_t reduction_mask, bool drop_dims) {
31 auto subtile = _subtile;
32 if (subtile.is_empty()) subtile = tensor_t(src.dims());
33 ir_assert(src.ndims() == subtile.ndims());
34 int ndims = src.ndims();
35
36 // Align dst layout with src layout according to the mask if needed.
37 layout_t dst_aligned;
38 if (drop_dims) {
39 std::vector<int> dst2src(dst.ndims());
40 int dst_dim_idx = 0;
41 for (int i = 0; i < ndims; i++) {
42 if ((reduction_mask & (1 << i)) != 0) {
43 dst2src[dst_dim_idx] = i;
44 dst_dim_idx++;
45 }
46 }
47 ir_assert(dst_dim_idx == dst.ndims()) << "Incompatible reduction mask.";
48
49 auto dst_blocks = dst.blocks();
50 for (auto &b : dst_blocks)
51 b.dim_idx = dst2src[b.dim_idx];
52
53 // Create final layout.
54 dst_aligned = layout_t(dst.type(), ndims, dst.offset(), dst_blocks);
55 } else {
56 dst_aligned = dst;
57 }
58
59 std::vector<dim_t> dst_tile_dims = subtile.dims();
60 std::vector<expr_t> dst_tile_start = subtile.start();
61 for (int i = 0; i < ndims; i++) {
62 if ((reduction_mask & (1 << i)) == 0) {
63 dst_tile_dims[i] = 1;
64 dst_tile_start[i] = expr_t(0);
65 continue;
66 }
67 }
68 dst_aligned = dst_aligned.map(tensor_t(dst_tile_dims, dst_tile_start));
69
70 auto func = reduce_t::make(src, dst_aligned);
71 return func.call({dst_buf, src_buf});
72}
73
74} // namespace jit
75} // namespace gpu
76} // namespace impl
77} // namespace dnnl
78