1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CONCAT_HPP |
18 | #define CONCAT_HPP |
19 | |
20 | #include <iostream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "common.hpp" |
25 | #include "dnn_types.hpp" |
26 | #include "dnnl_common.hpp" |
27 | #include "utils/perf_report.hpp" |
28 | #include "utils/settings.hpp" |
29 | |
30 | namespace concat { |
31 | |
32 | struct settings_t : public base_settings_t { |
33 | settings_t() = default; |
34 | |
35 | // ctor to save certain fields from resetting |
36 | settings_t(const char *perf_template) : settings_t() { |
37 | this->perf_template = perf_template; |
38 | } |
39 | |
40 | prb_vdims_t prb_vdims; |
41 | |
42 | std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32}; |
43 | std::vector<std::vector<std::string>> stag {{tag::abx}}; |
44 | std::vector<std::string> dtag {tag::undef}; |
45 | std::vector<int> axis {1}; |
46 | |
47 | const char *perf_template_csv() const { |
48 | static const std::string args = "%sdt%,%ddt%,%stag%,%dtag%,%axis%" ; |
49 | return perf_template_csv_base(args); |
50 | } |
51 | |
52 | void reset() { *this = settings_t(perf_template); } |
53 | }; |
54 | |
55 | struct prb_t : public prb_vdims_t { |
56 | prb_t(const prb_vdims_t &prb_vdims, dnnl_data_type_t sdt, |
57 | dnnl_data_type_t ddt, const std::vector<std::string> &stag, |
58 | const std::string &dtag, int axis, const attr_t &attr, |
59 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe) |
60 | : prb_vdims_t(prb_vdims) |
61 | , sdt(sdt) |
62 | , ddt(ddt) |
63 | , stag(stag) |
64 | , dtag(dtag) |
65 | , axis(axis) |
66 | , attr(attr) |
67 | , ctx_init(ctx_init) |
68 | , ctx_exe(ctx_exe) { |
69 | // If dst is omitted by `dtag = tag::undef`, omit `ddt` as well. |
70 | if (dtag == tag::undef) this->ddt = dnnl_data_type_undef; |
71 | |
72 | // Broadcast tag if needed |
73 | if (stag.size() == 1) { |
74 | const auto val = stag[0]; // Need a copy here. |
75 | this->stag.assign(prb_vdims.n_inputs(), val); |
76 | } |
77 | |
78 | dst_dims[axis] = axis_size(); |
79 | } |
80 | ~prb_t() {} |
81 | |
82 | dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward. |
83 | dnnl_data_type_t sdt, ddt; |
84 | std::vector<std::string> stag; |
85 | std::string dtag; |
86 | int axis; |
87 | attr_t attr; |
88 | thr_ctx_t ctx_init, ctx_exe; |
89 | |
90 | int64_t axis_size() const { |
91 | int64_t as = 0; |
92 | for (int i = 0; i < n_inputs(); ++i) |
93 | as += vdims[i].at(axis); |
94 | return as; |
95 | } |
96 | }; |
97 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
98 | |
99 | struct perf_report_t : public base_perf_report_t { |
100 | perf_report_t(const prb_t *prb, const char *perf_template) |
101 | : base_perf_report_t(perf_template) |
102 | , p_(prb) |
103 | , sdt_({p_->sdt}) |
104 | , stag_({}) |
105 | , dtag_(normalize_tag(p_->dtag, p_->ndims)) { |
106 | for (size_t d = 0; d < p_->stag.size(); d++) |
107 | stag_.push_back(normalize_tag(p_->stag[d], p_->ndims)); |
108 | } |
109 | |
110 | void dump_desc(std::ostream &s) const override { |
111 | s << static_cast<const prb_vdims_t &>(*p_); |
112 | } |
113 | |
114 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
115 | |
116 | const attr_t *attr() const override { return &p_->attr; } |
117 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
118 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
119 | const std::string *name() const override { return &p_->name; } |
120 | const int *axis() const override { return &p_->axis; } |
121 | const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; } |
122 | const dnnl_data_type_t *ddt() const override { return &p_->ddt; } |
123 | const std::vector<std::string> *stag() const override { return &stag_; } |
124 | const std::string *dtag() const override { return &dtag_; } |
125 | |
126 | private: |
127 | const prb_t *p_; |
128 | std::vector<dnnl_data_type_t> sdt_; |
129 | std::vector<std::string> stag_; |
130 | std::string dtag_; |
131 | }; |
132 | |
133 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
134 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
135 | void compute_ref(const prb_t *prb, const args_t &args, |
136 | dnnl_primitive_t prim_ref = nullptr); |
137 | |
138 | int doit(const prb_t *prb, res_t *res); |
139 | int bench(int argc, char **argv); |
140 | |
141 | } // namespace concat |
142 | |
143 | #endif |
144 | |