1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef SHUFFLE_HPP
18#define SHUFFLE_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <iostream>
25
26#include "common.hpp"
27#include "dnn_types.hpp"
28#include "dnnl_common.hpp"
29#include "dnnl_debug.hpp"
30#include "utils/perf_report.hpp"
31#include "utils/settings.hpp"
32
33namespace shuffle {
34
35struct settings_t : public base_settings_t {
36 settings_t() = default;
37
38 // ctor to save certain fields from resetting
39 settings_t(const char *perf_template) : settings_t() {
40 this->perf_template = perf_template;
41 }
42
43 prb_dims_t prb_dims;
44
45 std::vector<dir_t> dir {FWD_D};
46 std::vector<dnnl_data_type_t> dt {dnnl_f32};
47 std::vector<std::string> tag {tag::abx};
48 std::vector<int64_t> group {1};
49 std::vector<int> axis {1};
50
51 const char *perf_template_csv() const {
52 static const std::string args = "%dir%,%dt%,%tag%,%group%,%axis%";
53 return perf_template_csv_base(args);
54 }
55
56 void reset() { *this = settings_t(perf_template); }
57};
58
59struct prb_t : public prb_dims_t {
60 prb_t(const prb_dims_t &prb_dims, dir_t dir, dnnl_data_type_t dt,
61 const std::string &tag, int axis, int64_t group, const attr_t &attr,
62 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
63 : prb_dims_t(prb_dims)
64 , dir(dir)
65 , dt(dt)
66 , tag(tag)
67 , axis(axis)
68 , group(group)
69 , attr(attr)
70 , ctx_init(ctx_init)
71 , ctx_exe(ctx_exe) {}
72 ~prb_t() {}
73
74 dir_t dir;
75 dnnl_data_type_t dt;
76 std::string tag;
77 int axis;
78 int64_t group;
79 attr_t attr;
80 thr_ctx_t ctx_init, ctx_exe;
81};
82std::ostream &operator<<(std::ostream &s, const prb_t &prb);
83
84struct perf_report_t : public base_perf_report_t {
85 perf_report_t(const prb_t *prb, const char *perf_template)
86 : base_perf_report_t(perf_template)
87 , p_(prb)
88 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
89
90 void dump_desc(std::ostream &s) const override {
91 s << static_cast<const prb_dims_t &>(*p_);
92 }
93
94 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
95
96 const attr_t *attr() const override { return &p_->attr; }
97 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
98 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
99 const std::string *name() const override { return &p_->name; }
100 const int *axis() const override { return &p_->axis; }
101 const int64_t *group() const override { return &p_->group; }
102 const dir_t *dir() const override { return &p_->dir; }
103 const dnnl_data_type_t *dt() const override { return &p_->dt; }
104 const std::string *tag() const override { return &tag_; }
105
106private:
107 const prb_t *p_;
108 std::string tag_;
109};
110
111inline size_t data_off(const prb_t *prb, int64_t mb, int64_t c, int64_t d,
112 int64_t h, int64_t w) {
113 const auto &dims = prb->dims;
114 return (((mb * dims[1] + c) * dims[2] + d) * dims[3] + h) * dims[4] + w;
115}
116
117void skip_unimplemented_prb(const prb_t *prb, res_t *res);
118void skip_invalid_prb(const prb_t *prb, res_t *res);
119void compute_ref(const prb_t *prb, const args_t &args,
120 dnnl_primitive_t prim_ref = nullptr);
121
122int doit(const prb_t *prb, res_t *res);
123int bench(int argc, char **argv);
124} // namespace shuffle
125
126#endif
127