1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_POST_OPS_HPP
18#define GPU_JIT_IR_POST_OPS_HPP
19
20#include <string>
21#include <vector>
22
23#include "gpu/jit/ir/ir.hpp"
24#include "gpu/jit/ir/tensor.hpp"
25#include "gpu/jit/utils/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace jit {
31
32class post_op_tensor_info_t {
33public:
34 post_op_tensor_info_t() = default;
35
36 post_op_tensor_info_t(bool is_input, bool is_output, const view_t &view,
37 const expr_t &buf, uint32_t mask, const expr_t &op_var, float scale)
38 : is_input_(is_input)
39 , is_output_(is_output)
40 , view_(view)
41 , buf_(buf)
42 , mask_(mask)
43 , op_var_(op_var)
44 , scale_(scale) {
45 if (op_var_.is_empty())
46 op_var_ = var_t::make(type_t::f32(), make_op_var_name(buf));
47 if (scale != 1)
48 ir_assert(is_output_)
49 << "Scale is supported with output tensors only.";
50 }
51
52 bool is_input() const { return is_input_; }
53
54 bool is_output() const { return is_output_; }
55
56 bool needs_masked_update() const { return needs_masked_update_; }
57
58 const view_t &view() const { return view_; }
59
60 const expr_t &buf() const { return buf_; }
61
62 const uint32_t &mask() const { return mask_; }
63
64 const expr_t &op_var() const { return op_var_; }
65
66 float scale() const { return scale_; }
67
68 post_op_tensor_info_t create_sub_tensor(const tensor_t &tile) const {
69 auto ret = *this;
70 ret.view_ = ret.view_.create_sub_view(tile);
71 return ret;
72 }
73
74 void require_masked_update() { needs_masked_update_ = true; }
75
76private:
77 static std::string make_op_var_name(const expr_t &buf) {
78 auto *var = buf.as_ptr<var_t>();
79 if (var) return var->name;
80
81 auto *ptr = buf.as_ptr<ptr_t>();
82 if (ptr) {
83 auto prefix = make_op_var_name(ptr->base);
84 ir_assert(is_const(ptr->off));
85 int off = to_cpp<int>(ptr->off);
86 return prefix + "_" + std::to_string(off);
87 }
88
89 ir_error_not_expected() << "Can't generate op var name: " << buf;
90 return "unknown";
91 }
92 bool is_input_;
93 bool is_output_;
94 bool needs_masked_update_ = false;
95 view_t view_;
96 expr_t buf_;
97 uint32_t mask_;
98 expr_t op_var_;
99 float scale_;
100};
101
102// There are two types of post-ops:
103// - Eltwise: lhs = eltwise(rhs) and rhs must be equal lhs
104// Eltwise is supported via special IR function eltwise_t
105// - Generic post-op: lhs = rhs
106// Left-hand side (lhs) represents a single post-op tensor. Right-hand side
107// tensor (rhs) is an IR expression over post-op tensors and constants.
108//
109// Post-op tensors support broadcast (when used from rhs) and reduction (when
110// used from lhs) semantics.
111//
112// If lhs is (a x 1) tensor and rhs is (a x b) tensor then rhs is reduced:
113// lhs(i, 0) = sum over j rhs(i, j)
114//
115// If lhs is (a x b) tensor and rhs is (a x 1) tensor then rhs is broadcasted:
116// lhs(i, j) = rhs(i, 0)
117class post_op_t {
118public:
119 post_op_t() = default;
120
121 post_op_t(const expr_t &lhs, const expr_t &rhs,
122 const func_t &eltwise = func_t())
123 : lhs_(lhs), rhs_(simplify_rewrite(rhs)), eltwise_(eltwise) {}
124
125 const expr_t &lhs() const { return lhs_; }
126
127 const expr_t &rhs() const { return rhs_; }
128
129 const func_t &eltwise() const { return eltwise_; }
130
131 bool uses(const expr_t &op_var) const {
132 if (contains_object(lhs_, op_var)) return true;
133 if (contains_object(rhs_, op_var)) return true;
134 return false;
135 }
136
137private:
138 expr_t lhs_;
139 expr_t rhs_;
140 func_t eltwise_;
141};
142
143inline op_kind_t alg_kind_to_op_kind(alg_kind_t alg) {
144 switch (alg) {
145 case alg_kind::binary_add: return op_kind_t::_add;
146 case alg_kind::binary_sub: return op_kind_t::_sub;
147 case alg_kind::binary_mul: return op_kind_t::_mul;
148 case alg_kind::binary_div: return op_kind_t::_div;
149 case alg_kind::binary_min: return op_kind_t::_min;
150 case alg_kind::binary_max: return op_kind_t::_max;
151 case alg_kind::binary_ge: return op_kind_t::_ge;
152 case alg_kind::binary_gt: return op_kind_t::_gt;
153 case alg_kind::binary_le: return op_kind_t::_le;
154 case alg_kind::binary_lt: return op_kind_t::_lt;
155 case alg_kind::binary_eq: return op_kind_t::_eq;
156 case alg_kind::binary_ne: return op_kind_t::_ne;
157 default: ir_error_not_expected();
158 }
159 return op_kind_t::undef;
160}
161
162} // namespace jit
163} // namespace gpu
164} // namespace impl
165} // namespace dnnl
166
167#endif
168