1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_ELTWISE_HPP
18#define GPU_JIT_IR_ELTWISE_HPP
19
20#include <string>
21#include <vector>
22
23#include "gpu/jit/ir/ir.hpp"
24#include "gpu/jit/utils/utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31class eltwise_t : public func_impl_t {
32public:
33 IR_DECL_DERIVED_TYPE_ID(eltwise_t, func_impl_t)
34
35 static func_t make(
36 alg_kind_t alg_kind, float scale, float alpha, float beta) {
37 return func_t(new eltwise_t(alg_kind, scale, alpha, beta));
38 }
39
40 bool is_equal(const object_impl_t &obj) const override {
41 if (!obj.is<self_type>()) return false;
42 auto &other = obj.as<self_type>();
43
44 return (alg_kind == other.alg_kind) && (scale == other.scale)
45 && (alpha == other.alpha) && (beta == other.beta);
46 }
47
48 size_t get_hash() const override {
49 return ir_utils::get_hash(alg_kind, scale, alpha, beta);
50 }
51
52 std::string str() const override {
53 switch (alg_kind) {
54 case alg_kind::eltwise_relu: return "relu";
55 case alg_kind::eltwise_tanh: return "tanh";
56 case alg_kind::eltwise_elu: return "elu";
57 case alg_kind::eltwise_square: return "square";
58 case alg_kind::eltwise_abs: return "abs";
59 case alg_kind::eltwise_sqrt: return "sqrt";
60 case alg_kind::eltwise_swish: return "swish";
61 case alg_kind::eltwise_linear: return "linear";
62 case alg_kind::eltwise_soft_relu: return "soft_relu";
63 case alg_kind::eltwise_logistic: return "logistic";
64 case alg_kind::eltwise_mish: return "mish";
65 case alg_kind::eltwise_exp: return "exp";
66 case alg_kind::eltwise_log: return "log";
67 case alg_kind::eltwise_clip: return "clip";
68 case alg_kind::eltwise_clip_v2: return "clip_v2";
69 case alg_kind::eltwise_pow: return "pow";
70 case alg_kind::eltwise_gelu_tanh: return "gelu_tanh";
71 case alg_kind::eltwise_gelu_erf: return "gelu_erf";
72 case alg_kind::eltwise_hardswish: return "hardswish";
73 case alg_kind::eltwise_relu_use_dst_for_bwd:
74 return "relu_use_dst_for_bwd";
75 case alg_kind::eltwise_tanh_use_dst_for_bwd:
76 return "tanh_use_dst_for_bwd";
77 case alg_kind::eltwise_elu_use_dst_for_bwd:
78 return "elu_use_dst_for_bwd";
79 case alg_kind::eltwise_sqrt_use_dst_for_bwd:
80 return "sqrt_use_dst_for_bwd";
81 case alg_kind::eltwise_logistic_use_dst_for_bwd:
82 return "logistic_use_dst_for_bwd";
83 case alg_kind::eltwise_exp_use_dst_for_bwd:
84 return "exp_use_dst_for_bwd";
85 case alg_kind::eltwise_clip_v2_use_dst_for_bwd:
86 return "clip_v2_use_dst_for_bwd";
87 case alg_kind::eltwise_round: return "round";
88 default: ir_error_not_expected();
89 }
90 return "unknown";
91 }
92
93 IR_DEFINE_ARG_GET(elems, 0)
94 IR_DEFINE_ARG_GET(data, 1)
95
96 alg_kind_t alg_kind;
97 float scale;
98 float alpha;
99 float beta;
100
101private:
102 eltwise_t(alg_kind_t alg_kind, float scale, float alpha, float beta)
103 : func_impl_t(_type_info())
104 , alg_kind(alg_kind)
105 , scale(scale)
106 , alpha(alpha)
107 , beta(beta) {}
108};
109
110} // namespace jit
111} // namespace gpu
112} // namespace impl
113} // namespace dnnl
114
115#endif
116