1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/mul_add.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace jit {
23
24// Performs the following operation:
25// buf = alpha * buf + beta
26stmt_t create_mul_add_stmt(ir_context_t &ir_ctx, const expr_t &buf, int size,
27 const type_t &type, float alpha, float beta) {
28 if (alpha == 1 && beta == 0) return stmt_t();
29
30 stmt_t ret;
31 int step_bytes = 2 * ir_ctx.hw_cfg().grf_size();
32 for (int i = 0; i < size; i += step_bytes) {
33 auto elems = std::min(step_bytes, size - i) / type.size();
34 auto e_alpha = shuffle_t::make_broadcast(alpha, elems);
35 auto e_beta = shuffle_t::make_broadcast(beta, elems);
36 auto e = load_t::make(type.with_elems(elems), buf, i);
37 // Avoid extra IR expressions when not needed.
38 if (alpha == 0)
39 e = shuffle_t::make_broadcast(expr_t(0.0f), elems);
40 else if (alpha != 1)
41 e *= e_alpha;
42 if (beta != 0) e += e_beta;
43 ir_assert(e.type().scalar() == type);
44 ret = ret.append(store_t::make(buf, i, e));
45 }
46 return ret;
47}
48
49} // namespace jit
50} // namespace gpu
51} // namespace impl
52} // namespace dnnl
53