1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_DIMS_T_HPP
18#define UTILS_DIMS_T_HPP
19
20#include <cassert>
21#include <iostream>
22#include <string>
23#include <vector>
24
25using dims_t = std::vector<int64_t>;
26using vdims_t = std::vector<dims_t>;
27
28struct prb_dims_t {
29 dims_t dims;
30 int ndims;
31 std::string name;
32
33 int64_t nelems(int mask) const;
34};
35
36// Note: we could use a single type to contain both dims_t and vdims_t versions.
37// Two different types allow to separate features and members availability which
38// don't make much sense for dims_t.
39struct prb_vdims_t {
40 vdims_t vdims;
41 // Destination dimensions with all broadcasts incorporated. Drivers inherit
42 // this member and may modify it due to driver specifics.
43 dims_t dst_dims;
44 int ndims;
45 std::string name;
46
47 int n_inputs() const { return static_cast<int>(vdims.size()); }
48 int get_broadcast_mask(int i_input = 1) const;
49 int64_t nelems(int i_input, int mask) const;
50};
51
52// strides for SRC, WEI, and DST
53enum {
54 STRIDES_SRC = 0,
55 STRIDES_WEI = 1,
56 STRIDES_DST = 2,
57 STRIDES_SIZE = 3,
58};
59
60dims_t off2dims_idx(const dims_t &dims, int64_t off);
61std::string dims2str(const dims_t &dims);
62std::string vdims2str(const vdims_t &vdims);
63std::ostream &operator<<(std::ostream &s, const prb_dims_t &prb_dims);
64std::ostream &operator<<(std::ostream &s, const prb_vdims_t &prb_vdims);
65
66#endif
67