1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_PASS_EXPR_SCALARIZER_HPP
18#define GPU_JIT_PASS_EXPR_SCALARIZER_HPP
19
20#include "gpu/jit/ir/ir.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class expr_scalarizer_t : public ir_mutator_t {
28public:
29 expr_scalarizer_t(int elems, int idx,
30 const object_map_t<expr_t, std::vector<expr_t>> &vec_vars)
31 : elems_(elems), idx_(idx), vec_vars_(vec_vars) {}
32
33 object_t _mutate(const cast_t &obj) override {
34 if (obj.is_bool_vec_u16()) return obj;
35 auto type = obj.type;
36 auto expr = mutate(obj.expr);
37 if (!type.is_scalar()) {
38 ir_assert(type.elems() == elems_) << expr;
39 type = type.scalar();
40 }
41 return cast_t::make(type, expr, obj.saturate);
42 }
43
44 object_t _mutate(const var_t &obj) override {
45 if (obj.type.is_scalar()) return obj;
46
47 auto it = vec_vars_.find(obj);
48 ir_assert(it != vec_vars_.end()) << "Can't find variable: " << obj;
49 ir_assert(int(it->second.size()) == elems_);
50 return it->second[idx_];
51 }
52
53 object_t _mutate(const shuffle_t &obj) override {
54 expr_t new_obj = ir_mutator_t::_mutate(obj);
55 auto &shuffle = new_obj.as<shuffle_t>();
56 ir_assert(shuffle.type.elems() == elems_) << new_obj;
57 return new_obj[idx_];
58 }
59
60private:
61 int elems_;
62 int idx_;
63 const object_map_t<expr_t, std::vector<expr_t>> &vec_vars_;
64};
65
66} // namespace jit
67} // namespace gpu
68} // namespace impl
69} // namespace dnnl
70
71#endif
72