1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | |
8 | #include <cmath> |
9 | #include <cstdlib> |
10 | #include <random> |
11 | |
12 | using namespace torch::nn; |
13 | using namespace torch::test; |
14 | |
15 | const double kPi = 3.1415926535898; |
16 | |
17 | class CartPole { |
18 | // Translated from openai/gym's cartpole.py |
19 | public: |
20 | double gravity = 9.8; |
21 | double masscart = 1.0; |
22 | double masspole = 0.1; |
23 | double total_mass = (masspole + masscart); |
24 | double length = 0.5; // actually half the pole's length; |
25 | double polemass_length = (masspole * length); |
26 | double force_mag = 10.0; |
27 | double tau = 0.02; // seconds between state updates; |
28 | |
29 | // Angle at which to fail the episode |
30 | double theta_threshold_radians = 12 * 2 * kPi / 360; |
31 | double x_threshold = 2.4; |
32 | int steps_beyond_done = -1; |
33 | |
34 | torch::Tensor state; |
35 | double reward; |
36 | bool done; |
37 | int step_ = 0; |
38 | |
39 | torch::Tensor getState() { |
40 | return state; |
41 | } |
42 | |
43 | double getReward() { |
44 | return reward; |
45 | } |
46 | |
47 | double isDone() { |
48 | return done; |
49 | } |
50 | |
51 | void reset() { |
52 | state = torch::empty({4}).uniform_(-0.05, 0.05); |
53 | steps_beyond_done = -1; |
54 | step_ = 0; |
55 | } |
56 | |
57 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
58 | CartPole() { |
59 | reset(); |
60 | } |
61 | |
62 | void step(int action) { |
63 | auto x = state[0].item<float>(); |
64 | auto x_dot = state[1].item<float>(); |
65 | auto theta = state[2].item<float>(); |
66 | auto theta_dot = state[3].item<float>(); |
67 | |
68 | auto force = (action == 1) ? force_mag : -force_mag; |
69 | auto costheta = std::cos(theta); |
70 | auto sintheta = std::sin(theta); |
71 | auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / |
72 | total_mass; |
73 | auto thetaacc = (gravity * sintheta - costheta * temp) / |
74 | (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass)); |
75 | auto xacc = temp - polemass_length * thetaacc * costheta / total_mass; |
76 | |
77 | x = x + tau * x_dot; |
78 | x_dot = x_dot + tau * xacc; |
79 | theta = theta + tau * theta_dot; |
80 | theta_dot = theta_dot + tau * thetaacc; |
81 | state = torch::tensor({x, x_dot, theta, theta_dot}); |
82 | |
83 | done = x < -x_threshold || x > x_threshold || |
84 | theta < -theta_threshold_radians || theta > theta_threshold_radians || |
85 | step_ > 200; |
86 | |
87 | if (!done) { |
88 | reward = 1.0; |
89 | } else if (steps_beyond_done == -1) { |
90 | // Pole just fell! |
91 | steps_beyond_done = 0; |
92 | reward = 0; |
93 | } else { |
94 | if (steps_beyond_done == 0) { |
95 | AT_ASSERT(false); // Can't do this |
96 | } |
97 | } |
98 | step_++; |
99 | } |
100 | }; |
101 | |
102 | template <typename M, typename F, typename O> |
103 | bool test_mnist( |
104 | size_t batch_size, |
105 | size_t number_of_epochs, |
106 | bool with_cuda, |
107 | M&& model, |
108 | F&& forward_op, |
109 | O&& optimizer) { |
110 | std::string mnist_path = "mnist" ; |
111 | if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH" )) { |
112 | mnist_path = user_mnist_path; |
113 | } |
114 | |
115 | auto train_dataset = |
116 | torch::data::datasets::MNIST( |
117 | mnist_path, torch::data::datasets::MNIST::Mode::kTrain) |
118 | .map(torch::data::transforms::Stack<>()); |
119 | |
120 | auto data_loader = |
121 | torch::data::make_data_loader(std::move(train_dataset), batch_size); |
122 | |
123 | torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU); |
124 | model->to(device); |
125 | |
126 | for (const auto epoch : c10::irange(number_of_epochs)) { |
127 | (void)epoch; // Suppress unused variable warning |
128 | // NOLINTNEXTLINE(performance-for-range-copy) |
129 | for (torch::data::Example<> batch : *data_loader) { |
130 | auto data = batch.data.to(device); |
131 | auto targets = batch.target.to(device); |
132 | torch::Tensor prediction = forward_op(std::move(data)); |
133 | // NOLINTNEXTLINE(performance-move-const-arg) |
134 | torch::Tensor loss = torch::nll_loss(prediction, std::move(targets)); |
135 | AT_ASSERT(!torch::isnan(loss).any().item<int64_t>()); |
136 | optimizer.zero_grad(); |
137 | loss.backward(); |
138 | optimizer.step(); |
139 | } |
140 | } |
141 | |
142 | torch::NoGradGuard guard; |
143 | torch::data::datasets::MNIST test_dataset( |
144 | mnist_path, torch::data::datasets::MNIST::Mode::kTest); |
145 | auto images = test_dataset.images().to(device), |
146 | targets = test_dataset.targets().to(device); |
147 | |
148 | auto result = std::get<1>(forward_op(images).max(/*dim=*/1)); |
149 | torch::Tensor correct = (result == targets).to(torch::kFloat32); |
150 | return correct.sum().item<float>() > (test_dataset.size().value() * 0.8); |
151 | } |
152 | |
153 | struct IntegrationTest : torch::test::SeedingFixture {}; |
154 | |
155 | TEST_F(IntegrationTest, CartPole) { |
156 | torch::manual_seed(0); |
157 | auto model = std::make_shared<SimpleContainer>(); |
158 | auto linear = model->add(Linear(4, 128), "linear" ); |
159 | auto policyHead = model->add(Linear(128, 2), "policy" ); |
160 | auto valueHead = model->add(Linear(128, 1), "action" ); |
161 | auto optimizer = torch::optim::Adam(model->parameters(), 1e-3); |
162 | |
163 | std::vector<torch::Tensor> saved_log_probs; |
164 | std::vector<torch::Tensor> saved_values; |
165 | std::vector<float> rewards; |
166 | |
167 | auto forward = [&](torch::Tensor inp) { |
168 | auto x = linear->forward(inp).clamp_min(0); |
169 | torch::Tensor actions = policyHead->forward(x); |
170 | torch::Tensor value = valueHead->forward(x); |
171 | return std::make_tuple(torch::softmax(actions, -1), value); |
172 | }; |
173 | |
174 | auto selectAction = [&](torch::Tensor state) { |
175 | // Only work on single state right now, change index to gather for batch |
176 | auto out = forward(state); |
177 | auto probs = torch::Tensor(std::get<0>(out)); |
178 | auto value = torch::Tensor(std::get<1>(out)); |
179 | auto action = probs.multinomial(1)[0].item<int32_t>(); |
180 | // Compute the log prob of a multinomial distribution. |
181 | // This should probably be actually implemented in autogradpp... |
182 | auto p = probs / probs.sum(-1, true); |
183 | auto log_prob = p[action].log(); |
184 | saved_log_probs.emplace_back(log_prob); |
185 | saved_values.push_back(value); |
186 | return action; |
187 | }; |
188 | |
189 | auto finishEpisode = [&] { |
190 | auto R = 0.; |
191 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
192 | for (int i = rewards.size() - 1; i >= 0; i--) { |
193 | R = rewards[i] + 0.99 * R; |
194 | rewards[i] = R; |
195 | } |
196 | auto r_t = torch::from_blob( |
197 | rewards.data(), {static_cast<int64_t>(rewards.size())}); |
198 | r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5); |
199 | |
200 | std::vector<torch::Tensor> policy_loss; |
201 | std::vector<torch::Tensor> value_loss; |
202 | for (const auto i : c10::irange(0U, saved_log_probs.size())) { |
203 | auto advantage = r_t[i] - saved_values[i].item<float>(); |
204 | policy_loss.push_back(-advantage * saved_log_probs[i]); |
205 | value_loss.push_back( |
206 | torch::smooth_l1_loss(saved_values[i], torch::ones(1) * r_t[i])); |
207 | } |
208 | |
209 | auto loss = |
210 | torch::stack(policy_loss).sum() + torch::stack(value_loss).sum(); |
211 | |
212 | optimizer.zero_grad(); |
213 | loss.backward(); |
214 | optimizer.step(); |
215 | |
216 | rewards.clear(); |
217 | saved_log_probs.clear(); |
218 | saved_values.clear(); |
219 | }; |
220 | |
221 | auto env = CartPole(); |
222 | double running_reward = 10.0; |
223 | for (size_t episode = 0;; episode++) { |
224 | env.reset(); |
225 | auto state = env.getState(); |
226 | int t = 0; |
227 | for (; t < 10000; t++) { |
228 | auto action = selectAction(state); |
229 | env.step(action); |
230 | state = env.getState(); |
231 | auto reward = env.getReward(); |
232 | auto done = env.isDone(); |
233 | |
234 | rewards.push_back(reward); |
235 | if (done) |
236 | break; |
237 | } |
238 | |
239 | running_reward = running_reward * 0.99 + t * 0.01; |
240 | finishEpisode(); |
241 | /* |
242 | if (episode % 10 == 0) { |
243 | printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n", |
244 | episode, t, running_reward); |
245 | } |
246 | */ |
247 | if (running_reward > 150) { |
248 | break; |
249 | } |
250 | ASSERT_LT(episode, 3000); |
251 | } |
252 | } |
253 | |
254 | TEST_F(IntegrationTest, MNIST_CUDA) { |
255 | torch::manual_seed(0); |
256 | auto model = std::make_shared<SimpleContainer>(); |
257 | auto conv1 = model->add(Conv2d(1, 10, 5), "conv1" ); |
258 | auto conv2 = model->add(Conv2d(10, 20, 5), "conv2" ); |
259 | auto drop = Dropout(0.3); |
260 | auto drop2d = Dropout2d(0.3); |
261 | auto linear1 = model->add(Linear(320, 50), "linear1" ); |
262 | auto linear2 = model->add(Linear(50, 10), "linear2" ); |
263 | |
264 | auto forward = [&](torch::Tensor x) { |
265 | x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu(); |
266 | x = conv2->forward(x); |
267 | x = drop2d->forward(x); |
268 | x = torch::max_pool2d(x, {2, 2}).relu(); |
269 | |
270 | x = x.view({-1, 320}); |
271 | x = linear1->forward(x).clamp_min(0); |
272 | x = drop->forward(x); |
273 | x = linear2->forward(x); |
274 | x = torch::log_softmax(x, 1); |
275 | return x; |
276 | }; |
277 | |
278 | auto optimizer = torch::optim::SGD( |
279 | model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5)); |
280 | |
281 | ASSERT_TRUE(test_mnist( |
282 | 32, // batch_size |
283 | 3, // number_of_epochs |
284 | true, // with_cuda |
285 | model, |
286 | forward, |
287 | optimizer)); |
288 | } |
289 | |
290 | TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) { |
291 | torch::manual_seed(0); |
292 | auto model = std::make_shared<SimpleContainer>(); |
293 | auto conv1 = model->add(Conv2d(1, 10, 5), "conv1" ); |
294 | auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d" ); |
295 | auto conv2 = model->add(Conv2d(10, 20, 5), "conv2" ); |
296 | auto linear1 = model->add(Linear(320, 50), "linear1" ); |
297 | auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1" ); |
298 | auto linear2 = model->add(Linear(50, 10), "linear2" ); |
299 | |
300 | auto forward = [&](torch::Tensor x) { |
301 | x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu(); |
302 | x = batchnorm2d->forward(x); |
303 | x = conv2->forward(x); |
304 | x = torch::max_pool2d(x, {2, 2}).relu(); |
305 | |
306 | x = x.view({-1, 320}); |
307 | x = linear1->forward(x).clamp_min(0); |
308 | x = batchnorm1->forward(x); |
309 | x = linear2->forward(x); |
310 | x = torch::log_softmax(x, 1); |
311 | return x; |
312 | }; |
313 | |
314 | auto optimizer = torch::optim::SGD( |
315 | model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5)); |
316 | |
317 | ASSERT_TRUE(test_mnist( |
318 | 32, // batch_size |
319 | 3, // number_of_epochs |
320 | true, // with_cuda |
321 | model, |
322 | forward, |
323 | optimizer)); |
324 | } |
325 | |