1// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#include <ATen/functorch/BatchRulesHelper.h>
8#include <ATen/functorch/PlumbingHelper.h>
9#include <ATen/Operators.h>
10
11// NB: most activation functions fit pointwise unary or binary rules.
12// These are only the ones that have special batch rules to help with organization
13namespace at { namespace functorch {
14std::tuple<Tensor,optional<int64_t>>
15glu_batch_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t dim) {
16 // repeated error message from glu because 0D -> 1D when batched
17 // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
18 // can't be evenly halved, but give a nicer error message here.
19 TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
20
21 const auto rank = rankWithoutBatchDim(self, self_bdim);
22 const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
23
24 const auto self_ = moveBatchDimToFront(self, self_bdim);
25
26 const auto res = at::glu(self_, dim_);
27 return std::make_tuple(res, 0);
28}
29
30std::tuple<Tensor,optional<int64_t>> glu_backward_batch_rule(
31 const Tensor& grad_output, optional<int64_t> grad_output_bdim,
32 const Tensor& self, optional<int64_t> self_bdim, int64_t dim) {
33 if (self_bdim) {
34 // repeated error message from glu because 0D -> 1D when batched
35 // this can't pass anyway because a 0-dimensional tensor has "size" 1, which
36 // can't be evenly halved, but give a nicer error message here.
37 TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors");
38 }
39
40 const auto rank = rankWithoutBatchDim(self, self_bdim);
41 const auto dim_ = maybe_wrap_dim(dim, rank) + 1;
42
43 const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim);
44 const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size);
45 const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size);
46
47 const auto res = at::glu_backward(grad_output_, self_, dim_);
48 return std::make_tuple(res, 0);
49}
50
51
52TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
53 VMAP_SUPPORT(glu_backward, glu_backward_batch_rule);
54 VMAP_SUPPORT(glu, glu_batch_rule);
55}
56}} // namespace at::functorch
57