1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef SOFTMAX_HPP |
18 | #define SOFTMAX_HPP |
19 | |
20 | #include <iostream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "common.hpp" |
25 | #include "dnn_types.hpp" |
26 | #include "dnnl_common.hpp" |
27 | #include "utils/perf_report.hpp" |
28 | #include "utils/settings.hpp" |
29 | |
30 | namespace softmax { |
31 | |
32 | enum alg_t { |
33 | UNDEF, |
34 | SOFTMAX, |
35 | LOGSOFTMAX, |
36 | softmax_accurate = SOFTMAX, |
37 | softmax_log = LOGSOFTMAX, |
38 | }; |
39 | alg_t str2alg(const char *str); |
40 | const char *alg2str(alg_t alg); |
41 | dnnl_alg_kind_t alg2alg_kind(alg_t alg); |
42 | |
43 | struct settings_t : public base_settings_t { |
44 | settings_t() = default; |
45 | |
46 | // ctor to save certain fields from resetting |
47 | settings_t(const char *perf_template) : settings_t() { |
48 | this->perf_template = perf_template; |
49 | } |
50 | |
51 | prb_dims_t prb_dims; |
52 | |
53 | std::vector<dir_t> dir {FWD_D}; |
54 | std::vector<dnnl_data_type_t> sdt {dnnl_f32}, ddt {dnnl_f32}; |
55 | std::vector<std::string> stag {tag::abx}, dtag {tag::any}; |
56 | std::vector<alg_t> alg {SOFTMAX}; |
57 | std::vector<int> axis {1}; |
58 | |
59 | const char *perf_template_csv() const { |
60 | static const std::string args |
61 | = "%dir%,%sdt%,%ddt%,%stag%,%dtag%,%alg%,%axis%" ; |
62 | return perf_template_csv_base(args); |
63 | } |
64 | |
65 | void reset() { *this = settings_t(perf_template); } |
66 | }; |
67 | |
68 | struct prb_t : public prb_dims_t { |
69 | prb_t(const prb_dims_t &prb_dims, dir_t dir, dnnl_data_type_t sdt, |
70 | dnnl_data_type_t ddt, const std::string &stag, |
71 | const std::string &dtag, alg_t alg, int axis, bool inplace, |
72 | const attr_t &attr, const thr_ctx_t &ctx_init, |
73 | const thr_ctx_t &ctx_exe, int64_t mb = 0) |
74 | : prb_dims_t(prb_dims) |
75 | , dir(dir) |
76 | , sdt(sdt) |
77 | , ddt(ddt) |
78 | , stag(stag) |
79 | , dtag(dtag) |
80 | , alg(alg) |
81 | , axis(axis) |
82 | , inplace(inplace) |
83 | , attr(attr) |
84 | , ctx_init(ctx_init) |
85 | , ctx_exe(ctx_exe) |
86 | , user_mb(mb) { |
87 | if (mb) dims[0] = mb; |
88 | } |
89 | |
90 | dir_t dir; |
91 | dnnl_data_type_t sdt, ddt; |
92 | std::string stag, dtag; |
93 | alg_t alg; |
94 | int axis; |
95 | bool inplace; |
96 | attr_t attr; |
97 | thr_ctx_t ctx_init, ctx_exe; |
98 | int64_t user_mb; |
99 | }; |
100 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
101 | |
102 | struct perf_report_t : public base_perf_report_t { |
103 | perf_report_t(const prb_t *prb, const char *perf_template) |
104 | : base_perf_report_t(perf_template) |
105 | , p_(prb) |
106 | , sdt_({p_->sdt}) |
107 | , ddt_(p_->ddt) |
108 | , stag_({normalize_tag(p_->stag, p_->ndims)}) |
109 | , dtag_(normalize_tag(p_->dtag, p_->ndims)) {} |
110 | |
111 | void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); } |
112 | |
113 | void dump_desc(std::ostream &s) const override { |
114 | s << static_cast<const prb_dims_t &>(*p_); |
115 | } |
116 | |
117 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
118 | |
119 | const attr_t *attr() const override { return &p_->attr; } |
120 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
121 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
122 | const std::string *name() const override { return &p_->name; } |
123 | const int *axis() const override { return &p_->axis; } |
124 | const dir_t *dir() const override { return &p_->dir; } |
125 | const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; } |
126 | const dnnl_data_type_t *ddt() const override { return &ddt_; } |
127 | const int64_t *user_mb() const override { return &p_->user_mb; } |
128 | const std::vector<std::string> *stag() const override { return &stag_; } |
129 | const std::string *dtag() const override { return &dtag_; } |
130 | |
131 | private: |
132 | const prb_t *p_; |
133 | std::vector<dnnl_data_type_t> sdt_; |
134 | dnnl_data_type_t ddt_; |
135 | std::vector<std::string> stag_; |
136 | std::string dtag_; |
137 | }; |
138 | |
139 | inline void map_off_to_mb_ic( |
140 | const prb_t *prb, int64_t off, int64_t &mb, int64_t &ic) { |
141 | for (int i = prb->ndims - 1; i > 1; i--) |
142 | off /= prb->dims[i]; |
143 | |
144 | ic = off % prb->dims[1]; |
145 | off /= prb->dims[1]; |
146 | mb = off % prb->dims[0]; |
147 | off /= prb->dims[0]; |
148 | assert(off == 0); |
149 | } |
150 | |
151 | inline void get_sizes(const prb_t *prb, int64_t &outer_size, |
152 | int64_t &inner_size, int64_t &axis_size) { |
153 | outer_size = inner_size = axis_size = 1; |
154 | for (int i = 0; i < prb->axis; i++) |
155 | outer_size *= prb->dims[i]; |
156 | for (int i = prb->axis + 1; i < prb->ndims; i++) |
157 | inner_size *= prb->dims[i]; |
158 | axis_size = prb->dims[prb->axis]; |
159 | } |
160 | |
161 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
162 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
163 | void compute_ref(const prb_t *prb, const args_t &args, |
164 | dnnl_primitive_t prim_ref = nullptr); |
165 | |
166 | int doit(const prb_t *prb, res_t *res); |
167 | int bench(int argc, char **argv); |
168 | |
169 | } // namespace softmax |
170 | |
171 | #endif |
172 | |