1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_POST_OPS_HPP |
18 | #define GPU_JIT_IR_POST_OPS_HPP |
19 | |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "gpu/jit/ir/ir.hpp" |
24 | #include "gpu/jit/ir/tensor.hpp" |
25 | #include "gpu/jit/utils/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace jit { |
31 | |
32 | class post_op_tensor_info_t { |
33 | public: |
34 | post_op_tensor_info_t() = default; |
35 | |
36 | post_op_tensor_info_t(bool is_input, bool is_output, const view_t &view, |
37 | const expr_t &buf, uint32_t mask, const expr_t &op_var, float scale) |
38 | : is_input_(is_input) |
39 | , is_output_(is_output) |
40 | , view_(view) |
41 | , buf_(buf) |
42 | , mask_(mask) |
43 | , op_var_(op_var) |
44 | , scale_(scale) { |
45 | if (op_var_.is_empty()) |
46 | op_var_ = var_t::make(type_t::f32(), make_op_var_name(buf)); |
47 | if (scale != 1) |
48 | ir_assert(is_output_) |
49 | << "Scale is supported with output tensors only." ; |
50 | } |
51 | |
52 | bool is_input() const { return is_input_; } |
53 | |
54 | bool is_output() const { return is_output_; } |
55 | |
56 | bool needs_masked_update() const { return needs_masked_update_; } |
57 | |
58 | const view_t &view() const { return view_; } |
59 | |
60 | const expr_t &buf() const { return buf_; } |
61 | |
62 | const uint32_t &mask() const { return mask_; } |
63 | |
64 | const expr_t &op_var() const { return op_var_; } |
65 | |
66 | float scale() const { return scale_; } |
67 | |
68 | post_op_tensor_info_t create_sub_tensor(const tensor_t &tile) const { |
69 | auto ret = *this; |
70 | ret.view_ = ret.view_.create_sub_view(tile); |
71 | return ret; |
72 | } |
73 | |
74 | void require_masked_update() { needs_masked_update_ = true; } |
75 | |
76 | private: |
77 | static std::string make_op_var_name(const expr_t &buf) { |
78 | auto *var = buf.as_ptr<var_t>(); |
79 | if (var) return var->name; |
80 | |
81 | auto *ptr = buf.as_ptr<ptr_t>(); |
82 | if (ptr) { |
83 | auto prefix = make_op_var_name(ptr->base); |
84 | ir_assert(is_const(ptr->off)); |
85 | int off = to_cpp<int>(ptr->off); |
86 | return prefix + "_" + std::to_string(off); |
87 | } |
88 | |
89 | ir_error_not_expected() << "Can't generate op var name: " << buf; |
90 | return "unknown" ; |
91 | } |
92 | bool is_input_; |
93 | bool is_output_; |
94 | bool needs_masked_update_ = false; |
95 | view_t view_; |
96 | expr_t buf_; |
97 | uint32_t mask_; |
98 | expr_t op_var_; |
99 | float scale_; |
100 | }; |
101 | |
102 | // There are two types of post-ops: |
103 | // - Eltwise: lhs = eltwise(rhs) and rhs must be equal lhs |
104 | // Eltwise is supported via special IR function eltwise_t |
105 | // - Generic post-op: lhs = rhs |
106 | // Left-hand side (lhs) represents a single post-op tensor. Right-hand side |
107 | // tensor (rhs) is an IR expression over post-op tensors and constants. |
108 | // |
109 | // Post-op tensors support broadcast (when used from rhs) and reduction (when |
110 | // used from lhs) semantics. |
111 | // |
112 | // If lhs is (a x 1) tensor and rhs is (a x b) tensor then rhs is reduced: |
113 | // lhs(i, 0) = sum over j rhs(i, j) |
114 | // |
115 | // If lhs is (a x b) tensor and rhs is (a x 1) tensor then rhs is broadcasted: |
116 | // lhs(i, j) = rhs(i, 0) |
117 | class post_op_t { |
118 | public: |
119 | post_op_t() = default; |
120 | |
121 | post_op_t(const expr_t &lhs, const expr_t &rhs, |
122 | const func_t &eltwise = func_t()) |
123 | : lhs_(lhs), rhs_(simplify_rewrite(rhs)), eltwise_(eltwise) {} |
124 | |
125 | const expr_t &lhs() const { return lhs_; } |
126 | |
127 | const expr_t &rhs() const { return rhs_; } |
128 | |
129 | const func_t &eltwise() const { return eltwise_; } |
130 | |
131 | bool uses(const expr_t &op_var) const { |
132 | if (contains_object(lhs_, op_var)) return true; |
133 | if (contains_object(rhs_, op_var)) return true; |
134 | return false; |
135 | } |
136 | |
137 | private: |
138 | expr_t lhs_; |
139 | expr_t rhs_; |
140 | func_t eltwise_; |
141 | }; |
142 | |
143 | inline op_kind_t alg_kind_to_op_kind(alg_kind_t alg) { |
144 | switch (alg) { |
145 | case alg_kind::binary_add: return op_kind_t::_add; |
146 | case alg_kind::binary_sub: return op_kind_t::_sub; |
147 | case alg_kind::binary_mul: return op_kind_t::_mul; |
148 | case alg_kind::binary_div: return op_kind_t::_div; |
149 | case alg_kind::binary_min: return op_kind_t::_min; |
150 | case alg_kind::binary_max: return op_kind_t::_max; |
151 | case alg_kind::binary_ge: return op_kind_t::_ge; |
152 | case alg_kind::binary_gt: return op_kind_t::_gt; |
153 | case alg_kind::binary_le: return op_kind_t::_le; |
154 | case alg_kind::binary_lt: return op_kind_t::_lt; |
155 | case alg_kind::binary_eq: return op_kind_t::_eq; |
156 | case alg_kind::binary_ne: return op_kind_t::_ne; |
157 | default: ir_error_not_expected(); |
158 | } |
159 | return op_kind_t::undef; |
160 | } |
161 | |
162 | } // namespace jit |
163 | } // namespace gpu |
164 | } // namespace impl |
165 | } // namespace dnnl |
166 | |
167 | #endif |
168 | |