1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_CONCAT_HPP
18#define GPU_OCL_GEN9_CONCAT_HPP
19
20#include "common/engine.hpp"
21#include "common/primitive.hpp"
22#include "common/reorder.hpp"
23#include "common/reorder_pd.hpp"
24#include "common/stream.hpp"
25#include "gpu/gpu_concat_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct gen9_concat_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_concat_pd_t {
38 pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
39 int concat_dim, const memory_desc_t *const *src_mds)
40 : gpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds) {}
41
42 pd_t(const pd_t &rhs) = default;
43 ~pd_t() = default;
44
45 DECLARE_CONCAT_PD_T("gen9:any", gen9_concat_t);
46
47 status_t init(engine_t *engine) {
48 bool ok = n_inputs() <= 16 && attr()->has_default_values()
49 && set_default_params() == status::success
50 && !memory_desc_ndims_ok(dst_md());
51 if (!ok) return status::unimplemented;
52
53 return init_conf(engine);
54 }
55
56 status_t init_conf(engine_t *engine);
57 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
58
59 concat_conf_t conf;
60
61 protected:
62 bool can_use_sub_group_size(
63 const compute::compute_engine_t *compute_engine,
64 int sub_group_size);
65 int calculate_sub_group_size(
66 const compute::compute_engine_t *compute_engine);
67 std::pair<int, int> calculate_iter_dim_idx_chunk(int num_threads) const;
68 };
69
70 status_t init(engine_t *engine) override {
71 compute::kernel_ctx_t kernel_ctx;
72
73 status_t status = pd()->init_kernel_ctx(kernel_ctx);
74 CHECK(status);
75
76 status = create_kernel(engine, &kernel, "gen9_concat", kernel_ctx);
77 CHECK(status);
78
79 return status::success;
80 }
81
82 virtual status_t execute(const exec_ctx_t &ctx) const override {
83 return execute_concat(ctx);
84 }
85
86private:
87 status_t execute_concat(const exec_ctx_t &ctx) const;
88 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
89
90 compute::kernel_t kernel;
91};
92
93} // namespace ocl
94} // namespace gpu
95} // namespace impl
96} // namespace dnnl
97
98#endif //GPU_OCL_GEN9_CONCAT_HPP
99