1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef UTILS_DIMS_T_HPP |
18 | #define UTILS_DIMS_T_HPP |
19 | |
20 | #include <cassert> |
21 | #include <iostream> |
22 | #include <string> |
23 | #include <vector> |
24 | |
25 | using dims_t = std::vector<int64_t>; |
26 | using vdims_t = std::vector<dims_t>; |
27 | |
28 | struct prb_dims_t { |
29 | dims_t dims; |
30 | int ndims; |
31 | std::string name; |
32 | |
33 | int64_t nelems(int mask) const; |
34 | }; |
35 | |
36 | // Note: we could use a single type to contain both dims_t and vdims_t versions. |
37 | // Two different types allow to separate features and members availability which |
38 | // don't make much sense for dims_t. |
39 | struct prb_vdims_t { |
40 | vdims_t vdims; |
41 | // Destination dimensions with all broadcasts incorporated. Drivers inherit |
42 | // this member and may modify it due to driver specifics. |
43 | dims_t dst_dims; |
44 | int ndims; |
45 | std::string name; |
46 | |
47 | int n_inputs() const { return static_cast<int>(vdims.size()); } |
48 | int get_broadcast_mask(int i_input = 1) const; |
49 | int64_t nelems(int i_input, int mask) const; |
50 | }; |
51 | |
52 | // strides for SRC, WEI, and DST |
53 | enum { |
54 | STRIDES_SRC = 0, |
55 | STRIDES_WEI = 1, |
56 | STRIDES_DST = 2, |
57 | STRIDES_SIZE = 3, |
58 | }; |
59 | |
60 | dims_t off2dims_idx(const dims_t &dims, int64_t off); |
61 | std::string dims2str(const dims_t &dims); |
62 | std::string vdims2str(const vdims_t &vdims); |
63 | std::ostream &operator<<(std::ostream &s, const prb_dims_t &prb_dims); |
64 | std::ostream &operator<<(std::ostream &s, const prb_vdims_t &prb_vdims); |
65 | |
66 | #endif |
67 | |