1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief External function interface to cuBLAS libraries
22 * \file cublas.h
23 */
24#ifndef TVM_TOPI_CONTRIB_CUBLAS_H_
25#define TVM_TOPI_CONTRIB_CUBLAS_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/topi/detail/extern.h>
29
30namespace tvm {
31namespace topi {
32namespace contrib {
33
34using namespace tvm::te;
35using namespace topi::detail;
36/*!
37 * \brief Create an op that multiplies lhs and rhs with cuBLAS
38 *
39 * \param lhs The left matrix operand
40 * \param rhs The right matrix operand
41 * \param transa Whether to transpose lhs
42 * \param transb Whether to transpose rhs
43 *
44 * \return The output tensor
45 */
46inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
47 auto n = transa ? lhs->shape[1] : lhs->shape[0];
48 auto m = transb ? rhs->shape[0] : rhs->shape[1];
49
50 return make_extern(
51 {{n, m}}, {lhs->dtype}, {lhs, rhs},
52 [&](Array<Buffer> ins, Array<Buffer> outs) {
53 return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]),
54 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
55 },
56 "C", "", {})[0];
57}
58
59/*!
60 * \brief Create an op that multiplies batch matrices
61 * lhs and rhs with cuBLAS
62 *
63 * \param lhs The left matrix operand
64 * \param rhs The right matrix operand
65 * \param transa Whether to transpose lhs
66 * \param transb Whether to transpose rhs
67 *
68 * \return The output tensor
69 */
70inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
71 auto b = lhs->shape[0];
72 auto n = transa ? lhs->shape[2] : lhs->shape[1];
73 auto m = transb ? rhs->shape[1] : rhs->shape[2];
74
75 return make_extern(
76 {{b, n, m}}, {lhs->dtype}, {lhs, rhs},
77 [&](Array<Buffer> ins, Array<Buffer> outs) {
78 return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]),
79 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
80 },
81 "C", "", {})[0];
82}
83
84} // namespace contrib
85} // namespace topi
86} // namespace tvm
87
88#endif // TVM_TOPI_CONTRIB_CUBLAS_H_
89