1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/contrib/cudnn/softmax.cc
22 * \brief Use external cudnn softmax function
23 */
24#include <tvm/runtime/device_api.h>
25#include <tvm/runtime/registry.h>
26
27#include "cudnn_utils.h"
28
29namespace tvm {
30namespace contrib {
31
32using namespace runtime;
33
34void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) {
35 DLTensor* x = args[0];
36 DLTensor* y = args[1];
37 int axis = args[2];
38 int ndim = x->ndim;
39 int64_t* shape = x->shape;
40 if (axis < 0) axis += ndim;
41 ICHECK(axis >= 0 && axis < ndim);
42
43 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
44 entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
45
46 // Set mode and shape descriptor
47 if (axis == ndim - 1) {
48 int64_t N = 1;
49 for (int i = 0; i < ndim - 1; ++i) {
50 N *= shape[i];
51 }
52 entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
53 CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
54 entry_ptr->softmax_entry.data_type, static_cast<int>(N),
55 static_cast<int>(shape[ndim - 1]), 1, 1));
56 } else {
57 int64_t pre_axis_dim = 1;
58 int64_t post_axis_dim = 1;
59 for (int i = 0; i < ndim; ++i) {
60 if (i < axis) {
61 pre_axis_dim *= shape[i];
62 } else if (i > axis) {
63 post_axis_dim *= shape[i];
64 }
65 }
66 entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
67 CUDNN_CALL(cudnnSetTensor4dDescriptor(
68 entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
69 static_cast<int>(pre_axis_dim), static_cast<int>(shape[axis]),
70 static_cast<int>(post_axis_dim), 1));
71 }
72
73 auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
74 auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
75 CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, alg, entry_ptr->softmax_entry.mode, alpha,
76 entry_ptr->softmax_entry.shape_desc, x->data, beta,
77 entry_ptr->softmax_entry.shape_desc, y->data));
78}
79
80TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
81 .set_body([](TVMArgs args, TVMRetValue* ret) {
82 softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret);
83 });
84
85TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward")
86 .set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); });
87
88} // namespace contrib
89} // namespace tvm
90