1#pragma once
2
3#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4#include <torch/csrc/distributed/c10d/comm.hpp>
5
6namespace c10d {
7
8enum class BuiltinCommHookType {
9 ALLREDUCE = 1,
10 FP16_COMPRESS = 2,
11};
12
13class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
14 public:
15 explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
16 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
17
18 ~AllReduceCommHook() override = default;
19
20 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
21};
22
23class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
24 public:
25 explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
26 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
27
28 ~FP16CompressCommHook() override = default;
29
30 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
31};
32
33// Almost same as AllReduceCommHook, but without division inside the hook.
34// This enables the optimization of fusing copy and division and saves one scan
35// over all the input parameters, when no communication hook is provided by the user.
36// Only used internally and not released as a public built-in communication hook.
37class _AllReduceBySumCommHook
38 : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
39 public:
40 explicit _AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
41 : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
42
43 ~_AllReduceBySumCommHook() override = default;
44
45 c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
46};
47
48} // namespace c10d
49