1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is a set of different implementations for the basic matrix by matrix
17// multiply function, commonly known as GEMM after the BLAS library's naming.
18// Having a standard interface enables us to swap out implementations on
19// different platforms, to make sure we're using the optimal version. They are
20// implemented as C++ template functors, so they're easy to swap into all of the
21// different kernels that use them.
22
23#if !defined(EIGEN_USE_THREADS)
24#error "EIGEN_USE_THREADS must be enabled by all .cc files including this."
25#endif // EIGEN_USE_THREADS
26
27#ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
28#define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
29
30#include <string.h>
31#include <map>
32#include <vector>
33
34#include "tensorflow/core/common_runtime/threadpool_device.h"
35#include "tensorflow/core/framework/op_kernel.h"
36#include "tensorflow/core/framework/tensor.h"
37#include "tensorflow/core/framework/tensor_types.h"
38
39#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
40#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
41#endif
42
43// Apple provides an optimized BLAS library that is better than Eigen for their
44// devices, so use that if possible.
45#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
46#include <Accelerate/Accelerate.h>
47#define USE_CBLAS_GEMM
48#endif // __APPLE__
49
50// Older Raspberry Pi systems don't have NEON SIMD acceleration, so Eigen falls
51// back to scalar code, but OpenBLAS has much faster support so prefer that.
52#if defined(RASPBERRY_PI) && defined(USE_GEMM_FOR_CONV) && defined(USE_OPENBLAS)
53#include <cblas.h>
54#define USE_CBLAS_GEMM
55#endif
56
57// A readable but slow implementation of matrix multiplication, useful for
58// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
59// the Im2ColConvFunctor template definition inside the op registration to
60// enable. Assumes row-major ordering of the values in memory.
61template <class T1, class T2, class T3>
62class ReferenceGemmFunctor {
63 public:
64 void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
65 size_t k, const T1* a, size_t lda, const T2* b, size_t ldb,
66 T3* c, size_t ldc) {
67 const size_t a_i_stride = lda;
68 const size_t a_l_stride = 1;
69 const size_t b_j_stride = 1;
70 const size_t b_l_stride = ldb;
71 const size_t c_i_stride = ldc;
72 const size_t c_j_stride = 1;
73 size_t i, j, l;
74 for (j = 0; j < n; j++) {
75 for (i = 0; i < m; i++) {
76 T3 total(0);
77 for (l = 0; l < k; l++) {
78 const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
79 const T1 a_value = a[a_index];
80 const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
81 const T2 b_value = b[b_index];
82 total += (a_value * b_value);
83 }
84 const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
85 c[c_index] = total;
86 }
87 }
88 }
89};
90
91// Uses the optimized EigenTensor library to implement the matrix multiplication
92// required by the Im2ColConvFunctor class. We supply the two input and one
93// output types so that the accumulator can potentially be higher-precision than
94// the inputs, even though we don't currently take advantage of this.
95template <class T1, class T2, class T3>
96class FastGemmFunctor {
97 public:
98 void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
99 size_t k, const T1* a, size_t lda, const T2* b, size_t ldb,
100 T3* c, size_t ldc) {
101 typename tensorflow::TTypes<const T1>::Matrix a_matrix(a, m, k);
102 typename tensorflow::TTypes<const T2>::Matrix b_matrix(b, k, n);
103 typename tensorflow::TTypes<T3>::Matrix c_matrix(c, m, n);
104
105 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
106 dim_pair[0].first = 1;
107 dim_pair[0].second = 0;
108 c_matrix.device(ctx->eigen_device<Eigen::ThreadPoolDevice>()) =
109 a_matrix.contract(b_matrix, dim_pair);
110 }
111};
112
113// Use float32 accumulation for bfloat16 to deal with precision accumulation
114// issues.
115template <>
116class FastGemmFunctor<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16> {
117 public:
118 void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
119 size_t k, const Eigen::bfloat16* a, size_t lda,
120 const Eigen::bfloat16* b, size_t ldb, Eigen::bfloat16* c,
121 size_t ldc) {
122 using ConstMatrix =
123 typename tensorflow::TTypes<const Eigen::bfloat16>::Matrix;
124 ConstMatrix a_matrix(a, m, k);
125 ConstMatrix b_matrix(b, k, n);
126 typename tensorflow::TTypes<Eigen::bfloat16>::Matrix c_matrix(c, m, n);
127
128 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
129 dim_pair[0].first = 1;
130 dim_pair[0].second = 0;
131 c_matrix.device(ctx->eigen_device<Eigen::ThreadPoolDevice>()) =
132 a_matrix.cast<float>()
133 .contract(b_matrix.cast<float>(), dim_pair)
134 .template cast<Eigen::bfloat16>();
135 }
136};
137
138// If we have a fast CBLAS library, use its implementation through a wrapper.
139#if defined(USE_CBLAS_GEMM)
140template <>
141class FastGemmFunctor<float, float, float> {
142 public:
143 void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
144 size_t k, const float* a, size_t lda, const float* b,
145 size_t ldb, float* c, size_t ldc) {
146 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
147 lda, b, ldb, 0.0f, c, ldc);
148 }
149};
150#endif // USE_CBLAS_GEMM
151
152#endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
153