1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
17#define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
18
19#include <stdlib.h>
20
21#include "third_party/eigen3/Eigen/Core"
22#include "tensorflow/core/platform/types.h"
23
24// This is an unoptimized but debuggable implementation of the GEMM matrix
25// multiply function, used to compare to faster but more opaque versions, or
26// for bit depths or argument combinations that aren't supported by optimized
27// code.
28// It assumes the row-major convention used by TensorFlow, and implements
29// C = A * B, like the standard BLAS GEMM interface. If the transpose flags are
30// true, then the relevant matrix is treated as stored in column-major order.
31
32namespace tensorflow {
33template <class T1, class T2, class T3>
34void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
35 size_t m, size_t n, size_t k, const T1* a, int32_t offset_a,
36 size_t lda, const T2* b, int32_t offset_b, size_t ldb, T3* c,
37 int32_t shift_c, int32_t offset_c, int32_t mult_c,
38 size_t ldc) {
39 int a_i_stride;
40 int a_l_stride;
41 if (transpose_a) {
42 a_i_stride = 1;
43 a_l_stride = lda;
44 } else {
45 a_i_stride = lda;
46 a_l_stride = 1;
47 }
48 int b_j_stride;
49 int b_l_stride;
50 if (transpose_b) {
51 b_j_stride = ldb;
52 b_l_stride = 1;
53 } else {
54 b_j_stride = 1;
55 b_l_stride = ldb;
56 }
57 int c_i_stride;
58 int c_j_stride;
59 if (transpose_c) {
60 c_i_stride = 1;
61 c_j_stride = ldc;
62 } else {
63 c_i_stride = ldc;
64 c_j_stride = 1;
65 }
66
67 const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
68 const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
69 const int32_t rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1));
70
71 int i, j, l;
72 for (j = 0; j < n; j++) {
73 for (i = 0; i < m; i++) {
74 int32_t total = 0;
75 for (l = 0; l < k; l++) {
76 const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
77 const int32_t a_value = static_cast<int32>(a[a_index]) - offset_a;
78 const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
79 const int32_t b_value = static_cast<int32>(b[b_index]) - offset_b;
80 total += (a_value * b_value);
81 }
82 const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
83 int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c);
84 if (output > highest) {
85 output = highest;
86 }
87 if (output < lowest) {
88 output = lowest;
89 }
90 c[c_index] = static_cast<T3>(output);
91 }
92 }
93}
94} // namespace tensorflow
95
96#endif // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
97