1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_INNER_PRODUCT_HPP
18#define CPU_REF_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/primitive_attr_postops.hpp"
28
29#include "cpu/cpu_inner_product_pd.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34
35struct ref_inner_product_fwd_t : public primitive_t {
36 struct pd_t : public cpu_inner_product_fwd_pd_t {
37 using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
38
39 DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t);
40
41 status_t init(engine_t *engine) {
42 using namespace data_type;
43 using smask_t = primitive_attr_t::skip_mask_t;
44 const auto src_type = src_md(0)->data_type;
45 const auto wei_type = weights_md(0)->data_type;
46 const auto bia_type = weights_md(1)->data_type;
47 const auto dst_type = dst_md(0)->data_type;
48
49 const bool allow_all_tags = true; // ref should support all tags
50
51 bool ok = is_fwd() && platform::has_data_type_support(src_type)
52 && platform::has_data_type_support(wei_type)
53 && platform::has_data_type_support(bia_type)
54 && platform::has_data_type_support(dst_type)
55 && utils::one_of(src_type, f32, bf16, f16)
56 && wei_type == src_type
57 && utils::one_of(dst_type, f32, src_type)
58 && IMPLICATION(
59 with_bias(), utils::one_of(bia_type, f32, src_type))
60 && set_default_params(allow_all_tags) == status::success
61 && attr()->has_default_values(
62 smask_t::post_ops | smask_t::sum_dt)
63 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
64 && attr_.set_default_formats(dst_md(0)) == status::success;
65 return ok ? status::success : status::unimplemented;
66 }
67 };
68
69 ref_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {}
70
71 status_t init(engine_t *engine) override {
72 ref_post_ops
73 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
74 if (!ref_post_ops) return status::out_of_memory;
75 return status::success;
76 }
77
78 status_t execute(const exec_ctx_t &ctx) const override {
79 return execute_forward(ctx);
80 }
81
82private:
83 status_t execute_forward(const exec_ctx_t &ctx) const;
84 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
85 std::unique_ptr<ref_post_ops_t> ref_post_ops;
86};
87
88struct ref_inner_product_bwd_data_t : public primitive_t {
89 struct pd_t : public cpu_inner_product_bwd_data_pd_t {
90 using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
91
92 DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
93
94 status_t init(engine_t *engine) {
95 using namespace data_type;
96 const auto diff_src_type = diff_src_md(0)->data_type;
97 const auto wei_type = weights_md(0)->data_type;
98 const auto diff_dst_type = diff_dst_md(0)->data_type;
99
100 const bool allow_all_tags = true; // ref should support all tags
101
102 bool ok = desc()->prop_kind == prop_kind::backward_data
103 && platform::has_data_type_support(diff_src_type)
104 && platform::has_data_type_support(wei_type)
105 && platform::has_data_type_support(diff_dst_type)
106 && utils::one_of(diff_src_type, f32, wei_type)
107 && utils::one_of(wei_type, f32, bf16, f16)
108 && diff_dst_type == wei_type && attr()->has_default_values()
109 && set_default_params(allow_all_tags) == status::success;
110 return ok ? status::success : status::unimplemented;
111 }
112 };
113
114 ref_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
115
116 status_t execute(const exec_ctx_t &ctx) const override {
117 return execute_backward_data(ctx);
118 }
119
120private:
121 status_t execute_backward_data(const exec_ctx_t &ctx) const;
122 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
123};
124
125struct ref_inner_product_bwd_weights_t : public primitive_t {
126 struct pd_t : public cpu_inner_product_bwd_weights_pd_t {
127 using cpu_inner_product_bwd_weights_pd_t::
128 cpu_inner_product_bwd_weights_pd_t;
129
130 DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
131
132 status_t init(engine_t *engine) {
133 using namespace data_type;
134 const auto src_type = src_md(0)->data_type;
135 const auto diff_wei_type = diff_weights_md(0)->data_type;
136 const auto diff_bia_type = diff_weights_md(1)->data_type;
137 const auto diff_dst_type = diff_dst_md(0)->data_type;
138
139 const bool allow_all_tags = true; // ref should support all tags
140
141 bool ok = desc()->prop_kind == prop_kind::backward_weights
142 && platform::has_data_type_support(src_type)
143 && platform::has_data_type_support(diff_wei_type)
144 && platform::has_data_type_support(diff_bia_type)
145 && utils::one_of(src_type, f32, bf16, f16)
146 && utils::one_of(diff_wei_type, f32, src_type)
147 && IMPLICATION(with_bias(),
148 utils::one_of(diff_bia_type, f32, src_type))
149 && diff_dst_type == src_type && attr()->has_default_values()
150 && set_default_params(allow_all_tags) == status::success;
151 return ok ? status::success : status::unimplemented;
152 }
153 };
154
155 ref_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
156
157 status_t execute(const exec_ctx_t &ctx) const override {
158 return execute_backward_weights(ctx);
159 }
160
161private:
162 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
163 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
164};
165
166} // namespace cpu
167} // namespace impl
168} // namespace dnnl
169
170#endif
171
172// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
173