1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl.h" |
18 | |
19 | #include "c_types_map.hpp" |
20 | #include "utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | static setting_t<fpmath_mode_t> default_fpmath {fpmath_mode::strict}; |
25 | |
26 | void init_fpmath_mode() { |
27 | if (default_fpmath.initialized()) return; |
28 | |
29 | static std::string val = getenv_string_user("DEFAULT_FPMATH_MODE" ); |
30 | if (!val.empty()) { |
31 | if (val.compare("strict" ) == 0) default_fpmath.set(fpmath_mode::strict); |
32 | if (val.compare("bf16" ) == 0) default_fpmath.set(fpmath_mode::bf16); |
33 | if (val.compare("f16" ) == 0) default_fpmath.set(fpmath_mode::f16); |
34 | if (val.compare("tf32" ) == 0) default_fpmath.set(fpmath_mode::tf32); |
35 | if (val.compare("any" ) == 0) default_fpmath.set(fpmath_mode::any); |
36 | } |
37 | if (!default_fpmath.initialized()) default_fpmath.set(default_fpmath.get()); |
38 | } |
39 | |
40 | status_t check_fpmath_mode(fpmath_mode_t mode) { |
41 | if (utils::one_of(mode, fpmath_mode::strict, fpmath_mode::bf16, |
42 | fpmath_mode::f16, fpmath_mode::tf32, fpmath_mode::any)) |
43 | return status::success; |
44 | return status::invalid_arguments; |
45 | } |
46 | |
47 | bool is_fpsubtype(data_type_t sub_dt, data_type_t dt) { |
48 | using namespace dnnl::impl::utils; |
49 | using namespace dnnl::impl::data_type; |
50 | |
51 | if (sub_dt == dt) return true; |
52 | |
53 | // Check for strict subtype |
54 | if (dt == f32) return one_of(sub_dt, tf32, bf16, f16); |
55 | if (dt == tf32) return one_of(sub_dt, bf16, f16); |
56 | |
57 | // bf16 and f16 have no strict subtypes |
58 | return false; |
59 | } |
60 | |
61 | fpmath_mode_t get_fpmath_mode() { |
62 | init_fpmath_mode(); |
63 | auto mode = default_fpmath.get(); |
64 | // Should always be proper, since no way to set invalid mode |
65 | assert(check_fpmath_mode(mode) == status::success); |
66 | return mode; |
67 | } |
68 | |
69 | } // namespace impl |
70 | } // namespace dnnl |
71 | |
72 | dnnl_status_t dnnl_set_default_fpmath_mode(dnnl_fpmath_mode_t mode) { |
73 | using namespace dnnl::impl; |
74 | auto st = check_fpmath_mode(mode); |
75 | if (st == status::success) default_fpmath.set(mode); |
76 | return st; |
77 | } |
78 | |
79 | dnnl_status_t dnnl_get_default_fpmath_mode(dnnl_fpmath_mode_t *mode) { |
80 | using namespace dnnl::impl; |
81 | if (mode == nullptr) return status::invalid_arguments; |
82 | |
83 | auto m = get_fpmath_mode(); |
84 | // Should always be proper, since no way to set invalid mode |
85 | auto st = check_fpmath_mode(m); |
86 | if (st == status::success) *mode = m; |
87 | return st; |
88 | } |
89 | |