1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #include <ATen/functorch/BatchRulesHelper.h> |
8 | #include <ATen/functorch/PlumbingHelper.h> |
9 | #include <ATen/Operators.h> |
10 | |
11 | // NB: most activation functions fit pointwise unary or binary rules. |
12 | // These are only the ones that have special batch rules to help with organization |
13 | namespace at { namespace functorch { |
14 | std::tuple<Tensor,optional<int64_t>> |
15 | glu_batch_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t dim) { |
16 | // repeated error message from glu because 0D -> 1D when batched |
17 | // this can't pass anyway because a 0-dimensional tensor has "size" 1, which |
18 | // can't be evenly halved, but give a nicer error message here. |
19 | TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors" ); |
20 | |
21 | const auto rank = rankWithoutBatchDim(self, self_bdim); |
22 | const auto dim_ = maybe_wrap_dim(dim, rank) + 1; |
23 | |
24 | const auto self_ = moveBatchDimToFront(self, self_bdim); |
25 | |
26 | const auto res = at::glu(self_, dim_); |
27 | return std::make_tuple(res, 0); |
28 | } |
29 | |
30 | std::tuple<Tensor,optional<int64_t>> glu_backward_batch_rule( |
31 | const Tensor& grad_output, optional<int64_t> grad_output_bdim, |
32 | const Tensor& self, optional<int64_t> self_bdim, int64_t dim) { |
33 | if (self_bdim) { |
34 | // repeated error message from glu because 0D -> 1D when batched |
35 | // this can't pass anyway because a 0-dimensional tensor has "size" 1, which |
36 | // can't be evenly halved, but give a nicer error message here. |
37 | TORCH_CHECK(self.dim() > 1, "glu does not support 0-dimensional tensors" ); |
38 | } |
39 | |
40 | const auto rank = rankWithoutBatchDim(self, self_bdim); |
41 | const auto dim_ = maybe_wrap_dim(dim, rank) + 1; |
42 | |
43 | const auto batch_size = get_bdim_size2(grad_output, grad_output_bdim, self, self_bdim); |
44 | const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size); |
45 | const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size); |
46 | |
47 | const auto res = at::glu_backward(grad_output_, self_, dim_); |
48 | return std::make_tuple(res, 0); |
49 | } |
50 | |
51 | |
52 | TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { |
53 | VMAP_SUPPORT(glu_backward, glu_backward_batch_rule); |
54 | VMAP_SUPPORT(glu, glu_batch_rule); |
55 | } |
56 | }} // namespace at::functorch |
57 | |