1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
4 | #include <torch/csrc/distributed/c10d/comm.hpp> |
5 | |
6 | namespace c10d { |
7 | |
8 | enum class BuiltinCommHookType { |
9 | ALLREDUCE = 1, |
10 | FP16_COMPRESS = 2, |
11 | }; |
12 | |
13 | class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { |
14 | public: |
15 | explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state) |
16 | : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} |
17 | |
18 | ~AllReduceCommHook() override = default; |
19 | |
20 | c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; |
21 | }; |
22 | |
23 | class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { |
24 | public: |
25 | explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state) |
26 | : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} |
27 | |
28 | ~FP16CompressCommHook() override = default; |
29 | |
30 | c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; |
31 | }; |
32 | |
33 | // Almost same as AllReduceCommHook, but without division inside the hook. |
34 | // This enables the optimization of fusing copy and division and saves one scan |
35 | // over all the input parameters, when no communication hook is provided by the user. |
36 | // Only used internally and not released as a public built-in communication hook. |
37 | class _AllReduceBySumCommHook |
38 | : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { |
39 | public: |
40 | explicit _AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup>& state) |
41 | : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} |
42 | |
43 | ~_AllReduceBySumCommHook() override = default; |
44 | |
45 | c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; |
46 | }; |
47 | |
48 | } // namespace c10d |
49 |