1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ |
18 | |
19 | #include <stdlib.h> |
20 | |
21 | #include "third_party/eigen3/Eigen/Core" |
22 | #include "tensorflow/core/platform/types.h" |
23 | |
24 | // This is an unoptimized but debuggable implementation of the GEMM matrix |
25 | // multiply function, used to compare to faster but more opaque versions, or |
26 | // for bit depths or argument combinations that aren't supported by optimized |
27 | // code. |
28 | // It assumes the row-major convention used by TensorFlow, and implements |
29 | // C = A * B, like the standard BLAS GEMM interface. If the transpose flags are |
30 | // true, then the relevant matrix is treated as stored in column-major order. |
31 | |
32 | namespace tensorflow { |
33 | template <class T1, class T2, class T3> |
34 | void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, |
35 | size_t m, size_t n, size_t k, const T1* a, int32_t offset_a, |
36 | size_t lda, const T2* b, int32_t offset_b, size_t ldb, T3* c, |
37 | int32_t shift_c, int32_t offset_c, int32_t mult_c, |
38 | size_t ldc) { |
39 | int a_i_stride; |
40 | int a_l_stride; |
41 | if (transpose_a) { |
42 | a_i_stride = 1; |
43 | a_l_stride = lda; |
44 | } else { |
45 | a_i_stride = lda; |
46 | a_l_stride = 1; |
47 | } |
48 | int b_j_stride; |
49 | int b_l_stride; |
50 | if (transpose_b) { |
51 | b_j_stride = ldb; |
52 | b_l_stride = 1; |
53 | } else { |
54 | b_j_stride = 1; |
55 | b_l_stride = ldb; |
56 | } |
57 | int c_i_stride; |
58 | int c_j_stride; |
59 | if (transpose_c) { |
60 | c_i_stride = 1; |
61 | c_j_stride = ldc; |
62 | } else { |
63 | c_i_stride = ldc; |
64 | c_j_stride = 1; |
65 | } |
66 | |
67 | const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest()); |
68 | const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest()); |
69 | const int32_t rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1)); |
70 | |
71 | int i, j, l; |
72 | for (j = 0; j < n; j++) { |
73 | for (i = 0; i < m; i++) { |
74 | int32_t total = 0; |
75 | for (l = 0; l < k; l++) { |
76 | const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); |
77 | const int32_t a_value = static_cast<int32>(a[a_index]) - offset_a; |
78 | const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); |
79 | const int32_t b_value = static_cast<int32>(b[b_index]) - offset_b; |
80 | total += (a_value * b_value); |
81 | } |
82 | const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); |
83 | int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c); |
84 | if (output > highest) { |
85 | output = highest; |
86 | } |
87 | if (output < lowest) { |
88 | output = lowest; |
89 | } |
90 | c[c_index] = static_cast<T3>(output); |
91 | } |
92 | } |
93 | } |
94 | } // namespace tensorflow |
95 | |
96 | #endif // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ |
97 | |