1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_ZERO_PAD_REF_ZERO_PAD_HPP
18#define GPU_OCL_ZERO_PAD_REF_ZERO_PAD_HPP
19
20#include "gpu/gpu_primitive.hpp"
21#include "gpu/gpu_resource.hpp"
22#include "gpu/gpu_zero_pad_pd.hpp"
23#include "gpu/primitive_conf.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace ocl {
29
30struct ref_zero_pad_t : public gpu_primitive_t {
31 using gpu_primitive_t::gpu_primitive_t;
32 struct pd_t : public gpu_zero_pad_pd_t {
33 using gpu_zero_pad_pd_t::gpu_zero_pad_pd_t;
34
35 DECLARE_COMMON_PD_T("ocl:ref:any", ref_zero_pad_t);
36 status_t init(engine_t *engine) {
37 auto *compute_engine
38 = utils::downcast<compute::compute_engine_t *>(engine);
39 if (!compute_engine->mayiuse_sub_group(16))
40 return status::unimplemented;
41 return status::success;
42 }
43 };
44
45 ;
46
47 status_t init(engine_t *engine) override {
48 compute::kernel_ctx_t kernel_ctx;
49 create_kernel(engine, &kernel_, "ref_zero_pad", kernel_ctx);
50 create_kernel(
51 engine, &kernel_subg16_, "ref_zero_pad_subg_16", kernel_ctx);
52 create_kernel(engine, &kernel_subg16_mask_and_clear_dt_1b_,
53 "ref_zero_pad_subg_16_mask_and_clear_dt_1b", kernel_ctx);
54 if (!kernel_ || !kernel_subg16_ || !kernel_subg16_mask_and_clear_dt_1b_)
55 return status::runtime_error;
56 return status::success;
57 }
58 status_t execute(const exec_ctx_t &ctx) const override;
59
60private:
61 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
62 status_t execute_ref(const exec_ctx_t &ctx) const;
63 status_t execute_subg_16(const exec_ctx_t &ctx,
64 const memory_desc_wrapper &mdw,
65 const blocking_desc_t &blocking_desc) const;
66 status_t execute_subg_16_mask_and_clear_dt_1B(const exec_ctx_t &ctx,
67 const memory_desc_wrapper &mdw,
68 const blocking_desc_t &blocking_desc) const;
69 compute::kernel_t kernel_;
70 compute::kernel_t kernel_subg16_;
71 compute::kernel_t kernel_subg16_mask_and_clear_dt_1b_;
72};
73
74} // namespace ocl
75} // namespace gpu
76} // namespace impl
77} // namespace dnnl
78
79#endif
80