1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_PRIMITIVE_ATTR_POSTOPS_HPP |
18 | #define CPU_PRIMITIVE_ATTR_POSTOPS_HPP |
19 | |
20 | #include <vector> |
21 | |
22 | #include "common/primitive.hpp" |
23 | #include "common/primitive_attr.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | |
29 | float compute_binary_scalar(alg_kind_t alg, float x, float y); |
30 | float compute_eltwise_scalar_fwd( |
31 | const alg_kind_t alg, float s, float alpha, float beta); |
32 | float compute_eltwise_scalar_bwd( |
33 | const alg_kind_t alg, float dd, float s, float alpha, float beta); |
34 | |
35 | struct ref_binary_scalar_t { |
36 | ref_binary_scalar_t(alg_kind_t alg); |
37 | ref_binary_scalar_t(const post_ops_t::entry_t::binary_t &binary); |
38 | |
39 | float compute_scalar(float src0, float src1) const; |
40 | |
41 | private: |
42 | const alg_kind_t alg_; |
43 | }; |
44 | |
45 | struct ref_eltwise_scalar_fwd_t { |
46 | ref_eltwise_scalar_fwd_t( |
47 | alg_kind_t alg, float alpha, float beta, float scale); |
48 | ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); |
49 | |
50 | float compute_scalar(float s) const; |
51 | |
52 | const alg_kind_t alg_; |
53 | const float alpha_; |
54 | const float beta_; |
55 | const float scale_; |
56 | }; |
57 | |
58 | struct ref_post_ops_t { |
59 | struct args_t { |
60 | args_t() : dst_val(0.f), ctx(nullptr), l_offset(-1), dst_md(nullptr) {} |
61 | |
62 | float dst_val; // sum arg |
63 | const exec_ctx_t *ctx; // binary arg |
64 | dim_t l_offset; // binary arg |
65 | const memory_desc_t *dst_md; // binary arg |
66 | }; |
67 | |
68 | ref_post_ops_t(const post_ops_t &po, bool skip_sum = false); |
69 | |
70 | virtual ~ref_post_ops_t() = default; |
71 | |
72 | status_t execute(float &res, const args_t &args = args_t()) const; |
73 | |
74 | private: |
75 | const post_ops_t &po_; |
76 | // some primitives for example gemm are able to perform sum postop itself, |
77 | // in such cases executing sum should be skipped |
78 | const bool skip_sum_; |
79 | |
80 | std::vector<ref_eltwise_scalar_fwd_t> eltwise_po_; |
81 | std::vector<ref_binary_scalar_t> binary_po_; |
82 | }; |
83 | |
84 | } // namespace cpu |
85 | } // namespace impl |
86 | } // namespace dnnl |
87 | |
88 | #endif |
89 | |