1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/ir_builder.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace jit { |
23 | |
24 | void ir_builder_t::init_kernel_grid(const grid_info_t &kernel_grid, |
25 | const grid_info_t &tg_grid, int simd_size, constraint_set_t &cset, |
26 | std::vector<stmt_t> &init_stmts) { |
27 | int grid_ndims = kernel_grid.ndims(); |
28 | for (int i = 0; i < grid_ndims; i++) { |
29 | local_id_[i] |
30 | = var_t::make(type_t::u16(), "local_id" + std::to_string(i)); |
31 | int local_id_bound = tg_grid.dim(i); |
32 | if (i == 0) local_id_bound *= simd_size; |
33 | cset.add_constraint(local_id_[i] >= 0); |
34 | cset.add_constraint(local_id_[i] < local_id_bound); |
35 | |
36 | cset.add_constraint(kernel_grid.idx(i) >= 0); |
37 | cset.add_constraint(kernel_grid.idx(i) < kernel_grid.dim(i)); |
38 | cset.add_constraint(tg_grid.idx(i) >= 0); |
39 | cset.add_constraint(tg_grid.idx(i) < tg_grid.dim(i)); |
40 | } |
41 | |
42 | for (int i = 0; i < grid_ndims; i++) { |
43 | auto value = local_id_[i]; |
44 | if (i == 0) value /= simd_size; |
45 | auto &type = tg_grid.idx(i).type(); |
46 | init_stmts.push_back(let_t::make(tg_grid.idx(i), cast(value, type))); |
47 | } |
48 | } |
49 | |
50 | } // namespace jit |
51 | } // namespace gpu |
52 | } // namespace impl |
53 | } // namespace dnnl |
54 | |