1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file cuda/dense.h
22 * \brief CUDA schedule for dense operation
23 */
24#ifndef TVM_TOPI_CUDA_DENSE_H_
25#define TVM_TOPI_CUDA_DENSE_H_
26
27#include <tvm/target/generic_func.h>
28#include <tvm/te/operation.h>
29#include <tvm/te/schedule_pass.h>
30#include <tvm/topi/contrib/cublas.h>
31#include <tvm/topi/detail/array_utils.h>
32#include <tvm/topi/generic/extern.h>
33#include <tvm/topi/nn/dense.h>
34#include <tvm/topi/tags.h>
35
36namespace tvm {
37namespace topi {
38
39using namespace tvm::te;
40
41namespace cuda {
42/*!
43 * \brief Implementation of dense for CUDA backend
44 *
45 * \param target The target device
46 * \param data Tensor with shape [batch, in_dim]
47 * \param weight Tensor with shape [out_dim, in_dim]
48 * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
49 * \param out_dtype Output data type. Used for mixed precision.
50 *
51 * \return Tensor with shape [batch, out_dim]
52 */
53inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data,
54 const tvm::te::Tensor& weight, const tvm::te::Tensor& bias,
55 const DataType& out_dtype) {
56 ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
57 ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
58 if (bias.defined()) {
59 ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
60 }
61
62 auto batch = data->shape[0];
63 auto in_dim = data->shape[1];
64 auto out_dim = weight->shape[0];
65
66 if (target->GetLibs().count("cublas")) {
67 ICHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
68 auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
69 if (bias.defined()) {
70 mm = tvm::te::compute(
71 {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast);
72 }
73
74 return mm;
75 } else {
76 return topi::nn::dense(data, weight, bias, out_dtype);
77 }
78}
79
80/*!
81 * \brief Create a CUDA schedule for dense
82 *
83 * \param target The target to generate a schedule for.
84 * \param outs The output tensors.
85 *
86 * \return A schedule for the given ops.
87 */
88inline Schedule schedule_dense(const Target& target, const Array<Tensor>& outs) {
89 if (target->kind->name == "cuda" && target->GetLibs().count("cublas")) {
90 return topi::generic::schedule_extern(target, outs);
91 }
92
93 Array<Operation> out_ops;
94 for (auto t : outs) {
95 out_ops.push_back(t->op);
96 }
97 auto s = create_schedule(out_ops);
98
99 auto _schedule = [&](const Tensor& dense) {
100 auto num_thread = 64;
101 auto k = dense->op.as<ComputeOpNode>()->reduce_axis[0];
102 IterVar ko, kf;
103 s[dense].split(k, num_thread, &ko, &kf);
104 auto dense_f = s.rfactor(dense, kf)[0];
105
106 Tensor out;
107 if (detail::contains(s->outputs, dense->op)) {
108 out = dense;
109 } else {
110 out = outs[0]->op.output(0);
111 s[dense].compute_at(s[out], s[out]->op.as<ComputeOpNode>()->axis[1]);
112 }
113 s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[0],
114 tvm::te::thread_axis(Range(), "blockIdx.y"));
115 s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[1],
116 tvm::te::thread_axis(Range(), "blockIdx.x"));
117
118 auto tx = s[dense]->op.as<ComputeOpNode>()->reduce_axis[0];
119 auto thread_x = tvm::te::thread_axis(Range(), "threadIdx.x");
120 s[dense].bind(tx, thread_x);
121 s[dense_f].compute_at(s[dense], tx);
122 s[dense].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
123 s[out].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
124 };
125
126 std::function<void(Operation)> traverse;
127 traverse = [&](const Operation& op) {
128 // Inline all one-to-one-mapping operators except the last stage (output)
129 if (is_broadcast(op->tag)) {
130 if (!detail::contains(s->outputs, op)) {
131 s[op].compute_inline();
132 }
133 for (auto tensor : op->InputTensors()) {
134 if (tensor->op->InputTensors().size() > 0) {
135 traverse(tensor->op);
136 }
137 }
138 } else if (op->tag == "dense") {
139 // If tag starts with global_pool
140 auto dense = op.output(0);
141 _schedule(dense);
142 } else {
143 LOG(ERROR) << "Unsupported operator " << op->tag;
144 }
145 };
146
147 traverse(outs[0]->op);
148 return s;
149}
150
151} // namespace cuda
152} // namespace topi
153} // namespace tvm
154#endif // TVM_TOPI_CUDA_DENSE_H_
155