1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_shuffle.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace ocl {
23
24using namespace format_tag;
25
26status_t ref_shuffle_t::pd_t::init_conf(engine_t *engine) {
27 const memory_desc_wrapper input_mdw(is_fwd() ? src_md() : diff_dst_md());
28 conf.data_type = input_mdw.data_type();
29 const memory_desc_wrapper output_mdw(is_fwd() ? dst_md() : diff_src_md());
30
31 conf.src_md_info = memory_desc_info_t::create(input_mdw);
32 conf.dst_md_info = memory_desc_info_t::create(output_mdw);
33
34 conf.axis = axis();
35
36 conf.transpose_col = is_fwd() ? group_size() : axis_size() / group_size();
37 conf.transpose_row = is_fwd() ? axis_size() / group_size() : group_size();
38
39 set_offsets(input_mdw, off.src_off);
40
41 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
42 conf.dispatch = compute_engine->create_dispatch(input_mdw.md_);
43 for (int i = 0; i < MAX_NDIMS; ++i) {
44 auto dim_str = utils::format("D%d", i);
45 if (i < input_mdw.ndims()) {
46 conf.dispatch.define_dim(dim_str, i, input_mdw.dims()[i], 1);
47 } else {
48 conf.dispatch.define_dim(dim_str, 1);
49 }
50 }
51 conf.dispatch.generate();
52
53 return status::success;
54}
55
56status_t ref_shuffle_t::pd_t::init_kernel_ctx(
57 compute::kernel_ctx_t &kernel_ctx) const {
58 kernel_ctx.set_data_type(conf.data_type);
59 kernel_ctx.define_int("AXIS", conf.axis);
60 kernel_ctx.define_int("TRANSPOSE_ROW", conf.transpose_row);
61 kernel_ctx.define_int("TRANSPOSE_COL", conf.transpose_col);
62
63 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
64 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
65 def_dispatch(kernel_ctx, conf.dispatch);
66
67 return status::success;
68}
69
70template <dnnl_format_tag_t tag>
71status_t ref_shuffle_t::execute_(const exec_ctx_t &ctx) const {
72 status_t status = status::success;
73
74 auto &src = pd()->is_fwd() ? CTX_IN_STORAGE(DNNL_ARG_SRC)
75 : CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
76 auto &dst = pd()->is_fwd()
77 ? CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status)
78 : CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DIFF_SRC, status);
79 CHECK(status);
80
81 const auto &conf = pd()->conf;
82
83 compute::kernel_arg_list_t arg_list;
84 arg_list.set(0, src);
85 arg_list.set(1, dst);
86
87 auto nd_range = conf.dispatch.nd_range();
88 status = parallel_for(ctx, nd_range, kernel_, arg_list);
89 return status;
90}
91template status_t ref_shuffle_t::execute_<any>(const exec_ctx_t &ctx) const;
92
93} // namespace ocl
94} // namespace gpu
95} // namespace impl
96} // namespace dnnl
97