1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_CONVOLUTION_INNER_PRODUCT_HPP
18#define GPU_OCL_CONVOLUTION_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "gpu/compute/compute.hpp"
26#include "gpu/gpu_inner_product_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/ocl/ocl_stream.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37struct convolution_inner_product_fwd_t : public gpu_primitive_t {
38 struct pd_t : public gpu_inner_product_fwd_pd_t {
39 using gpu_inner_product_fwd_pd_t::gpu_inner_product_fwd_pd_t;
40
41 pd_t(const pd_t &rhs) = default;
42
43 DECLARE_COMMON_PD_T("ocl:conv", convolution_inner_product_fwd_t);
44
45 status_t init(engine_t *engine) {
46 using namespace data_type;
47 using namespace prop_kind;
48 using namespace data_type;
49 assert(engine->kind() == engine_kind::gpu);
50 auto *compute_engine
51 = utils::downcast<compute::compute_engine_t *>(engine);
52
53 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::scales
54 | primitive_attr_t::skip_mask_t::post_ops;
55
56 bool ok = true
57 && utils::one_of(desc()->prop_kind, forward_training,
58 forward_inference)
59 && set_default_params(true) == status::success
60 && IMPLICATION(with_bias(),
61 utils::one_of(desc()->bias_desc.data_type, u8, s8,
62 bf16, f16, f32))
63 && attr()->has_default_values(attr_skip_mask)
64 && post_ops_with_binary_ok(
65 attr(), desc()->dst_desc.data_type)
66 && attr_.set_default_formats(dst_md(0)) == status::success
67 && IMPLICATION(desc()->src_desc.data_type == f16,
68 compute_engine->mayiuse(
69 compute::device_ext_t::khr_fp16))
70 && (invariant_src_md()->format_desc.blocking.inner_nblks > 0
71 || invariant_wei_md()
72 ->format_desc.blocking.inner_nblks
73 > 0
74 || (src_md_.format_kind == format_kind::any
75 && weights_md_.format_kind
76 == format_kind::any));
77
78 if (!ok) return status::unimplemented;
79
80 CHECK(init_conf(engine));
81 CHECK(init_scratchpad());
82 return status::success;
83 }
84
85 status_t init_conf(engine_t *engine);
86 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
87
88 inner_product_conf_t conf;
89
90 std::shared_ptr<primitive_desc_t> cpd_;
91 std::shared_ptr<primitive_desc_t> rpd_postop_;
92 std::shared_ptr<primitive_desc_t> rpd_dst_;
93
94 private:
95 status_t init_scratchpad();
96 };
97
98 convolution_inner_product_fwd_t(const pd_t *apd) : gpu_primitive_t(apd) {}
99
100 status_t init(engine_t *engine) override {
101 CHECK(create_nested_primitive(conv_, pd()->cpd_, engine));
102 if (pd()->rpd_postop_)
103 CHECK(create_nested_primitive(
104 postop_reorder_, pd()->rpd_postop_, engine));
105 if (pd()->rpd_dst_)
106 CHECK(create_nested_primitive(
107 dst_reorder_, pd()->rpd_dst_, engine));
108 return status::success;
109 }
110
111 status_t execute(const exec_ctx_t &ctx) const override {
112 return execute_forward(ctx);
113 }
114
115private:
116 status_t execute_forward(const exec_ctx_t &ctx) const;
117 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
118 std::shared_ptr<primitive_t> conv_;
119 std::shared_ptr<primitive_t> postop_reorder_;
120 std::shared_ptr<primitive_t> dst_reorder_;
121};
122
123} // namespace ocl
124} // namespace gpu
125} // namespace impl
126} // namespace dnnl
127
128#endif
129
130// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
131