1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/reorder/gen_reorder.hpp"
18
19#include <iostream>
20#include <utility>
21
22#include "common/impl_registration.hpp"
23#include "common/utils.hpp"
24#include "common/verbose.hpp"
25#include "gpu/jit/ir/kernel_info.hpp"
26#include "gpu/jit/reorder/config.hpp"
27#include "gpu/jit/reorder/reorder_kernel.hpp"
28#include "gpu/jit/utils/utils.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace jit {
35
36status_t gen_reorder_t::pd_t::init(
37 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
38 const auto src_dt = src_md()->data_type;
39 const auto dst_dt = dst_md()->data_type;
40 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
41 auto *device_info = compute_engine->device_info();
42 bool ok = src_engine == dst_engine && src_engine->kind() == engine_kind::gpu
43 && device_info->gpu_arch() > compute::gpu_arch_t::xe_lp
44 && IMPLICATION(src_dt == data_type::f16 || dst_dt == data_type::f16,
45 compute_engine->mayiuse(compute::device_ext_t::khr_fp16))
46 && IMPLICATION(src_dt == data_type::bf16,
47 utils::one_of(dst_dt, data_type::bf16, data_type::f32))
48 && IMPLICATION(dst_dt == data_type::bf16,
49 utils::one_of(src_dt, data_type::bf16, data_type::f32))
50 && IMPLICATION(src_dt == data_type::f64 || dst_dt == data_type::f64,
51 compute_engine->mayiuse(compute::device_ext_t::khr_fp64))
52 && attr()->has_default_values() && extra_ok();
53 if (!ok) return status::unimplemented;
54
55 memory_desc_wrapper src_mdw {src_md()};
56 memory_desc_wrapper dst_mdw {dst_md()};
57 if (src_mdw.has_runtime_dims_or_strides()) return status::unimplemented;
58 if (src_mdw.ndims() != dst_mdw.ndims()) return status::unimplemented;
59 int ndims = src_mdw.ndims();
60
61 layout_t src_layout {src_mdw, /*do_normalize=*/false};
62 layout_t dst_layout {dst_mdw, /*do_normalize=*/false};
63
64 if (src_layout.elems() == 0 || dst_layout.elems() == 0)
65 return status::unimplemented;
66
67 std::vector<dim_t> dims(ndims);
68 for (int i = 0; i < ndims; ++i)
69 dims[i] = std::max(src_layout.dim(i), dst_layout.dim(i));
70
71 auto check_layout = [&](const layout_t &l) {
72 for (auto &eb : l.enumerated_blocks()) {
73 auto &b = eb.second;
74 if (l.is_outermost(eb)) {
75 dim_t inner = l.dim(b.dim_idx) / b.block;
76 if (dims[b.dim_idx] % inner) return false;
77 }
78 }
79 return true;
80 };
81
82 if (!check_layout(src_layout)) return status::unimplemented;
83 if (!check_layout(dst_layout)) return status::unimplemented;
84 if (!compute_engine->mayiuse_ngen_kernels()) return status::unimplemented;
85 cfg = std::make_shared<reorder_config_t>(engine, src_md(), dst_md());
86 cfg->exec_cfg.set_regs(128);
87 cfg->exec_cfg.set_simd(16);
88 CHECK(init_kernel_info());
89
90 return status::success;
91}
92
93status_t gen_reorder_t::pd_t::init_kernel_info() {
94 auto &info = kernel_info;
95 auto elems = cfg->dst_layout.elems();
96
97 info = std::make_shared<kernel_info_t>();
98 auto src_buf = make_buffer("src_user");
99 auto dst_buf = make_buffer("dst_user");
100 info->register_user_arg(src_buf, DNNL_ARG_SRC, /*is_input=*/true);
101 info->register_user_arg(dst_buf, DNNL_ARG_DST, /*is_input=*/false);
102 auto elems_var = var_t::make(type_t::u32(), "elems");
103 info->register_internal_arg(elems_var, uint32_t(elems));
104 info->set_nd_range(reorder_kernel_t<>::nd_range(
105 cfg->exec_cfg, cfg->src_layout, cfg->dst_layout));
106
107 return status::success;
108}
109
110status_t gen_reorder_t::init(engine_t *engine) {
111 auto &cfg = *pd()->cfg;
112 auto &info = *pd()->kernel_info;
113
114 kernel_ = make_kernel<reorder_kernel_t>(this, engine, cfg.exec_cfg,
115 "gen_reorder", info, cfg.src_layout, cfg.dst_layout, false,
116 grf_mode_t::any);
117 if (!kernel_) return status::runtime_error;
118 return status::success;
119}
120
121status_t gen_reorder_t::execute(const exec_ctx_t &ctx) const {
122 auto &info = *pd()->kernel_info;
123
124 std::vector<memory_storage_wrapper_t> storage_list;
125 info.init_memory_storage_list(storage_list, ctx, this);
126
127 compute::kernel_arg_list_t arg_list;
128 info.set_args(arg_list, storage_list);
129
130 CHECK(parallel_for(ctx, info.nd_range(), kernel_, arg_list));
131 return status::success;
132}
133
134} // namespace jit
135} // namespace gpu
136} // namespace impl
137} // namespace dnnl
138