1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
17
18#include <mutex> // NOLINT(build/c++11)
19
20#include "absl/base/call_once.h"
21
22// We need a pair of compile time and runtime flags to disable compilation of
23// custom contraction kernels for unsupported architectures (e.g. Android,
24// iOS, ARM and PPC CPUs, etc...), and to be able to fallback on default Eigen
25// matrix multiplication at runtime.
26//
27// It's not allowed to use absl flags library in Tensorflow, so we have to pass
28// the configuration through the environment variable.
29//
30// Example:
31// bazel test \
32// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \
33// //path/to:test
34
35#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
36
37namespace Eigen {
38namespace internal {
39
40// TODO(ezhulenev): This is a temporary workaround for disabling custom kernels
41// at runtime in tests. We should always rely on compile time flags for that.
42//
43// Example:
44// bazel test \
45// --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \
46// //path/to:test
47EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() {
48 static bool use_custom_contraction_kernel = true;
49
50// This subroutine should not be used in GPU. In case it is, a custom kernel
51// should always be used
52#if !defined __NVCC__ && !defined __HIP_DEVICE_COMPILE__
53 static absl::once_flag initialized;
54 absl::call_once(initialized, [&] {
55 char* flag = std::getenv("TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL");
56 if (flag && (strcmp(flag, "false") == 0 || strcmp(flag, "0") == 0)) {
57 use_custom_contraction_kernel = false;
58 }
59 });
60#endif
61
62 return use_custom_contraction_kernel;
63}
64
65} // namespace internal
66} // namespace Eigen
67#endif
68