1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief External function interface to rocBLAS libraries
22 * \file tags.h
23 */
24#ifndef TVM_TOPI_CONTRIB_ROCBLAS_H_
25#define TVM_TOPI_CONTRIB_ROCBLAS_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/topi/detail/extern.h>
29
30namespace tvm {
31namespace topi {
32namespace contrib {
33
34using namespace tvm::te;
35/*!
36 * \brief Create an op that multiplies lhs and rhs with rocBLAS
37 *
38 * \param lhs The left matrix operand
39 * \param rhs The right matrix operand
40 * \param transa Whether to transpose lhs
41 * \param transb Whether to transpose rhs
42 *
43 * \return The output tensor
44 */
45inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
46 auto n = transa ? lhs->shape[1] : lhs->shape[0];
47 auto m = transb ? rhs->shape[0] : rhs->shape[1];
48
49 return make_extern(
50 {{n, m}}, {lhs->dtype}, {lhs, rhs},
51 [&](Array<Buffer> ins, Array<Buffer> outs) {
52 return call_packed({StringImm("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]),
53 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
54 },
55 "C", "", {})[0];
56}
57/*!
58 * \brief Create an op that batch multiplies lhs and rhs with rocBLAS
59 *
60 * \param lhs The left matrix operand e.g. (batch_size, M, K)
61 * \param rhs The right matrix operand e.g. (batch_size, K, N)
62 * \param transa Whether to transpose lhs
63 * \param transb Whether to transpose rhs
64 *
65 * \return The output tensor
66 */
67inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
68 auto batch_size = lhs->shape[0];
69 auto n = transa ? lhs->shape[2] : lhs->shape[1];
70 auto m = transb ? rhs->shape[1] : rhs->shape[2];
71
72 return make_extern(
73 {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
74 [&](Array<Buffer> ins, Array<Buffer> outs) {
75 return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
76 pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
77 },
78 "C", "", {})[0];
79}
80
81} // namespace contrib
82} // namespace topi
83} // namespace tvm
84
85#endif // TVM_TOPI_CONTRIB_ROCBLAS_H_
86