1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_RESAMPLING_HPP
18#define GPU_OCL_REF_RESAMPLING_HPP
19
20#include "gpu/gpu_primitive.hpp"
21#include "gpu/gpu_resampling_pd.hpp"
22#include "gpu/primitive_conf.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29struct ref_resampling_fwd_t : public gpu_primitive_t {
30 using gpu_primitive_t::gpu_primitive_t;
31 struct pd_t : public gpu_resampling_fwd_pd_t {
32 pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr,
33 const resampling_fwd_pd_t *hint_fwd_pd)
34 : gpu_resampling_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
35 virtual ~pd_t() {}
36
37 DECLARE_COMMON_PD_T("ref:any", ref_resampling_fwd_t);
38
39 status_t init(engine_t *engine) {
40 using namespace data_type;
41 assert(engine->kind() == engine_kind::gpu);
42 using sm = primitive_attr_t::skip_mask_t;
43 const auto attr_skip_mask = sm::post_ops;
44
45 bool ok = is_fwd() && set_default_params() == status::success
46 && attr()->has_default_values(attr_skip_mask)
47 && post_ops_with_binary_ok(attr(), dst_md()->data_type, 5)
48 && attr_.set_default_formats(dst_md(0)) == status::success;
49 if (!ok) return status::unimplemented;
50
51 return init_conf(engine);
52 }
53 compute::dispatch_t dispatch;
54 resampling_conf_t conf;
55
56 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
57 status_t init_conf(engine_t *engine);
58 };
59
60 status_t init(engine_t *engine) override {
61 using namespace alg_kind;
62
63 compute::kernel_ctx_t kernel_ctx;
64 status_t status = pd()->init_kernel_ctx(kernel_ctx);
65 CHECK(status);
66
67 create_kernel(engine, &kernel_, "ref_resampling_fwd", kernel_ctx);
68 if (!kernel_) return status::runtime_error;
69
70 return status::success;
71 }
72
73 status_t execute(const exec_ctx_t &ctx) const override {
74 return execute_forward(ctx);
75 }
76
77private:
78 status_t execute_forward(const exec_ctx_t &ctx) const;
79 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
80 compute::kernel_t kernel_;
81};
82
83struct ref_resampling_bwd_t : public gpu_primitive_t {
84 using gpu_primitive_t::gpu_primitive_t;
85 struct pd_t : public gpu_resampling_bwd_pd_t {
86 pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr,
87 const resampling_fwd_pd_t *hint_fwd_pd)
88 : gpu_resampling_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
89 virtual ~pd_t() {}
90
91 DECLARE_COMMON_PD_T("ref:any", ref_resampling_bwd_t);
92
93 status_t init(engine_t *engine) {
94 using namespace data_type;
95 assert(engine->kind() == engine_kind::gpu);
96 bool ok = !is_fwd() && set_default_params() == status::success
97 && attr()->has_default_values();
98 if (!ok) return status::unimplemented;
99
100 return init_conf(engine);
101 }
102 resampling_conf_t conf;
103
104 status_t init_conf(engine_t *engine);
105 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
106 };
107
108 status_t init(engine_t *engine) override {
109 using namespace alg_kind;
110
111 compute::kernel_ctx_t kernel_ctx;
112 status_t status = pd()->init_kernel_ctx(kernel_ctx);
113 CHECK(status);
114
115 create_kernel(engine, &kernel_, "ref_resampling_bwd", kernel_ctx);
116 if (!kernel_) return status::runtime_error;
117
118 return status::success;
119 }
120
121 status_t execute(const exec_ctx_t &ctx) const override {
122 return execute_backward(ctx);
123 }
124
125private:
126 status_t execute_backward(const exec_ctx_t &ctx) const;
127 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
128 compute::kernel_t kernel_;
129};
130
131} // namespace ocl
132} // namespace gpu
133} // namespace impl
134} // namespace dnnl
135
136#endif
137