1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <limits.h> |
19 | #include <stdint.h> |
20 | |
21 | #include <cctype> |
22 | #include <fstream> |
23 | #include <functional> |
24 | #include <string> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include "oneapi/dnnl/dnnl.h" |
29 | |
30 | #include "common.hpp" |
31 | #include "utils/parser.hpp" |
32 | |
33 | #include "utils/parallel.hpp" |
34 | |
35 | // BENCHDNN_MEMORY_CHECK macro enables guarding mechanism for memory allocation: |
36 | // memory block is allocated on a page boundary and the page after the block is |
37 | // protected to catch possible invalid accesses. |
38 | // |
39 | // Note that the macro affects the correctness mode only. |
40 | #ifdef __unix__ |
41 | #define BENCHDNN_MEMORY_CHECK |
42 | #endif |
43 | |
44 | #ifdef BENCHDNN_MEMORY_CHECK |
45 | #include <stdlib.h> |
46 | #include <unistd.h> |
47 | #include <sys/mman.h> |
48 | #endif |
49 | |
50 | bench_mode_t RUN {0x1}; // Run mode. |
51 | bench_mode_t CORR {0x2}; // Correctness mode. The default one. |
52 | bench_mode_t PERF {0x4}; // Performance mode. May be combined with CORR. |
53 | bench_mode_t LIST {0x8}; // Listing mode. Standalone mode to only create prb. |
54 | bench_mode_t PROF { |
55 | 0x10}; // Profiling-based performance mode (may be only combined with performance mode). |
56 | |
57 | bool is_bench_mode(bench_mode_t user_mode) { |
58 | return !(bench_mode & user_mode).none(); |
59 | } |
60 | |
61 | /* result structure */ |
62 | const char *state2str(res_state_t state) { |
63 | if (state == UNTESTED) return "UNTESTED_FAILED" ; // for easier fail search |
64 | |
65 | #define CASE(x) \ |
66 | if (state == (x)) return STRINGIFY(x) |
67 | CASE(PASSED); |
68 | CASE(SKIPPED); |
69 | CASE(MISTRUSTED); |
70 | CASE(UNIMPLEMENTED); |
71 | CASE(INVALID_ARGUMENTS); |
72 | CASE(FAILED); |
73 | CASE(LISTED); |
74 | CASE(EXECUTED); |
75 | #undef CASE |
76 | assert(!"unknown res state" ); |
77 | return "STATE_UNDEF" ; |
78 | } |
79 | |
80 | const char *skip_reason2str(skip_reason_t skip_reason) { |
81 | #define CASE(x) \ |
82 | if (skip_reason == (x)) return STRINGIFY(x) |
83 | CASE(CASE_NOT_SUPPORTED); |
84 | CASE(DATA_TYPE_NOT_SUPPORTED); |
85 | CASE(INVALID_CASE); |
86 | CASE(NOT_ENOUGH_RAM); |
87 | CASE(SKIP_IMPL_HIT); |
88 | CASE(SKIP_START); |
89 | #undef CASE |
90 | return "SKIP_UNKNOWN" ; |
91 | } |
92 | |
93 | void parse_result(res_t &res, const char *pstr) { |
94 | auto &bs = benchdnn_stat; |
95 | const char *state = state2str(res.state); |
96 | |
97 | switch (res.state) { |
98 | case UNTESTED: |
99 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
100 | bs.failed++; |
101 | break; |
102 | case EXECUTED: |
103 | if (is_bench_mode(RUN)) |
104 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
105 | bs.passed++; |
106 | break; |
107 | case FAILED: |
108 | bs.failed++; |
109 | BENCHDNN_PRINT(0, "%d:%s (errors:%lu total:%lu) __REPRO: %s\n" , |
110 | bs.tests, state, (unsigned long)res.errors, |
111 | (unsigned long)res.total, pstr); |
112 | break; |
113 | case SKIPPED: |
114 | BENCHDNN_PRINT(0, "%d:%s (%s) __REPRO: %s\n" , bs.tests, state, |
115 | skip_reason2str(res.reason), pstr); |
116 | bs.skipped++; |
117 | break; |
118 | case UNIMPLEMENTED: |
119 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
120 | bs.unimplemented++; |
121 | bs.failed++; |
122 | break; |
123 | case INVALID_ARGUMENTS: |
124 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
125 | bs.invalid_arguments++; |
126 | bs.failed++; |
127 | break; |
128 | case MISTRUSTED: |
129 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
130 | bs.mistrusted++; |
131 | break; |
132 | case PASSED: |
133 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
134 | bs.passed++; |
135 | break; |
136 | case LISTED: |
137 | BENCHDNN_PRINT(0, "%d:%s __REPRO: %s\n" , bs.tests, state, pstr); |
138 | bs.listed++; |
139 | break; |
140 | default: assert(!"unknown state" ); SAFE_V(FAIL); |
141 | } |
142 | |
143 | bs.tests++; |
144 | assert(bs.tests == (bs.passed + bs.skipped + bs.mistrusted + bs.failed) |
145 | || bs.tests == bs.listed); |
146 | |
147 | if (is_bench_mode(PERF)) { |
148 | using bt = timer::timer_t; |
149 | const auto &t = res.timer_map.perf_timer(); |
150 | for (int mode = 0; mode < (int)bt::n_modes; ++mode) |
151 | bs.ms[mode] += t.ms((bt::mode_t)mode); |
152 | } |
153 | |
154 | if (is_bench_mode(CORR)) { |
155 | using bt = timer::timer_t; |
156 | const auto &t = res.timer_map.get_timer(timer::timer_t::ref_timer); |
157 | bs.ms[bt::mode_t::sum] += t.sec(bt::mode_t::sum); |
158 | } |
159 | } |
160 | |
161 | /* misc */ |
162 | |
163 | #ifdef BENCHDNN_MEMORY_CHECK |
164 | static void *zmalloc_protect(size_t size) { |
165 | const size_t page_sz = getpagesize(); |
166 | |
167 | const size_t block_sz = size + 3 * sizeof(void *); |
168 | const size_t total_sz = div_up(block_sz, page_sz) * page_sz + page_sz; |
169 | |
170 | void *mem_ptr; |
171 | int rc = ::posix_memalign(&mem_ptr, page_sz, total_sz); |
172 | if (rc != 0) return nullptr; |
173 | |
174 | uint8_t *ptr_start = (uint8_t *)mem_ptr; |
175 | uint8_t *ptr = ptr_start + total_sz - page_sz - size; |
176 | |
177 | // Aligned on a page boundary |
178 | void *ptr_protect = ptr + size; |
179 | |
180 | // Layout of the allocated region: |
181 | // ptr_start <- start of the allocated region |
182 | // ptr[-16] <- stores start address: ptr_start |
183 | // ptr[-8] <- stores protected address: ptr_protect |
184 | // ptr <- pointer to be returned from the function |
185 | // ptr_protect <- pointer to the block to protect |
186 | |
187 | // Protect one page right after the block of size bytes |
188 | int err = mprotect(ptr_protect, page_sz, PROT_NONE); |
189 | if (err != 0) { |
190 | ::free(ptr_start); |
191 | return nullptr; |
192 | } |
193 | |
194 | // Align down `ptr` on 8 bytes before storing addresses to make behavior |
195 | // defined. |
196 | ptrdiff_t to_align = reinterpret_cast<ptrdiff_t>(ptr) % sizeof(void *); |
197 | void *ptr_aligned_8 = ptr - to_align; |
198 | // Save pointers for zfree_protect |
199 | ((void **)ptr_aligned_8)[-2] = ptr_start; |
200 | ((void **)ptr_aligned_8)[-1] = ptr_protect; |
201 | |
202 | return ptr; |
203 | } |
204 | |
205 | static void zfree_protect(void *ptr) { |
206 | // Get aligned ptr before obtaining addresses |
207 | ptrdiff_t to_align = reinterpret_cast<ptrdiff_t>(ptr) % sizeof(void *); |
208 | void *ptr_aligned_8 = reinterpret_cast<uint8_t *>(ptr) - to_align; |
209 | |
210 | // Restore read-write access for the protected region |
211 | void *ptr_protect = ((void **)ptr_aligned_8)[-1]; |
212 | const size_t page_sz = getpagesize(); |
213 | mprotect(ptr_protect, page_sz, PROT_READ | PROT_WRITE); |
214 | |
215 | // Deallocate the whole region |
216 | void *ptr_start = ((void **)ptr_aligned_8)[-2]; |
217 | ::free(ptr_start); |
218 | } |
219 | #endif |
220 | |
221 | void *zmalloc(size_t size, size_t align) { |
222 | #ifdef BENCHDNN_MEMORY_CHECK |
223 | if (is_bench_mode(CORR)) { return zmalloc_protect(size); } |
224 | #endif |
225 | |
226 | void *ptr; |
227 | #ifdef _WIN32 |
228 | ptr = _aligned_malloc(size, align); |
229 | int rc = ((ptr) ? 0 : errno); |
230 | #else |
231 | // posix_memalign requires alignment to be |
232 | // a power of 2 and a multiple of sizeof(void *) |
233 | if (align < sizeof(void *)) align = sizeof(void *); |
234 | assert(((align & (align - 1)) == 0) && "align must be a power of 2" ); |
235 | |
236 | // TODO. Heuristics: Increasing the size to alignment increases |
237 | // the stability of performance results. |
238 | if (is_bench_mode(PERF) && (size < align)) size = align; |
239 | int rc = ::posix_memalign(&ptr, align, size); |
240 | #endif /* _WIN32 */ |
241 | return rc == 0 ? ptr : nullptr; |
242 | } |
243 | |
244 | void zfree(void *ptr) { |
245 | #ifdef BENCHDNN_MEMORY_CHECK |
246 | if (is_bench_mode(CORR)) { |
247 | zfree_protect(ptr); |
248 | return; |
249 | } |
250 | #endif |
251 | |
252 | #ifdef _WIN32 |
253 | _aligned_free(ptr); |
254 | #else |
255 | return ::free(ptr); |
256 | #endif /* _WIN32 */ |
257 | } |
258 | |
259 | bool str2bool(const char *str) { |
260 | return !strcasecmp("true" , str) || !strcasecmp("1" , str); |
261 | } |
262 | |
263 | const char *bool2str(bool value) { |
264 | return value ? "true" : "false" ; |
265 | } |
266 | |
267 | #ifdef _WIN32 |
268 | /* NOTE: this should be supported on linux as well, but currently |
269 | * having issues for ICC170 and Clang*/ |
270 | #include <regex> |
271 | |
272 | bool match_regex(const char *str, const char *pattern) { |
273 | std::regex re(pattern); |
274 | return std::regex_search(str, re); |
275 | } |
276 | #else |
277 | #include <regex.h> |
278 | #include <sys/types.h> |
279 | |
280 | bool match_regex(const char *str, const char *pattern) { |
281 | static regex_t regex; |
282 | static const char *prev_pattern = nullptr; |
283 | if (pattern != prev_pattern) { |
284 | if (prev_pattern) regfree(®ex); |
285 | |
286 | if (regcomp(®ex, pattern, 0)) { |
287 | fprintf(stderr, "could not create regex\n" ); |
288 | return true; |
289 | } |
290 | |
291 | prev_pattern = pattern; |
292 | } |
293 | |
294 | return !regexec(®ex, str, 0, nullptr, 0); |
295 | } |
296 | #endif /* _WIN32 */ |
297 | |
298 | bool maybe_skip(const std::string &impl_str) { |
299 | if (skip_impl.empty()) return false; |
300 | |
301 | size_t start_pos = 0; |
302 | // Iterate over impls in skip list. |
303 | while (start_pos != std::string::npos) { |
304 | const auto skip_impl_item |
305 | = parser::get_substr(skip_impl, start_pos, ','); |
306 | if (skip_impl_item.empty()) continue; |
307 | if (impl_str.find(skip_impl_item) != std::string::npos) return true; |
308 | } |
309 | |
310 | return false; |
311 | } |
312 | |
313 | #if defined(_WIN32) && !defined(__GNUC__) |
314 | #include <windows.h> |
315 | #define PATH_MAX MAX_PATH |
316 | static char *dirname(char *path) { |
317 | char drive[_MAX_DRIVE]; |
318 | char dir[_MAX_DIR]; |
319 | SAFE_V(_splitpath_s(path, drive, sizeof(drive), dir, sizeof(dir), NULL, 0, |
320 | NULL, 0) == 0 |
321 | ? OK |
322 | : FAIL); |
323 | path[0] = '\0'; |
324 | SAFE_V(strncat_s(path, PATH_MAX, drive, _MAX_DRIVE) == 0 ? OK : FAIL); |
325 | SAFE_V(strncat_s(path, PATH_MAX, dir, _MAX_DIR) == 0 ? OK : FAIL); |
326 | if (path[0] == '\0') { |
327 | path[0] = '.'; |
328 | path[1] = '\0'; |
329 | } |
330 | return path; |
331 | } |
332 | |
333 | int readlink(const char *path, char *buf, size_t buf_max) { |
334 | (void)path; |
335 | // NULL means take the path of myself |
336 | return GetModuleFileName(NULL, buf, (DWORD)buf_max); |
337 | } |
338 | #else |
339 | #include <libgen.h> |
340 | #include <unistd.h> |
341 | #endif /* _WIN32 */ |
342 | |
343 | std::string locate_batch_file(const std::string &fname) { |
344 | SAFE_V(fname.length() < PATH_MAX ? OK : FAIL); |
345 | |
346 | const int max_paths = 4; |
347 | |
348 | static int n_paths = 0; |
349 | static std::string search_paths[max_paths]; |
350 | |
351 | std::string fdir; |
352 | { |
353 | std::string fname_copy = fname; |
354 | fname_copy.resize(PATH_MAX); |
355 | char *c_fdir = dirname(&fname_copy[0]); |
356 | fdir = std::string(c_fdir); |
357 | } |
358 | |
359 | bool dir_found = false; |
360 | for (int n = 0; n_paths < max_paths && n < n_paths; ++n) |
361 | if (search_paths[n].find(fdir) == 0) { |
362 | dir_found = true; |
363 | break; |
364 | } |
365 | if (!dir_found) { |
366 | SAFE_V(n_paths < max_paths ? OK : FAIL); |
367 | search_paths[n_paths++] = std::move(fdir); |
368 | } |
369 | |
370 | std::ifstream ifs(fname); |
371 | if (ifs.is_open()) return fname; |
372 | |
373 | for (int n = 0; n < n_paths; ++n) { |
374 | std::string fullname = search_paths[n] + "/" + fname; |
375 | ifs.open(fullname); |
376 | if (ifs.is_open()) { |
377 | BENCHDNN_PRINT(50, "batch file used: %s\n" , fullname.c_str()); |
378 | ifs.close(); |
379 | return fullname; |
380 | } |
381 | ifs.close(); |
382 | } |
383 | |
384 | // Search in default inputs directory |
385 | // Takes dirname(executable)/inputs/file_name on Linux |
386 | // Takes dirname(executable)/../inputs/file_name on Windows |
387 | fdir.resize(PATH_MAX); |
388 | int length = readlink("/proc/self/exe" , &fdir[0], PATH_MAX); |
389 | if (length) { |
390 | std::string s_fdir = dirname(&fdir[0]); |
391 | for (int i_try = 0; i_try < 2; ++i_try) { |
392 | fdir = s_fdir; |
393 | fdir.append(i_try == 1 ? "/../inputs/" : "/inputs/" ); |
394 | assert(!driver_name.empty()); |
395 | fdir.append(driver_name); |
396 | |
397 | std::string fullname = fdir + "/" ; |
398 | fullname += fname; |
399 | ifs.open(fullname); |
400 | if (ifs.is_open()) { |
401 | search_paths[n_paths++] = std::move(fdir); |
402 | BENCHDNN_PRINT(50, "batch file used: %s\n" , fullname.c_str()); |
403 | ifs.close(); |
404 | return fullname; |
405 | } |
406 | ifs.close(); |
407 | } |
408 | } |
409 | |
410 | fprintf(stderr, "cannot open file %s\n" , fname.c_str()); |
411 | return fname; |
412 | } |
413 | |
414 | int batch(const char *fname, bench_f bench) { |
415 | std::ifstream ifs(locate_batch_file(std::string(fname))); |
416 | SAFE(ifs.is_open() ? OK : FAIL, CRIT); |
417 | |
418 | std::vector<std::string> opts; |
419 | std::string str; |
420 | bool continued_line = false; |
421 | while (ifs >> str) { |
422 | if (str.length() == 0) continue; |
423 | |
424 | // shell style comments |
425 | if (str.front() == '#') { |
426 | std::getline(ifs, str); // take whole commented line out |
427 | continue; |
428 | } |
429 | |
430 | // shell style line break |
431 | if (continued_line) { |
432 | // NOLINTNEXTLINE(performance-inefficient-string-concatenation) |
433 | str = opts.back() + str; // update current line with previous |
434 | opts.pop_back(); // take previous line out |
435 | } |
436 | |
437 | if (str.back() == '\\') { |
438 | continued_line = true; |
439 | if (str.length() == 1) continue; // line break lives separately |
440 | str.erase(str.size() - 1); // otherwise remove it |
441 | } else { |
442 | continued_line = false; |
443 | } |
444 | |
445 | opts.push_back(std::move(str)); |
446 | } |
447 | |
448 | std::vector<char *> c_opts; |
449 | c_opts.reserve(opts.size()); |
450 | for (const auto &opt : opts) |
451 | c_opts.push_back(const_cast<char *>(opt.c_str())); |
452 | |
453 | return bench(static_cast<int>(c_opts.size()), c_opts.data()); |
454 | } |
455 | |
456 | int flip_coin(ptrdiff_t seed, float probability) { |
457 | const ptrdiff_t big_prime = 1000003; |
458 | const ptrdiff_t prime = 753737; |
459 | seed *= prime; |
460 | return (seed % big_prime) < (probability * big_prime); |
461 | } |
462 | |
463 | int64_t div_up(const int64_t a, const int64_t b) { |
464 | SAFE_V(b != 0 ? OK : FAIL); |
465 | return (a + b - 1) / b; |
466 | } |
467 | |
468 | int64_t next_pow2(int64_t a) { |
469 | assert(a > 0 && a <= ((int64_t)1 << 62)); |
470 | if (a > 1) a--; |
471 | while (a & (a - 1)) |
472 | a &= (a - 1); |
473 | return a << 1; |
474 | } |
475 | |
476 | #if defined(__x86_64__) || defined(_M_X64) |
477 | #include <immintrin.h> |
478 | #include <xmmintrin.h> |
479 | |
480 | int mxcsr_cvt(float f) { |
481 | return _mm_cvtss_si32(_mm_load_ss(&f)); |
482 | } |
483 | void init_fp_mode() { |
484 | // We set ftz to avoid denormals in perf measurements |
485 | _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); |
486 | } |
487 | #else |
488 | int mxcsr_cvt(float f) { |
489 | return (int)nearbyintf(f); |
490 | } |
491 | void init_fp_mode() {} |
492 | #endif |
493 | |
494 | void array_set(char *arr, size_t size) { |
495 | for (size_t i = 0; i < size; ++i) |
496 | arr[i] = 0; |
497 | } |
498 | |
499 | void gemm(const char *layout, const char *transa, const char *transb, int64_t m, |
500 | int64_t n, int64_t k, const float alpha, const float *a, |
501 | const int64_t lda, const float *b, const int64_t ldb, const float beta, |
502 | float *c, const int64_t ldc) { |
503 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
504 | if (*layout == 'C') { |
505 | dnnl_sgemm( |
506 | *transa, *transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
507 | } else { |
508 | dnnl_sgemm( |
509 | *transb, *transa, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc); |
510 | } |
511 | #else |
512 | if (std::toupper(*layout) != 'C') { |
513 | gemm("C" , transb, transa, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc); |
514 | return; |
515 | } |
516 | |
517 | auto a_accessor = [&](int64_t i, int64_t j) { |
518 | if (std::toupper(*transa) == 'N') return a[i * lda + j]; |
519 | return a[j * lda + i]; |
520 | }; |
521 | |
522 | auto b_accessor = [&](int64_t i, int64_t j) { |
523 | if (std::toupper(*transb) == 'N') return b[i * ldb + j]; |
524 | return b[j * ldb + i]; |
525 | }; |
526 | |
527 | benchdnn_parallel_nd(m, n, [&](int64_t i, int64_t j) { |
528 | float ab = 0.0f; |
529 | for (int64_t _k = 0; _k < k; ++_k) |
530 | ab += a_accessor(i, _k) * b_accessor(_k, j); |
531 | float cij = (beta == 0) ? 0.0f : beta * c[i * ldc + j]; |
532 | c[i * ldc + j] = alpha * ab + cij; |
533 | }); |
534 | #endif |
535 | } |
536 | |
537 | int sanitize_desc(int &ndims, std::vector<std::reference_wrapper<int64_t>> d, |
538 | std::vector<std::reference_wrapper<int64_t>> h, |
539 | std::vector<std::reference_wrapper<int64_t>> w, |
540 | const std::vector<int64_t> &def_values, bool must_have_spatial) { |
541 | size_t N = d.size(); |
542 | assert(h.size() == N && w.size() == N && def_values.size() == N); |
543 | |
544 | ndims = 5; |
545 | |
546 | // check output spatial values |
547 | const bool no_d = d[0].get() == 0; |
548 | const bool no_h = h[0].get() == 0; |
549 | const bool no_w = w[0].get() == 0; |
550 | |
551 | if (no_d) ndims--; |
552 | if (no_d && no_h) ndims--; |
553 | if (no_d && no_h && no_w) ndims--; |
554 | if (must_have_spatial && ndims <= 2) return FAIL; |
555 | |
556 | if (ndims == 5) { |
557 | if (no_h && no_w) { |
558 | // User specified values for the d dimension but not values for h |
559 | // and w dimensions. Propagate d values to h and w dimensions. |
560 | for (size_t n = 0; n < N; ++n) |
561 | w[n].get() = h[n].get() = d[n].get(); |
562 | } else if (!no_h && !no_w) { |
563 | // User specified them all, good to go. |
564 | } else { |
565 | // Problem is not cubic and one of h or w dimension is missing. |
566 | return FAIL; |
567 | } |
568 | } else if (ndims == 4 && no_w) { |
569 | // User specified values for the h dimension but not values for the w |
570 | // dimension. Propagate h values to the w dimension. |
571 | for (size_t n = 0; n < N; ++n) |
572 | w[n].get() = h[n].get(); |
573 | } |
574 | |
575 | for (size_t n = 0; n < N; ++n) { |
576 | if (ndims < 5) d[n].get() = def_values[n]; |
577 | if (ndims < 4) h[n].get() = def_values[n]; |
578 | if (ndims < 3) w[n].get() = def_values[n]; |
579 | } |
580 | |
581 | return OK; |
582 | } |
583 | |
584 | void print_dhw(bool &print_d, bool &print_h, bool &print_w, int ndims, |
585 | const std::vector<int64_t> &d, const std::vector<int64_t> &h, |
586 | const std::vector<int64_t> &w) { |
587 | size_t N = d.size(); |
588 | assert(h.size() == N && w.size() == N); |
589 | |
590 | bool square_shape = true, cubic_shape = true; |
591 | for (size_t n = 0; n < N; ++n) { |
592 | square_shape = square_shape && h[n] == w[n]; |
593 | cubic_shape = cubic_shape && d[n] == h[n] && h[n] == w[n]; |
594 | } |
595 | |
596 | print_d = ndims == 5; |
597 | print_h = ndims == 4 || (ndims == 5 && (!cubic_shape || canonical)); |
598 | print_w = ndims == 3 || (ndims == 5 && (!cubic_shape || canonical)) |
599 | || (ndims == 4 && (!square_shape || canonical)); |
600 | } |
601 | |