1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_SHUFFLE_BY_REORDER_HPP
18#define GPU_OCL_SHUFFLE_BY_REORDER_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/reorder.hpp"
23#include "common/reorder_pd.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_primitive.hpp"
26#include "gpu/gpu_resource.hpp"
27#include "gpu/gpu_shuffle_pd.hpp"
28#include "gpu/ocl/ocl_engine.hpp"
29#include "gpu/ocl/ocl_stream.hpp"
30#include "gpu/primitive_conf.hpp"
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace ocl {
35
36// Implements shuffle using reorder kernel.
37// Pretends that instead of the one dimension to be shuffled there are two
38// smaller dimensions, then reorders the tensor to swap those two.
39// Reorder kernel is used more often so is expected to be better optimized.
40struct shuffle_by_reorder_t : public gpu_primitive_t {
41 using gpu_primitive_t::gpu_primitive_t;
42 struct pd_t : public gpu_shuffle_pd_t {
43 using gpu_shuffle_pd_t::gpu_shuffle_pd_t;
44
45 DECLARE_COMMON_PD_T("ocl:reorder:any", shuffle_by_reorder_t);
46
47 status_t init(engine_t *engine) {
48 const auto &md_src = is_fwd() ? src_md() : diff_src_md();
49 const auto &md_dst = is_fwd() ? dst_md() : diff_dst_md();
50 const memory_desc_wrapper src_d(md_src);
51 const memory_desc_wrapper dst_d(md_dst);
52
53 bool ok = src_d.data_type() == dst_d.data_type()
54 && md_src->format_kind == format_kind::blocked
55 && attr()->has_default_values()
56 && set_default_formats_common() && src_d == dst_d
57 && src_d.is_dense();
58 if (!ok) return status::unimplemented;
59
60 // Abort if there's blocking on the dimension that's going to be
61 // shuffled; such shuffle cannot be reduced to simple reorder.
62 // TODO: if both group_size and groups are multiples of blocking it
63 // still could be possible to use reorder.
64 for (int i = 0; i < md_src->format_desc.blocking.inner_nblks; i++) {
65 if (md_src->format_desc.blocking.inner_idxs[i] == axis()) {
66 return status::unimplemented;
67 }
68 }
69
70 auto tensor_size
71 = utils::array_product(md_src->dims, md_src->ndims);
72 // groups, group_size() are sizes of the two fake dimensions
73 // groups * group_size() == size of the original single dimension
74 auto groups = md_src->dims[axis()] / group_size();
75 // prepare 2 dimensions to be reordered
76 auto tr_rows = is_fwd() ? group_size() : groups;
77 auto tr_cols = is_fwd() ? groups : group_size();
78 // combine all dimensions below axis() together with all blocks
79 // into a single dimension that's not going to be reordered
80 auto stride_of_axis = md_src->format_desc.blocking.strides[axis()];
81 // combine all dimensions above axis into a single dimension
82 // that's not going to be reordered
83 auto remaining = tensor_size
84 / md_src->format_desc.blocking.strides[axis()] / tr_cols
85 / tr_rows;
86
87 memory_desc_t fake_src;
88 memory_desc_t fake_dst;
89
90 dims_t d = {remaining, tr_cols, tr_rows, stride_of_axis};
91 dims_t strides_src = {d[3] * d[2] * d[1], d[3] * d[2], d[3], 1};
92 dims_t strides_dst = {d[3] * d[2] * d[1], d[3], d[1] * d[3], 1};
93
94 CHECK(memory_desc_init_by_strides(
95 fake_src, 4, d, md_src->data_type, strides_src));
96 CHECK(memory_desc_init_by_strides(
97 fake_dst, 4, d, md_src->data_type, strides_dst));
98
99 CHECK(reorder_primitive_desc_create(
100 reorder_pd_, engine, &fake_src, &fake_dst));
101 return status::success;
102 }
103
104 std::shared_ptr<primitive_desc_t> reorder_pd_;
105 };
106
107 status_t init(engine_t *engine) override {
108 return create_nested_primitive(reorder_, pd()->reorder_pd_, engine);
109 }
110
111 status_t execute(const exec_ctx_t &ctx) const override {
112 using namespace memory_tracking::names;
113 exec_args_t r_args;
114
115 auto src = pd()->is_fwd() ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
116 auto dst = pd()->is_fwd() ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
117
118 r_args[DNNL_ARG_SRC] = ctx.args().at(src);
119 r_args[DNNL_ARG_DST] = ctx.args().at(dst);
120 exec_ctx_t r_ctx(ctx, std::move(r_args));
121
122 return reorder_->execute(r_ctx);
123 }
124
125private:
126 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
127 std::shared_ptr<primitive_t> reorder_;
128};
129} // namespace ocl
130} // namespace gpu
131} // namespace impl
132} // namespace dnnl
133
134#endif
135