1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7
8#include <cmath>
9#include <cstdlib>
10#include <random>
11
12using namespace torch::nn;
13using namespace torch::test;
14
15const double kPi = 3.1415926535898;
16
17class CartPole {
18 // Translated from openai/gym's cartpole.py
19 public:
20 double gravity = 9.8;
21 double masscart = 1.0;
22 double masspole = 0.1;
23 double total_mass = (masspole + masscart);
24 double length = 0.5; // actually half the pole's length;
25 double polemass_length = (masspole * length);
26 double force_mag = 10.0;
27 double tau = 0.02; // seconds between state updates;
28
29 // Angle at which to fail the episode
30 double theta_threshold_radians = 12 * 2 * kPi / 360;
31 double x_threshold = 2.4;
32 int steps_beyond_done = -1;
33
34 torch::Tensor state;
35 double reward;
36 bool done;
37 int step_ = 0;
38
39 torch::Tensor getState() {
40 return state;
41 }
42
43 double getReward() {
44 return reward;
45 }
46
47 double isDone() {
48 return done;
49 }
50
51 void reset() {
52 state = torch::empty({4}).uniform_(-0.05, 0.05);
53 steps_beyond_done = -1;
54 step_ = 0;
55 }
56
57 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
58 CartPole() {
59 reset();
60 }
61
62 void step(int action) {
63 auto x = state[0].item<float>();
64 auto x_dot = state[1].item<float>();
65 auto theta = state[2].item<float>();
66 auto theta_dot = state[3].item<float>();
67
68 auto force = (action == 1) ? force_mag : -force_mag;
69 auto costheta = std::cos(theta);
70 auto sintheta = std::sin(theta);
71 auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
72 total_mass;
73 auto thetaacc = (gravity * sintheta - costheta * temp) /
74 (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
75 auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
76
77 x = x + tau * x_dot;
78 x_dot = x_dot + tau * xacc;
79 theta = theta + tau * theta_dot;
80 theta_dot = theta_dot + tau * thetaacc;
81 state = torch::tensor({x, x_dot, theta, theta_dot});
82
83 done = x < -x_threshold || x > x_threshold ||
84 theta < -theta_threshold_radians || theta > theta_threshold_radians ||
85 step_ > 200;
86
87 if (!done) {
88 reward = 1.0;
89 } else if (steps_beyond_done == -1) {
90 // Pole just fell!
91 steps_beyond_done = 0;
92 reward = 0;
93 } else {
94 if (steps_beyond_done == 0) {
95 AT_ASSERT(false); // Can't do this
96 }
97 }
98 step_++;
99 }
100};
101
102template <typename M, typename F, typename O>
103bool test_mnist(
104 size_t batch_size,
105 size_t number_of_epochs,
106 bool with_cuda,
107 M&& model,
108 F&& forward_op,
109 O&& optimizer) {
110 std::string mnist_path = "mnist";
111 if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH")) {
112 mnist_path = user_mnist_path;
113 }
114
115 auto train_dataset =
116 torch::data::datasets::MNIST(
117 mnist_path, torch::data::datasets::MNIST::Mode::kTrain)
118 .map(torch::data::transforms::Stack<>());
119
120 auto data_loader =
121 torch::data::make_data_loader(std::move(train_dataset), batch_size);
122
123 torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
124 model->to(device);
125
126 for (const auto epoch : c10::irange(number_of_epochs)) {
127 (void)epoch; // Suppress unused variable warning
128 // NOLINTNEXTLINE(performance-for-range-copy)
129 for (torch::data::Example<> batch : *data_loader) {
130 auto data = batch.data.to(device);
131 auto targets = batch.target.to(device);
132 torch::Tensor prediction = forward_op(std::move(data));
133 // NOLINTNEXTLINE(performance-move-const-arg)
134 torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
135 AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
136 optimizer.zero_grad();
137 loss.backward();
138 optimizer.step();
139 }
140 }
141
142 torch::NoGradGuard guard;
143 torch::data::datasets::MNIST test_dataset(
144 mnist_path, torch::data::datasets::MNIST::Mode::kTest);
145 auto images = test_dataset.images().to(device),
146 targets = test_dataset.targets().to(device);
147
148 auto result = std::get<1>(forward_op(images).max(/*dim=*/1));
149 torch::Tensor correct = (result == targets).to(torch::kFloat32);
150 return correct.sum().item<float>() > (test_dataset.size().value() * 0.8);
151}
152
153struct IntegrationTest : torch::test::SeedingFixture {};
154
155TEST_F(IntegrationTest, CartPole) {
156 torch::manual_seed(0);
157 auto model = std::make_shared<SimpleContainer>();
158 auto linear = model->add(Linear(4, 128), "linear");
159 auto policyHead = model->add(Linear(128, 2), "policy");
160 auto valueHead = model->add(Linear(128, 1), "action");
161 auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
162
163 std::vector<torch::Tensor> saved_log_probs;
164 std::vector<torch::Tensor> saved_values;
165 std::vector<float> rewards;
166
167 auto forward = [&](torch::Tensor inp) {
168 auto x = linear->forward(inp).clamp_min(0);
169 torch::Tensor actions = policyHead->forward(x);
170 torch::Tensor value = valueHead->forward(x);
171 return std::make_tuple(torch::softmax(actions, -1), value);
172 };
173
174 auto selectAction = [&](torch::Tensor state) {
175 // Only work on single state right now, change index to gather for batch
176 auto out = forward(state);
177 auto probs = torch::Tensor(std::get<0>(out));
178 auto value = torch::Tensor(std::get<1>(out));
179 auto action = probs.multinomial(1)[0].item<int32_t>();
180 // Compute the log prob of a multinomial distribution.
181 // This should probably be actually implemented in autogradpp...
182 auto p = probs / probs.sum(-1, true);
183 auto log_prob = p[action].log();
184 saved_log_probs.emplace_back(log_prob);
185 saved_values.push_back(value);
186 return action;
187 };
188
189 auto finishEpisode = [&] {
190 auto R = 0.;
191 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
192 for (int i = rewards.size() - 1; i >= 0; i--) {
193 R = rewards[i] + 0.99 * R;
194 rewards[i] = R;
195 }
196 auto r_t = torch::from_blob(
197 rewards.data(), {static_cast<int64_t>(rewards.size())});
198 r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
199
200 std::vector<torch::Tensor> policy_loss;
201 std::vector<torch::Tensor> value_loss;
202 for (const auto i : c10::irange(0U, saved_log_probs.size())) {
203 auto advantage = r_t[i] - saved_values[i].item<float>();
204 policy_loss.push_back(-advantage * saved_log_probs[i]);
205 value_loss.push_back(
206 torch::smooth_l1_loss(saved_values[i], torch::ones(1) * r_t[i]));
207 }
208
209 auto loss =
210 torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
211
212 optimizer.zero_grad();
213 loss.backward();
214 optimizer.step();
215
216 rewards.clear();
217 saved_log_probs.clear();
218 saved_values.clear();
219 };
220
221 auto env = CartPole();
222 double running_reward = 10.0;
223 for (size_t episode = 0;; episode++) {
224 env.reset();
225 auto state = env.getState();
226 int t = 0;
227 for (; t < 10000; t++) {
228 auto action = selectAction(state);
229 env.step(action);
230 state = env.getState();
231 auto reward = env.getReward();
232 auto done = env.isDone();
233
234 rewards.push_back(reward);
235 if (done)
236 break;
237 }
238
239 running_reward = running_reward * 0.99 + t * 0.01;
240 finishEpisode();
241 /*
242 if (episode % 10 == 0) {
243 printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
244 episode, t, running_reward);
245 }
246 */
247 if (running_reward > 150) {
248 break;
249 }
250 ASSERT_LT(episode, 3000);
251 }
252}
253
254TEST_F(IntegrationTest, MNIST_CUDA) {
255 torch::manual_seed(0);
256 auto model = std::make_shared<SimpleContainer>();
257 auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
258 auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
259 auto drop = Dropout(0.3);
260 auto drop2d = Dropout2d(0.3);
261 auto linear1 = model->add(Linear(320, 50), "linear1");
262 auto linear2 = model->add(Linear(50, 10), "linear2");
263
264 auto forward = [&](torch::Tensor x) {
265 x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
266 x = conv2->forward(x);
267 x = drop2d->forward(x);
268 x = torch::max_pool2d(x, {2, 2}).relu();
269
270 x = x.view({-1, 320});
271 x = linear1->forward(x).clamp_min(0);
272 x = drop->forward(x);
273 x = linear2->forward(x);
274 x = torch::log_softmax(x, 1);
275 return x;
276 };
277
278 auto optimizer = torch::optim::SGD(
279 model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
280
281 ASSERT_TRUE(test_mnist(
282 32, // batch_size
283 3, // number_of_epochs
284 true, // with_cuda
285 model,
286 forward,
287 optimizer));
288}
289
290TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
291 torch::manual_seed(0);
292 auto model = std::make_shared<SimpleContainer>();
293 auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
294 auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
295 auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
296 auto linear1 = model->add(Linear(320, 50), "linear1");
297 auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
298 auto linear2 = model->add(Linear(50, 10), "linear2");
299
300 auto forward = [&](torch::Tensor x) {
301 x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
302 x = batchnorm2d->forward(x);
303 x = conv2->forward(x);
304 x = torch::max_pool2d(x, {2, 2}).relu();
305
306 x = x.view({-1, 320});
307 x = linear1->forward(x).clamp_min(0);
308 x = batchnorm1->forward(x);
309 x = linear2->forward(x);
310 x = torch::log_softmax(x, 1);
311 return x;
312 };
313
314 auto optimizer = torch::optim::SGD(
315 model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
316
317 ASSERT_TRUE(test_mnist(
318 32, // batch_size
319 3, // number_of_epochs
320 true, // with_cuda
321 model,
322 forward,
323 optimizer));
324}
325