1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the main Ruy public header.
17
18#ifndef RUY_RUY_RUY_H_
19#define RUY_RUY_RUY_H_
20
21#include "ruy/context.h"
22#include "ruy/context_get_ctx.h"
23#include "ruy/frontend.h"
24#include "ruy/mat.h"
25#include "ruy/matrix.h"
26#include "ruy/mul_params.h"
27#include "ruy/path.h"
28#include "ruy/trace.h"
29
30namespace ruy {
31
32// Entry point allowing to specify a custom OR-ed set of Path's to
33// compile. See the comments in path.h for more details about that.
34// Most users should use the other ruy::Mul overload not taking a Path template
35// parameter, and the main documentation comment is on that overload.
36template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
37 typename AccumScalar, typename DstScalar>
38void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
39 const MulParams<AccumScalar, DstScalar>& mul_params, Context* context,
40 Matrix<DstScalar>* dst) {
41 RUY_TRACE_SCOPE;
42 RUY_TRACE_INFO(MUL);
43 Mat<LhsScalar> internal_lhs = ToInternal(lhs);
44 Mat<RhsScalar> internal_rhs = ToInternal(rhs);
45 Mat<DstScalar> internal_dst = ToInternal(*dst);
46 MulFrontEnd<CompiledPaths>(internal_lhs, internal_rhs, mul_params,
47 get_ctx(context), &internal_dst);
48}
49
50// Performs a multiplication of matrices, with some extra features for
51// neural network applications. The basic operation is:
52//
53// dst = lhs * rhs // matrix multiplication
54//
55// The `mul_params` argument conveys additional parameters that are not
56// naturally associated with lhs, rhs, dst. That includes typical neural network
57// application domain specific features such as a bias-vector and clamp bounds,
58// as well as integer quantization parameters.
59//
60// A simple reference implementation of the operation performed by ruy::Mul
61// is provided by the ruy::ReferenceMul function in reference_mul.h.
62//
63// The `context` argument can be any ruy::Context object as long as no other
64// thread is going to concurrently access that ruy::Context. The simplest
65// correct (but not efficient) calling pattern is
66//
67// ruy::Context context;
68// ruy::Mul(lhs, rhs, mul_params, &context, dst);
69//
70// However, creating and destroying a new context everytime is inefficient
71// because it doesn't allow for resources to persist across ruy calls. Such
72// resources may include heap allocations, a thread pool, and hardware detection
73// results, and can be expensive to obtain. So the recommended usage pattern is
74// more like this:
75//
76// // Once during initialization:
77// ruy::Context* context = new ruy::Context;
78//
79// // Many times
80// ruy::Mul(lhs, rhs, mul_params, context, dst);
81//
82// If multiple threads may concurrently be calling ruy::Mul, they must either
83// use separate Contexts, or use a lock to ensure that no two threads are
84// concurrently accessing the Context object. There is no lock inside Context,
85// nothing is done to ensure reentrancy with shared Context objects.
86//
87// Ruy defaults to using only 1 thread. Multi-threading is always opted in to,
88// by calling Context::set_max_num_threads() with an explicit thread count.
89// If multiple threads may concurrently be calling ruy::Mul, it is advisable
90// to set up their respective Context objects with set_max_num_threads so that
91// the overall number of threads doesn't exceed the overall number of threads
92// that the system can usefully execute concurrently
93// (e.g. the number of CPU cores in typical scenarios). At least ruy forces
94// each invocation to make an explicit decision here, there is no automatic
95// detection of the best number of threads to use in ruy.
96template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
97 typename DstScalar>
98void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
99 const MulParams<AccumScalar, DstScalar>& mul_params, Context* context,
100 Matrix<DstScalar>* dst) {
101 Mul<kDefaultPaths>(lhs, rhs, mul_params, context, dst);
102}
103
104} // namespace ruy
105
106#endif // RUY_RUY_RUY_H_
107