1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // This is a set of different implementations for the basic matrix by matrix |
17 | // multiply function, commonly known as GEMM after the BLAS library's naming. |
18 | // Having a standard interface enables us to swap out implementations on |
19 | // different platforms, to make sure we're using the optimal version. They are |
20 | // implemented as C++ template functors, so they're easy to swap into all of the |
21 | // different kernels that use them. |
22 | |
23 | #if !defined(EIGEN_USE_THREADS) |
24 | #error "EIGEN_USE_THREADS must be enabled by all .cc files including this." |
25 | #endif // EIGEN_USE_THREADS |
26 | |
27 | #ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ |
28 | #define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ |
29 | |
30 | #include <string.h> |
31 | #include <map> |
32 | #include <vector> |
33 | |
34 | #include "tensorflow/core/common_runtime/threadpool_device.h" |
35 | #include "tensorflow/core/framework/op_kernel.h" |
36 | #include "tensorflow/core/framework/tensor.h" |
37 | #include "tensorflow/core/framework/tensor_types.h" |
38 | |
39 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
40 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
41 | #endif |
42 | |
43 | // Apple provides an optimized BLAS library that is better than Eigen for their |
44 | // devices, so use that if possible. |
45 | #if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV) |
46 | #include <Accelerate/Accelerate.h> |
47 | #define USE_CBLAS_GEMM |
48 | #endif // __APPLE__ |
49 | |
50 | // Older Raspberry Pi systems don't have NEON SIMD acceleration, so Eigen falls |
51 | // back to scalar code, but OpenBLAS has much faster support so prefer that. |
52 | #if defined(RASPBERRY_PI) && defined(USE_GEMM_FOR_CONV) && defined(USE_OPENBLAS) |
53 | #include <cblas.h> |
54 | #define USE_CBLAS_GEMM |
55 | #endif |
56 | |
57 | // A readable but slow implementation of matrix multiplication, useful for |
58 | // debugging and understanding the algorithm. Use instead of FastGemmFunctor in |
59 | // the Im2ColConvFunctor template definition inside the op registration to |
60 | // enable. Assumes row-major ordering of the values in memory. |
61 | template <class T1, class T2, class T3> |
62 | class ReferenceGemmFunctor { |
63 | public: |
64 | void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, |
65 | size_t k, const T1* a, size_t lda, const T2* b, size_t ldb, |
66 | T3* c, size_t ldc) { |
67 | const size_t a_i_stride = lda; |
68 | const size_t a_l_stride = 1; |
69 | const size_t b_j_stride = 1; |
70 | const size_t b_l_stride = ldb; |
71 | const size_t c_i_stride = ldc; |
72 | const size_t c_j_stride = 1; |
73 | size_t i, j, l; |
74 | for (j = 0; j < n; j++) { |
75 | for (i = 0; i < m; i++) { |
76 | T3 total(0); |
77 | for (l = 0; l < k; l++) { |
78 | const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); |
79 | const T1 a_value = a[a_index]; |
80 | const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); |
81 | const T2 b_value = b[b_index]; |
82 | total += (a_value * b_value); |
83 | } |
84 | const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); |
85 | c[c_index] = total; |
86 | } |
87 | } |
88 | } |
89 | }; |
90 | |
91 | // Uses the optimized EigenTensor library to implement the matrix multiplication |
92 | // required by the Im2ColConvFunctor class. We supply the two input and one |
93 | // output types so that the accumulator can potentially be higher-precision than |
94 | // the inputs, even though we don't currently take advantage of this. |
95 | template <class T1, class T2, class T3> |
96 | class FastGemmFunctor { |
97 | public: |
98 | void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, |
99 | size_t k, const T1* a, size_t lda, const T2* b, size_t ldb, |
100 | T3* c, size_t ldc) { |
101 | typename tensorflow::TTypes<const T1>::Matrix a_matrix(a, m, k); |
102 | typename tensorflow::TTypes<const T2>::Matrix b_matrix(b, k, n); |
103 | typename tensorflow::TTypes<T3>::Matrix c_matrix(c, m, n); |
104 | |
105 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
106 | dim_pair[0].first = 1; |
107 | dim_pair[0].second = 0; |
108 | c_matrix.device(ctx->eigen_device<Eigen::ThreadPoolDevice>()) = |
109 | a_matrix.contract(b_matrix, dim_pair); |
110 | } |
111 | }; |
112 | |
113 | // Use float32 accumulation for bfloat16 to deal with precision accumulation |
114 | // issues. |
115 | template <> |
116 | class FastGemmFunctor<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16> { |
117 | public: |
118 | void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, |
119 | size_t k, const Eigen::bfloat16* a, size_t lda, |
120 | const Eigen::bfloat16* b, size_t ldb, Eigen::bfloat16* c, |
121 | size_t ldc) { |
122 | using ConstMatrix = |
123 | typename tensorflow::TTypes<const Eigen::bfloat16>::Matrix; |
124 | ConstMatrix a_matrix(a, m, k); |
125 | ConstMatrix b_matrix(b, k, n); |
126 | typename tensorflow::TTypes<Eigen::bfloat16>::Matrix c_matrix(c, m, n); |
127 | |
128 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; |
129 | dim_pair[0].first = 1; |
130 | dim_pair[0].second = 0; |
131 | c_matrix.device(ctx->eigen_device<Eigen::ThreadPoolDevice>()) = |
132 | a_matrix.cast<float>() |
133 | .contract(b_matrix.cast<float>(), dim_pair) |
134 | .template cast<Eigen::bfloat16>(); |
135 | } |
136 | }; |
137 | |
138 | // If we have a fast CBLAS library, use its implementation through a wrapper. |
139 | #if defined(USE_CBLAS_GEMM) |
140 | template <> |
141 | class FastGemmFunctor<float, float, float> { |
142 | public: |
143 | void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, |
144 | size_t k, const float* a, size_t lda, const float* b, |
145 | size_t ldb, float* c, size_t ldc) { |
146 | cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a, |
147 | lda, b, ldb, 0.0f, c, ldc); |
148 | } |
149 | }; |
150 | #endif // USE_CBLAS_GEMM |
151 | |
152 | #endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ |
153 | |