1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEN9_CONCAT_HPP |
18 | #define GPU_OCL_GEN9_CONCAT_HPP |
19 | |
20 | #include "common/engine.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/reorder.hpp" |
23 | #include "common/reorder_pd.hpp" |
24 | #include "common/stream.hpp" |
25 | #include "gpu/gpu_concat_pd.hpp" |
26 | #include "gpu/gpu_primitive.hpp" |
27 | #include "gpu/ocl/ocl_utils.hpp" |
28 | #include "gpu/primitive_conf.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace gpu { |
33 | namespace ocl { |
34 | |
35 | struct gen9_concat_t : public gpu_primitive_t { |
36 | using gpu_primitive_t::gpu_primitive_t; |
37 | struct pd_t : public gpu_concat_pd_t { |
38 | pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, |
39 | int concat_dim, const memory_desc_t *const *src_mds) |
40 | : gpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds) {} |
41 | |
42 | pd_t(const pd_t &rhs) = default; |
43 | ~pd_t() = default; |
44 | |
45 | DECLARE_CONCAT_PD_T("gen9:any" , gen9_concat_t); |
46 | |
47 | status_t init(engine_t *engine) { |
48 | bool ok = n_inputs() <= 16 && attr()->has_default_values() |
49 | && set_default_params() == status::success |
50 | && !memory_desc_ndims_ok(dst_md()); |
51 | if (!ok) return status::unimplemented; |
52 | |
53 | return init_conf(engine); |
54 | } |
55 | |
56 | status_t init_conf(engine_t *engine); |
57 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
58 | |
59 | concat_conf_t conf; |
60 | |
61 | protected: |
62 | bool can_use_sub_group_size( |
63 | const compute::compute_engine_t *compute_engine, |
64 | int sub_group_size); |
65 | int calculate_sub_group_size( |
66 | const compute::compute_engine_t *compute_engine); |
67 | std::pair<int, int> calculate_iter_dim_idx_chunk(int num_threads) const; |
68 | }; |
69 | |
70 | status_t init(engine_t *engine) override { |
71 | compute::kernel_ctx_t kernel_ctx; |
72 | |
73 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
74 | CHECK(status); |
75 | |
76 | status = create_kernel(engine, &kernel, "gen9_concat" , kernel_ctx); |
77 | CHECK(status); |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
83 | return execute_concat(ctx); |
84 | } |
85 | |
86 | private: |
87 | status_t execute_concat(const exec_ctx_t &ctx) const; |
88 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
89 | |
90 | compute::kernel_t kernel; |
91 | }; |
92 | |
93 | } // namespace ocl |
94 | } // namespace gpu |
95 | } // namespace impl |
96 | } // namespace dnnl |
97 | |
98 | #endif //GPU_OCL_GEN9_CONCAT_HPP |
99 | |