1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_PRIMITIVE_ATTR_POSTOPS_HPP
18#define CPU_PRIMITIVE_ATTR_POSTOPS_HPP
19
20#include <vector>
21
22#include "common/primitive.hpp"
23#include "common/primitive_attr.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28
29float compute_binary_scalar(alg_kind_t alg, float x, float y);
30float compute_eltwise_scalar_fwd(
31 const alg_kind_t alg, float s, float alpha, float beta);
32float compute_eltwise_scalar_bwd(
33 const alg_kind_t alg, float dd, float s, float alpha, float beta);
34
35struct ref_binary_scalar_t {
36 ref_binary_scalar_t(alg_kind_t alg);
37 ref_binary_scalar_t(const post_ops_t::entry_t::binary_t &binary);
38
39 float compute_scalar(float src0, float src1) const;
40
41private:
42 const alg_kind_t alg_;
43};
44
45struct ref_eltwise_scalar_fwd_t {
46 ref_eltwise_scalar_fwd_t(
47 alg_kind_t alg, float alpha, float beta, float scale);
48 ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise);
49
50 float compute_scalar(float s) const;
51
52 const alg_kind_t alg_;
53 const float alpha_;
54 const float beta_;
55 const float scale_;
56};
57
58struct ref_post_ops_t {
59 struct args_t {
60 args_t() : dst_val(0.f), ctx(nullptr), l_offset(-1), dst_md(nullptr) {}
61
62 float dst_val; // sum arg
63 const exec_ctx_t *ctx; // binary arg
64 dim_t l_offset; // binary arg
65 const memory_desc_t *dst_md; // binary arg
66 };
67
68 ref_post_ops_t(const post_ops_t &po, bool skip_sum = false);
69
70 virtual ~ref_post_ops_t() = default;
71
72 status_t execute(float &res, const args_t &args = args_t()) const;
73
74private:
75 const post_ops_t &po_;
76 // some primitives for example gemm are able to perform sum postop itself,
77 // in such cases executing sum should be skipped
78 const bool skip_sum_;
79
80 std::vector<ref_eltwise_scalar_fwd_t> eltwise_po_;
81 std::vector<ref_binary_scalar_t> binary_po_;
82};
83
84} // namespace cpu
85} // namespace impl
86} // namespace dnnl
87
88#endif
89