1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_ELTWISE_HPP |
18 | #define GPU_JIT_IR_ELTWISE_HPP |
19 | |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "gpu/jit/ir/ir.hpp" |
24 | #include "gpu/jit/utils/utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | class eltwise_t : public func_impl_t { |
32 | public: |
33 | IR_DECL_DERIVED_TYPE_ID(eltwise_t, func_impl_t) |
34 | |
35 | static func_t make( |
36 | alg_kind_t alg_kind, float scale, float alpha, float beta) { |
37 | return func_t(new eltwise_t(alg_kind, scale, alpha, beta)); |
38 | } |
39 | |
40 | bool is_equal(const object_impl_t &obj) const override { |
41 | if (!obj.is<self_type>()) return false; |
42 | auto &other = obj.as<self_type>(); |
43 | |
44 | return (alg_kind == other.alg_kind) && (scale == other.scale) |
45 | && (alpha == other.alpha) && (beta == other.beta); |
46 | } |
47 | |
48 | size_t get_hash() const override { |
49 | return ir_utils::get_hash(alg_kind, scale, alpha, beta); |
50 | } |
51 | |
52 | std::string str() const override { |
53 | switch (alg_kind) { |
54 | case alg_kind::eltwise_relu: return "relu" ; |
55 | case alg_kind::eltwise_tanh: return "tanh" ; |
56 | case alg_kind::eltwise_elu: return "elu" ; |
57 | case alg_kind::eltwise_square: return "square" ; |
58 | case alg_kind::eltwise_abs: return "abs" ; |
59 | case alg_kind::eltwise_sqrt: return "sqrt" ; |
60 | case alg_kind::eltwise_swish: return "swish" ; |
61 | case alg_kind::eltwise_linear: return "linear" ; |
62 | case alg_kind::eltwise_soft_relu: return "soft_relu" ; |
63 | case alg_kind::eltwise_logistic: return "logistic" ; |
64 | case alg_kind::eltwise_mish: return "mish" ; |
65 | case alg_kind::eltwise_exp: return "exp" ; |
66 | case alg_kind::eltwise_log: return "log" ; |
67 | case alg_kind::eltwise_clip: return "clip" ; |
68 | case alg_kind::eltwise_clip_v2: return "clip_v2" ; |
69 | case alg_kind::eltwise_pow: return "pow" ; |
70 | case alg_kind::eltwise_gelu_tanh: return "gelu_tanh" ; |
71 | case alg_kind::eltwise_gelu_erf: return "gelu_erf" ; |
72 | case alg_kind::eltwise_hardswish: return "hardswish" ; |
73 | case alg_kind::eltwise_relu_use_dst_for_bwd: |
74 | return "relu_use_dst_for_bwd" ; |
75 | case alg_kind::eltwise_tanh_use_dst_for_bwd: |
76 | return "tanh_use_dst_for_bwd" ; |
77 | case alg_kind::eltwise_elu_use_dst_for_bwd: |
78 | return "elu_use_dst_for_bwd" ; |
79 | case alg_kind::eltwise_sqrt_use_dst_for_bwd: |
80 | return "sqrt_use_dst_for_bwd" ; |
81 | case alg_kind::eltwise_logistic_use_dst_for_bwd: |
82 | return "logistic_use_dst_for_bwd" ; |
83 | case alg_kind::eltwise_exp_use_dst_for_bwd: |
84 | return "exp_use_dst_for_bwd" ; |
85 | case alg_kind::eltwise_clip_v2_use_dst_for_bwd: |
86 | return "clip_v2_use_dst_for_bwd" ; |
87 | case alg_kind::eltwise_round: return "round" ; |
88 | default: ir_error_not_expected(); |
89 | } |
90 | return "unknown" ; |
91 | } |
92 | |
93 | IR_DEFINE_ARG_GET(elems, 0) |
94 | IR_DEFINE_ARG_GET(data, 1) |
95 | |
96 | alg_kind_t alg_kind; |
97 | float scale; |
98 | float alpha; |
99 | float beta; |
100 | |
101 | private: |
102 | eltwise_t(alg_kind_t alg_kind, float scale, float alpha, float beta) |
103 | : func_impl_t(_type_info()) |
104 | , alg_kind(alg_kind) |
105 | , scale(scale) |
106 | , alpha(alpha) |
107 | , beta(beta) {} |
108 | }; |
109 | |
110 | } // namespace jit |
111 | } // namespace gpu |
112 | } // namespace impl |
113 | } // namespace dnnl |
114 | |
115 | #endif |
116 | |