1 | // @generated from optim_baseline.py |
2 | |
3 | #include <torch/types.h> |
4 | |
5 | #include <vector> |
6 | |
7 | namespace expected_parameters { |
8 | |
9 | inline std::vector<std::vector<torch::Tensor>> LBFGS() { |
10 | return { |
11 | { |
12 | torch::tensor( |
13 | {-0.20959197386869663, |
14 | -0.49580870398532073, |
15 | -0.1313442585372408, |
16 | -0.3287331939506787, |
17 | -0.24613947168465267, |
18 | 0.705889510763571}), |
19 | torch::tensor( |
20 | {-0.10412662274500666, -0.2644705062031845, 0.7102859961803084}), |
21 | torch::tensor( |
22 | {-0.19787984636009417, -0.5320223708266223, -0.5396083236337847}), |
23 | torch::tensor({-0.43108206822505857}), |
24 | }, |
25 | { |
26 | torch::tensor( |
27 | {0.4377600774755075, |
28 | 0.3828823919505896, |
29 | 0.5308031277873992, |
30 | 0.5752746453369446, |
31 | 0.23943592910168343, |
32 | 1.3739197373644627}), |
33 | torch::tensor( |
34 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
35 | torch::tensor( |
36 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
37 | torch::tensor({-4.776742087865583}), |
38 | }, |
39 | { |
40 | torch::tensor( |
41 | {0.4377600774755075, |
42 | 0.3828823919505896, |
43 | 0.5308031277873992, |
44 | 0.5752746453369446, |
45 | 0.23943592910168343, |
46 | 1.3739197373644627}), |
47 | torch::tensor( |
48 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
49 | torch::tensor( |
50 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
51 | torch::tensor({-4.776742087865583}), |
52 | }, |
53 | { |
54 | torch::tensor( |
55 | {0.4377600774755075, |
56 | 0.3828823919505896, |
57 | 0.5308031277873992, |
58 | 0.5752746453369446, |
59 | 0.23943592910168343, |
60 | 1.3739197373644627}), |
61 | torch::tensor( |
62 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
63 | torch::tensor( |
64 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
65 | torch::tensor({-4.776742087865583}), |
66 | }, |
67 | { |
68 | torch::tensor( |
69 | {0.4377600774755075, |
70 | 0.3828823919505896, |
71 | 0.5308031277873992, |
72 | 0.5752746453369446, |
73 | 0.23943592910168343, |
74 | 1.3739197373644627}), |
75 | torch::tensor( |
76 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
77 | torch::tensor( |
78 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
79 | torch::tensor({-4.776742087865583}), |
80 | }, |
81 | { |
82 | torch::tensor( |
83 | {0.4377600774755075, |
84 | 0.3828823919505896, |
85 | 0.5308031277873992, |
86 | 0.5752746453369446, |
87 | 0.23943592910168343, |
88 | 1.3739197373644627}), |
89 | torch::tensor( |
90 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
91 | torch::tensor( |
92 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
93 | torch::tensor({-4.776742087865583}), |
94 | }, |
95 | { |
96 | torch::tensor( |
97 | {0.4377600774755075, |
98 | 0.3828823919505896, |
99 | 0.5308031277873992, |
100 | 0.5752746453369446, |
101 | 0.23943592910168343, |
102 | 1.3739197373644627}), |
103 | torch::tensor( |
104 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
105 | torch::tensor( |
106 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
107 | torch::tensor({-4.776742087865583}), |
108 | }, |
109 | { |
110 | torch::tensor( |
111 | {0.4377600774755075, |
112 | 0.3828823919505896, |
113 | 0.5308031277873992, |
114 | 0.5752746453369446, |
115 | 0.23943592910168343, |
116 | 1.3739197373644627}), |
117 | torch::tensor( |
118 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
119 | torch::tensor( |
120 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
121 | torch::tensor({-4.776742087865583}), |
122 | }, |
123 | { |
124 | torch::tensor( |
125 | {0.4377600774755075, |
126 | 0.3828823919505896, |
127 | 0.5308031277873992, |
128 | 0.5752746453369446, |
129 | 0.23943592910168343, |
130 | 1.3739197373644627}), |
131 | torch::tensor( |
132 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
133 | torch::tensor( |
134 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
135 | torch::tensor({-4.776742087865583}), |
136 | }, |
137 | { |
138 | torch::tensor( |
139 | {0.4377600774755075, |
140 | 0.3828823919505896, |
141 | 0.5308031277873992, |
142 | 0.5752746453369446, |
143 | 0.23943592910168343, |
144 | 1.3739197373644627}), |
145 | torch::tensor( |
146 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
147 | torch::tensor( |
148 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
149 | torch::tensor({-4.776742087865583}), |
150 | }, |
151 | { |
152 | torch::tensor( |
153 | {0.4377600774755075, |
154 | 0.3828823919505896, |
155 | 0.5308031277873992, |
156 | 0.5752746453369446, |
157 | 0.23943592910168343, |
158 | 1.3739197373644627}), |
159 | torch::tensor( |
160 | {2.209263823172053, 2.154134023426646, 2.534834254325867}), |
161 | torch::tensor( |
162 | {-4.091952315741579, -4.67916063385269, -4.781279234594454}), |
163 | torch::tensor({-4.776742087865583}), |
164 | }, |
165 | }; |
166 | } |
167 | |
168 | inline std::vector<std::vector<torch::Tensor>> LBFGS_with_line_search() { |
169 | return { |
170 | { |
171 | torch::tensor( |
172 | {-0.2108988568338871, |
173 | -0.4975560466422629, |
174 | -0.14129216202471762, |
175 | -0.3420288967865903, |
176 | -0.2523635082803723, |
177 | 0.6975570255493777}), |
178 | torch::tensor( |
179 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
180 | torch::tensor( |
181 | {-0.05080313011597659, |
182 | -0.39413518751058996, |
183 | -0.28433759745928844}), |
184 | torch::tensor({-0.07113812430174116}), |
185 | }, |
186 | { |
187 | torch::tensor( |
188 | {-0.2108988568338871, |
189 | -0.4975560466422629, |
190 | -0.14129216202471762, |
191 | -0.3420288967865903, |
192 | -0.2523635082803723, |
193 | 0.6975570255493777}), |
194 | torch::tensor( |
195 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
196 | torch::tensor( |
197 | {-0.05080313011597659, |
198 | -0.39413518751058996, |
199 | -0.28433759745928844}), |
200 | torch::tensor({-0.07113812430174116}), |
201 | }, |
202 | { |
203 | torch::tensor( |
204 | {-0.2108988568338871, |
205 | -0.4975560466422629, |
206 | -0.14129216202471762, |
207 | -0.3420288967865903, |
208 | -0.2523635082803723, |
209 | 0.6975570255493777}), |
210 | torch::tensor( |
211 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
212 | torch::tensor( |
213 | {-0.05080313011597659, |
214 | -0.39413518751058996, |
215 | -0.28433759745928844}), |
216 | torch::tensor({-0.07113812430174116}), |
217 | }, |
218 | { |
219 | torch::tensor( |
220 | {-0.2108988568338871, |
221 | -0.4975560466422629, |
222 | -0.14129216202471762, |
223 | -0.3420288967865903, |
224 | -0.2523635082803723, |
225 | 0.6975570255493777}), |
226 | torch::tensor( |
227 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
228 | torch::tensor( |
229 | {-0.05080313011597659, |
230 | -0.39413518751058996, |
231 | -0.28433759745928844}), |
232 | torch::tensor({-0.07113812430174116}), |
233 | }, |
234 | { |
235 | torch::tensor( |
236 | {-0.2108988568338871, |
237 | -0.4975560466422629, |
238 | -0.14129216202471762, |
239 | -0.3420288967865903, |
240 | -0.2523635082803723, |
241 | 0.6975570255493777}), |
242 | torch::tensor( |
243 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
244 | torch::tensor( |
245 | {-0.05080313011597659, |
246 | -0.39413518751058996, |
247 | -0.28433759745928844}), |
248 | torch::tensor({-0.07113812430174116}), |
249 | }, |
250 | { |
251 | torch::tensor( |
252 | {-0.2108988568338871, |
253 | -0.4975560466422629, |
254 | -0.14129216202471762, |
255 | -0.3420288967865903, |
256 | -0.2523635082803723, |
257 | 0.6975570255493777}), |
258 | torch::tensor( |
259 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
260 | torch::tensor( |
261 | {-0.05080313011597659, |
262 | -0.39413518751058996, |
263 | -0.28433759745928844}), |
264 | torch::tensor({-0.07113812430174116}), |
265 | }, |
266 | { |
267 | torch::tensor( |
268 | {-0.2108988568338871, |
269 | -0.4975560466422629, |
270 | -0.14129216202471762, |
271 | -0.3420288967865903, |
272 | -0.2523635082803723, |
273 | 0.6975570255493777}), |
274 | torch::tensor( |
275 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
276 | torch::tensor( |
277 | {-0.05080313011597659, |
278 | -0.39413518751058996, |
279 | -0.28433759745928844}), |
280 | torch::tensor({-0.07113812430174116}), |
281 | }, |
282 | { |
283 | torch::tensor( |
284 | {-0.2108988568338871, |
285 | -0.4975560466422629, |
286 | -0.14129216202471762, |
287 | -0.3420288967865903, |
288 | -0.2523635082803723, |
289 | 0.6975570255493777}), |
290 | torch::tensor( |
291 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
292 | torch::tensor( |
293 | {-0.05080313011597659, |
294 | -0.39413518751058996, |
295 | -0.28433759745928844}), |
296 | torch::tensor({-0.07113812430174116}), |
297 | }, |
298 | { |
299 | torch::tensor( |
300 | {-0.2108988568338871, |
301 | -0.4975560466422629, |
302 | -0.14129216202471762, |
303 | -0.3420288967865903, |
304 | -0.2523635082803723, |
305 | 0.6975570255493777}), |
306 | torch::tensor( |
307 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
308 | torch::tensor( |
309 | {-0.05080313011597659, |
310 | -0.39413518751058996, |
311 | -0.28433759745928844}), |
312 | torch::tensor({-0.07113812430174116}), |
313 | }, |
314 | { |
315 | torch::tensor( |
316 | {-0.2108988568338871, |
317 | -0.4975560466422629, |
318 | -0.14129216202471762, |
319 | -0.3420288967865903, |
320 | -0.2523635082803723, |
321 | 0.6975570255493777}), |
322 | torch::tensor( |
323 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
324 | torch::tensor( |
325 | {-0.05080313011597659, |
326 | -0.39413518751058996, |
327 | -0.28433759745928844}), |
328 | torch::tensor({-0.07113812430174116}), |
329 | }, |
330 | { |
331 | torch::tensor( |
332 | {-0.2108988568338871, |
333 | -0.4975560466422629, |
334 | -0.14129216202471762, |
335 | -0.3420288967865903, |
336 | -0.2523635082803723, |
337 | 0.6975570255493777}), |
338 | torch::tensor( |
339 | {-0.10853121966252377, -0.297948499687533, 0.6892015099955717}), |
340 | torch::tensor( |
341 | {-0.05080313011597659, |
342 | -0.39413518751058996, |
343 | -0.28433759745928844}), |
344 | torch::tensor({-0.07113812430174116}), |
345 | }, |
346 | }; |
347 | } |
348 | |
349 | inline std::vector<std::vector<torch::Tensor>> Adam() { |
350 | return { |
351 | { |
352 | torch::tensor( |
353 | {0.7890972864438472, |
354 | 0.5024410688121617, |
355 | 0.8587073313055582, |
356 | 0.6579707241208395, |
357 | 0.7476356819075531, |
358 | 1.697556420651692}), |
359 | torch::tensor( |
360 | {0.891467636010675, 0.7020513497567501, 1.6892012709428947}), |
361 | torch::tensor( |
362 | {-1.0508030958460797, -1.3941351509567657, -1.284337577714353}), |
363 | torch::tensor({-1.071138110298716}), |
364 | }, |
365 | { |
366 | torch::tensor( |
367 | {8.233039313231828, |
368 | 7.971150747377481, |
369 | 6.6436209506776, |
370 | 6.470977407900541, |
371 | 6.170125488259256, |
372 | 7.1507391033435015}), |
373 | torch::tensor( |
374 | {8.417695070103735, 6.597188212844593, 7.23175710827678}), |
375 | torch::tensor( |
376 | {-6.729624357635757, -7.09743493108154, -6.753301896575352}), |
377 | torch::tensor({-6.435639096011218}), |
378 | }, |
379 | { |
380 | torch::tensor( |
381 | {8.233424596059296, |
382 | 7.971537360032308, |
383 | 6.643920150720394, |
384 | 6.47127807553724, |
385 | 6.170405874224489, |
386 | 7.151021086137982}), |
387 | torch::tensor( |
388 | {8.418084791214294, 6.597493171180545, 7.232043740621598}), |
389 | torch::tensor( |
390 | {-6.729918250724671, -7.097730102046093, -6.753584809755359}), |
391 | torch::tensor({-6.4359165566974985}), |
392 | }, |
393 | { |
394 | torch::tensor( |
395 | {8.233424610557648, |
396 | 7.971537374586563, |
397 | 6.643920161995285, |
398 | 6.471278086877829, |
399 | 6.170405884785074, |
400 | 7.151021096766405}), |
401 | torch::tensor( |
402 | {8.418084805901902, 6.597493182713584, 7.2320437514477875}), |
403 | torch::tensor( |
404 | {-6.72991829363266, -7.097730147102975, -6.753584838821182}), |
405 | torch::tensor({-6.435916580217771}), |
406 | }, |
407 | { |
408 | torch::tensor( |
409 | {8.233424610575101, |
410 | 7.971537374611125, |
411 | 6.643920162027962, |
412 | 6.471278086923278, |
413 | 6.170405884809245, |
414 | 7.15102109680004}), |
415 | torch::tensor( |
416 | {8.418084805946389, 6.597493182796847, 7.232043751509309}), |
417 | torch::tensor( |
418 | {-6.729918332327653, -7.097730188349552, -6.753584861205486}), |
419 | torch::tensor({-6.435916596115672}), |
420 | }, |
421 | { |
422 | torch::tensor( |
423 | {8.233424610594858, |
424 | 7.971537374639166, |
425 | 6.643920162065571, |
426 | 6.471278086975759, |
427 | 6.170405884836981, |
428 | 7.1510210968387975}), |
429 | torch::tensor( |
430 | {8.418084805997614, 6.59749318289335, 7.232043751580523}), |
431 | torch::tensor( |
432 | {-6.72991837738045, -7.097730236373201, -6.753584887267492}), |
433 | torch::tensor({-6.43591661462546}), |
434 | }, |
435 | { |
436 | torch::tensor( |
437 | {8.233424610617288, |
438 | 7.971537374671012, |
439 | 6.643920162108285, |
440 | 6.471278087035362, |
441 | 6.170405884868481, |
442 | 7.151021096882811}), |
443 | torch::tensor( |
444 | {8.418084806055795, 6.59749318300295, 7.232043751661401}), |
445 | torch::tensor( |
446 | {-6.729918428547273, -7.09773029091405, -6.753584916866329}), |
447 | torch::tensor({-6.4359166356471755}), |
448 | }, |
449 | { |
450 | torch::tensor( |
451 | {8.233424610642352, |
452 | 7.9715373747065925, |
453 | 6.6439201621560064, |
454 | 6.471278087101955, |
455 | 6.1704058849036745, |
456 | 7.151021096931989}), |
457 | torch::tensor( |
458 | {8.418084806120799, 6.597493183125404, 7.232043751751764}), |
459 | torch::tensor( |
460 | {-6.729918485714688, -7.0977303518511805, -6.753584949936365}), |
461 | torch::tensor({-6.43591665913422}), |
462 | }, |
463 | { |
464 | torch::tensor( |
465 | {8.233424610670035, |
466 | 7.97153737474589, |
467 | 6.6439201622087145, |
468 | 6.471278087175502, |
469 | 6.170405884942545, |
470 | 7.151021096986302}), |
471 | torch::tensor( |
472 | {8.418084806192592, 6.597493183260647, 7.232043751851564}), |
473 | torch::tensor( |
474 | {-6.729918548853505, -7.097730419153473, -6.753584986460725}), |
475 | torch::tensor({-6.435916685074594}), |
476 | }, |
477 | { |
478 | torch::tensor( |
479 | {8.233424610700348, |
480 | 7.971537374788922, |
481 | 6.643920162266433, |
482 | 6.4712780872560405, |
483 | 6.17040588498511, |
484 | 7.151021097045779}), |
485 | torch::tensor( |
486 | {8.418084806271214, 6.597493183408747, 7.232043751960854}), |
487 | torch::tensor( |
488 | {-6.7299186179943, -7.097730492853521, -6.753585026457088}), |
489 | torch::tensor({-6.435916713480863}), |
490 | }, |
491 | { |
492 | torch::tensor( |
493 | {8.233424610733326, |
494 | 7.971537374835737, |
495 | 6.643920162329225, |
496 | 6.471278087343659, |
497 | 6.170405885031416, |
498 | 7.151021097110483}), |
499 | torch::tensor( |
500 | {8.418084806356743, 6.597493183569867, 7.232043752079749}), |
501 | torch::tensor( |
502 | {-6.729918693213275, -7.097730573032567, -6.753585069969552}), |
503 | torch::tensor({-6.43591674438434}), |
504 | }, |
505 | }; |
506 | } |
507 | |
508 | inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay() { |
509 | return { |
510 | { |
511 | torch::tensor( |
512 | {0.7890990163499767, |
513 | 0.5024427688479549, |
514 | 0.858707365154099, |
515 | 0.65797076763247, |
516 | 0.7476358193232038, |
517 | 1.6975559791029715}), |
518 | torch::tensor( |
519 | {0.8914677624298939, 0.7020513562204098, 1.6892012237887575}), |
520 | torch::tensor( |
521 | {-1.050803095786311, -1.3941351504224309, -1.2843375776028747}), |
522 | torch::tensor({-1.0711381102847533}), |
523 | }, |
524 | { |
525 | torch::tensor( |
526 | {0.17835734655765323, |
527 | 0.2542117171890537, |
528 | 0.19681971909229715, |
529 | 0.23522651199260597, |
530 | 0.17806083719648957, |
531 | 0.22943655675307303}), |
532 | torch::tensor( |
533 | {0.6227676931552837, 0.6058596954431213, 0.6077176546857177}), |
534 | torch::tensor( |
535 | {-1.4259755901844118, -1.4333355461952704, -1.408545526635006}), |
536 | torch::tensor({-2.0710783081666215}), |
537 | }, |
538 | { |
539 | torch::tensor( |
540 | {0.17965695035191162, |
541 | 0.24254352340441693, |
542 | 0.17964663531482672, |
543 | 0.24250834976541322, |
544 | 0.17962893833698693, |
545 | 0.24249920074277215}), |
546 | torch::tensor( |
547 | {0.6287144967638043, 0.6286955805603279, 0.6286563093833837}), |
548 | torch::tensor( |
549 | {-1.4123887230853596, -1.4124007126659273, -1.4122701589749163}), |
550 | torch::tensor({-2.063357041247863}), |
551 | }, |
552 | { |
553 | torch::tensor( |
554 | {0.1796366651819, |
555 | 0.24250861931831874, |
556 | 0.17963731759793083, |
557 | 0.24250861142436989, |
558 | 0.1796372002681969, |
559 | 0.24250890248031373}), |
560 | torch::tensor( |
561 | {0.6287221269294724, 0.6287225821354421, 0.6287220274975922}), |
562 | torch::tensor( |
563 | {-1.4123466103044011, -1.4123465669572683, -1.4123462614739388}), |
564 | torch::tensor({-2.063368365143669}), |
565 | }, |
566 | { |
567 | torch::tensor( |
568 | {0.17963666103165563, |
569 | 0.24250882317446784, |
570 | 0.17963665831217887, |
571 | 0.24250882481082656, |
572 | 0.17963666029066117, |
573 | 0.24250882426223175}), |
574 | torch::tensor( |
575 | {0.6287216329900817, 0.6287216340515608, 0.6287216326960158}), |
576 | torch::tensor( |
577 | {-1.4123467542623926, -1.4123467542350234, -1.4123467478191443}), |
578 | torch::tensor({-2.0633690432440437}), |
579 | }, |
580 | { |
581 | torch::tensor( |
582 | {0.17963666098500394, |
583 | 0.24250882442377164, |
584 | 0.17963666099348902, |
585 | 0.2425088244120223, |
586 | 0.1796366609725109, |
587 | 0.24250882441058697}), |
588 | torch::tensor( |
589 | {0.6287216343798432, 0.6287216343800675, 0.6287216343742645}), |
590 | torch::tensor( |
591 | {-1.4123467490742723, -1.412346749072554, -1.4123467490678536}), |
592 | torch::tensor({-2.0633690434425396}), |
593 | }, |
594 | { |
595 | torch::tensor( |
596 | {0.17963666098407144, |
597 | 0.24250882442250174, |
598 | 0.17963666098407347, |
599 | 0.2425088244224233, |
600 | 0.17963666098409325, |
601 | 0.2425088244225157}), |
602 | torch::tensor( |
603 | {0.6287216343836609, 0.6287216343836147, 0.6287216343836255}), |
604 | torch::tensor( |
605 | {-1.412346749067226, -1.412346749067243, -1.412346749067169}), |
606 | torch::tensor({-2.063369043434909}), |
607 | }, |
608 | { |
609 | torch::tensor( |
610 | {0.17963666098406988, |
611 | 0.2425088244224408, |
612 | 0.17963666098407077, |
613 | 0.24250882442244073, |
614 | 0.17963666098407008, |
615 | 0.2425088244224409}), |
616 | torch::tensor( |
617 | {0.6287216343837067, 0.6287216343837065, 0.6287216343837069}), |
618 | torch::tensor( |
619 | {-1.4123467490671706, -1.412346749067171, -1.4123467490671713}), |
620 | torch::tensor({-2.0633690434349057}), |
621 | }, |
622 | { |
623 | torch::tensor( |
624 | {0.17963666098407038, |
625 | 0.24250882442244104, |
626 | 0.17963666098407027, |
627 | 0.24250882442244104, |
628 | 0.17963666098407025, |
629 | 0.24250882442244098}), |
630 | torch::tensor( |
631 | {0.6287216343837067, 0.628721634383707, 0.6287216343837067}), |
632 | torch::tensor( |
633 | {-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}), |
634 | torch::tensor({-2.0633690434349052}), |
635 | }, |
636 | { |
637 | torch::tensor( |
638 | {0.1796366609840706, |
639 | 0.24250882442244143, |
640 | 0.17963666098407047, |
641 | 0.24250882442244096, |
642 | 0.17963666098407025, |
643 | 0.24250882442244098}), |
644 | torch::tensor( |
645 | {0.6287216343837069, 0.6287216343837067, 0.6287216343837067}), |
646 | torch::tensor( |
647 | {-1.4123467490671706, -1.4123467490671706, -1.4123467490671708}), |
648 | torch::tensor({-2.0633690434349052}), |
649 | }, |
650 | { |
651 | torch::tensor( |
652 | {0.1796366609840692, |
653 | 0.24250882442244046, |
654 | 0.17963666098407022, |
655 | 0.24250882442244082, |
656 | 0.17963666098407, |
657 | 0.24250882442244104}), |
658 | torch::tensor( |
659 | {0.6287216343837063, 0.6287216343837068, 0.6287216343837067}), |
660 | torch::tensor( |
661 | {-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}), |
662 | torch::tensor({-2.0633690434349052}), |
663 | }, |
664 | }; |
665 | } |
666 | |
667 | inline std::vector<std::vector<torch::Tensor>> |
668 | Adam_with_weight_decay_and_amsgrad() { |
669 | return { |
670 | { |
671 | torch::tensor( |
672 | {0.7890972867575196, |
673 | 0.5024410692260988, |
674 | 0.8587073313091852, |
675 | 0.6579707241257546, |
676 | 0.7476356819241026, |
677 | 1.6975564206261673}), |
678 | torch::tensor( |
679 | {0.8914676360248869, 0.7020513497574256, 1.6892012709389561}), |
680 | torch::tensor( |
681 | {-1.050803095846074, -1.3941351509567128, -1.284337577714342}), |
682 | torch::tensor({-1.0711381102987145}), |
683 | }, |
684 | { |
685 | torch::tensor( |
686 | {6.790598887618061, |
687 | 6.914995398136696, |
688 | 6.41533478566264, |
689 | 6.297644005485053, |
690 | 5.845162499872375, |
691 | 6.862229173597117}), |
692 | torch::tensor( |
693 | {7.958707058914726, 6.511338624975532, 7.100969502256063}), |
694 | torch::tensor( |
695 | {-6.690689640539306, -7.056584601121166, -6.72114879738572}), |
696 | torch::tensor({-6.406608295022552}), |
697 | }, |
698 | { |
699 | torch::tensor( |
700 | {4.707506618354547, |
701 | 5.291519064582759, |
702 | 6.0451502264500006, |
703 | 6.024403702678936, |
704 | 5.309533822430375, |
705 | 6.388110918107735}), |
706 | torch::tensor( |
707 | {7.200495189000188, 6.398387074819269, 6.904125817198589}), |
708 | torch::tensor( |
709 | {-6.664150387053514, -7.026716194929788, -6.705821732490459}), |
710 | torch::tensor({-6.396310969025695}), |
711 | }, |
712 | { |
713 | torch::tensor( |
714 | {2.9508632109188633, |
715 | 3.7657755643994775, |
716 | 5.60741774331852, |
717 | 5.6957903028180565, |
718 | 4.70145185833677, |
719 | 5.835064148607041}), |
720 | torch::tensor( |
721 | {6.343524400109462, 6.258242740866945, 6.663022973860484}), |
722 | torch::tensor( |
723 | {-6.630461605133603, -6.988854886932907, -6.686194352796841}), |
724 | torch::tensor({-6.383010489575922}), |
725 | }, |
726 | { |
727 | torch::tensor( |
728 | {1.7128944635692829, |
729 | 2.536365345915568, |
730 | 5.140416924817106, |
731 | 5.33803266121343, |
732 | 4.083921806116116, |
733 | 5.254596369238127}), |
734 | torch::tensor( |
735 | {5.477917349690043, 6.100068192681452, 6.394747035239918}), |
736 | torch::tensor( |
737 | {-6.591742325353548, -6.945355504749947, -6.663584873152447}), |
738 | torch::tensor({-6.367675348676994}), |
739 | }, |
740 | { |
741 | torch::tensor( |
742 | {0.9341502247285258, |
743 | 1.6339620765410685, |
744 | 4.6679910940755835, |
745 | 4.967688979298023, |
746 | 3.4933141073198866, |
747 | 4.678195347615295}), |
748 | torch::tensor( |
749 | {4.655117321178743, 5.9294836450698245, 6.110061909503652}), |
750 | torch::tensor( |
751 | {-6.549116242899458, -6.897486328511809, -6.638629863638681}), |
752 | torch::tensor({-6.350731975696792}), |
753 | }, |
754 | { |
755 | torch::tensor( |
756 | {0.483518223008081, |
757 | 1.014261673266094, |
758 | 4.205928052060015, |
759 | 4.596195204751035, |
760 | 2.9502123780175378, |
761 | 4.125826973031755}), |
762 | torch::tensor( |
763 | {3.90359317709392, 5.750505227817309, 5.816617309738603}), |
764 | torch::tensor( |
765 | {-6.503371263778851, -6.846137347295816, -6.611773185243846}), |
766 | torch::tensor({-6.3324769650576656}), |
767 | }, |
768 | { |
769 | torch::tensor( |
770 | {0.2393576446248765, |
771 | 0.6100241101779533, |
772 | 3.764601561942264, |
773 | 4.231602962540335, |
774 | 2.4647709193637635, |
775 | 3.6096961114614476}), |
776 | torch::tensor( |
777 | {3.236721556734532, 5.566168160977344, 5.520085344708356}), |
778 | torch::tensor( |
779 | {-6.455100840665729, -6.791979259673106, -6.583347815648856}), |
780 | torch::tensor({-6.3131323905801136}), |
781 | }, |
782 | { |
783 | torch::tensor( |
784 | {0.11401265024593016, |
785 | 0.3570521972760832, |
786 | 3.350517925954931, |
787 | 3.8795333009419823, |
788 | 2.0402068661130683, |
789 | 3.136550602110189}), |
790 | torch::tensor( |
791 | {2.6579570162250215, 5.378834741966309, 5.224742933241745}), |
792 | torch::tensor( |
793 | {-6.4047717084756375, -6.7355397098532155, -6.55361495837462}), |
794 | torch::tensor({-6.2928722155069945}), |
795 | }, |
796 | { |
797 | torch::tensor( |
798 | {0.05251515193791458, |
799 | 0.20410212473600725, |
800 | 2.9673680881961273, |
801 | 3.543794405777883, |
802 | 1.6752677855061209, |
803 | 2.709287985107431}), |
804 | torch::tensor( |
805 | {2.164479166686583, 5.190372657839918, 4.9338240234040756}), |
806 | torch::tensor( |
807 | {-6.352761531270841, -6.6772456859648175, -6.52278570167088}), |
808 | torch::tensor({-6.271836898876738}), |
809 | }, |
810 | { |
811 | torch::tensor( |
812 | {0.023489947480426834, |
813 | 0.11428338573638941, |
814 | 2.616797245623764, |
815 | 3.226821571853439, |
816 | 1.3659994589608537, |
817 | 2.328136084816453}), |
818 | torch::tensor( |
819 | {1.74978620664146, 5.002269811977871, 4.649756802441968}), |
820 | torch::tensor( |
821 | {-6.299382007948917, -6.617449564196286, -6.491034254081261}), |
822 | torch::tensor({-6.250142306631259}), |
823 | }, |
824 | }; |
825 | } |
826 | |
827 | inline std::vector<std::vector<torch::Tensor>> AdamW() { |
828 | return { |
829 | { |
830 | torch::tensor( |
831 | {0.7912062750121864, |
832 | 0.5074166292785842, |
833 | 0.8601202529258052, |
834 | 0.6613910130887053, |
835 | 0.7501593169903569, |
836 | 1.6905808503961983}), |
837 | torch::tensor( |
838 | {0.8925529482073002, 0.7050308347536254, 1.682309255842939}), |
839 | torch::tensor( |
840 | {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), |
841 | torch::tensor({-1.0704267290556988}), |
842 | }, |
843 | { |
844 | torch::tensor( |
845 | {3.3165329599188507, |
846 | 3.223120441823618, |
847 | 2.665544565239194, |
848 | 2.6044341406663225, |
849 | 2.479859063483047, |
850 | 2.836831717112226}), |
851 | torch::tensor( |
852 | {3.3885192024669744, 2.6544147219174556, 2.8709245656887328}), |
853 | torch::tensor( |
854 | {-2.70172647102137, -2.836731459490802, -2.69652471546253}), |
855 | torch::tensor({-2.575239255076019}), |
856 | }, |
857 | { |
858 | torch::tensor( |
859 | {2.231471944853865, |
860 | 2.3549328325971755, |
861 | 1.5699078054795328, |
862 | 1.6160272935884685, |
863 | 1.5339085081403547, |
864 | 1.7397405105941612}), |
865 | torch::tensor( |
866 | {2.8552579170807926, 1.8369866847839356, 1.9735168512425862}), |
867 | torch::tensor( |
868 | {-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}), |
869 | torch::tensor({-1.6180915942867784}), |
870 | }, |
871 | { |
872 | torch::tensor( |
873 | {2.084688381515552, |
874 | 2.3141612674892946, |
875 | 1.4850714710140511, |
876 | 1.5961047256668386, |
877 | 1.440300645879787, |
878 | 1.6065354941586025}), |
879 | torch::tensor( |
880 | {3.0111385685659444, 1.955556497153507, 1.9596562467797627}), |
881 | torch::tensor( |
882 | {-2.889337305884852, -2.965249100126337, -1.7721676671605975}), |
883 | torch::tensor({-1.4001341655590005}), |
884 | }, |
885 | { |
886 | torch::tensor( |
887 | {2.0465343456006604, |
888 | 2.311613891239368, |
889 | 1.4666717526896398, |
890 | 1.601383980913499, |
891 | 1.4223660595993763, |
892 | 1.5711552625612757}), |
893 | torch::tensor( |
894 | {3.07151984580744, 2.0112690538174802, 1.9592484602763875}), |
895 | torch::tensor( |
896 | {-3.0186469726426863, -3.093855445542849, -1.7367953899738784}), |
897 | torch::tensor({-1.3299011560804312}), |
898 | }, |
899 | { |
900 | torch::tensor( |
901 | {2.039659777412556, |
902 | 2.3178034179536273, |
903 | 1.4654302718412722, |
904 | 1.6094701969162322, |
905 | 1.4230510816446773, |
906 | 1.565168902852383}), |
907 | torch::tensor( |
908 | {3.1007583934270064, 2.039757113618415, 1.9652096140698696}), |
909 | torch::tensor( |
910 | {-3.0880626664330832, -3.166705422245348, -1.73538367534238}), |
911 | torch::tensor({-1.3130428735015893}), |
912 | }, |
913 | { |
914 | torch::tensor( |
915 | {2.0413773043991963, |
916 | 2.3251469369586366, |
917 | 1.4690808101517236, |
918 | 1.6174065798291044, |
919 | 1.4280274009117935, |
920 | 1.5682418226469732}), |
921 | torch::tensor( |
922 | {3.118843540209399, 2.057729936485249, 1.9742319629710936}), |
923 | torch::tensor( |
924 | {-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}), |
925 | torch::tensor({-1.3148644134154366}), |
926 | }, |
927 | { |
928 | torch::tensor( |
929 | {2.0452604138357113, |
930 | 2.332074253989847, |
931 | 1.4738845773449165, |
932 | 1.6246403004735728, |
933 | 1.4335712611625357, |
934 | 1.573826630920094}), |
935 | torch::tensor( |
936 | {3.1324088069784093, 2.0711763619826575, 1.9841582498316732}), |
937 | torch::tensor( |
938 | {-3.16737058959847, -3.2529206463859146, -1.7602788393925501}), |
939 | torch::tensor({-1.32281766461531}), |
940 | }, |
941 | { |
942 | torch::tensor( |
943 | {2.0495243704493262, |
944 | 2.338413341249581, |
945 | 1.4787599440132637, |
946 | 1.631210274009555, |
947 | 1.438849155552895, |
948 | 1.5798736919537595}), |
949 | torch::tensor( |
950 | {3.1438209015414227, 2.0823943437659658, 1.9940075805973108}), |
951 | torch::tensor( |
952 | {-3.19628690363529, -3.2845941643030367, -1.7752333900055153}), |
953 | torch::tensor({-1.332456718314933}), |
954 | }, |
955 | { |
956 | torch::tensor( |
957 | {2.0536979895206295, |
958 | 2.3442520601250334, |
959 | 1.4834272584224222, |
960 | 1.6372462654983486, |
961 | 1.4437517398490174, |
962 | 1.585780877834892}), |
963 | torch::tensor( |
964 | {3.1540081461072447, 2.0923381262560454, 2.0034284957296107}), |
965 | torch::tensor( |
966 | {-3.222142968201519, -3.312867602521477, -1.7898220261118043}), |
967 | torch::tensor({-1.3422692037690986}), |
968 | }, |
969 | { |
970 | torch::tensor( |
971 | {2.0576784836825315, |
972 | 2.3496934395759377, |
973 | 1.4878413407927933, |
974 | 1.6428612479757005, |
975 | 1.4483225979568104, |
976 | 1.5914034339763325}), |
977 | torch::tensor( |
978 | {3.163383747232199, 2.101446878895216, 2.012344413569353}), |
979 | torch::tensor( |
980 | {-3.246000281299229, -3.338904166978488, -1.8037666936489785}), |
981 | torch::tensor({-1.3517884775416527}), |
982 | }, |
983 | }; |
984 | } |
985 | |
986 | inline std::vector<std::vector<torch::Tensor>> AdamW_without_weight_decay() { |
987 | return { |
988 | { |
989 | torch::tensor( |
990 | {0.7890972864438476, |
991 | 0.5024410688121617, |
992 | 0.858707331305558, |
993 | 0.6579707241208395, |
994 | 0.7476356819075531, |
995 | 1.6975564206516922}), |
996 | torch::tensor( |
997 | {0.891467636010675, 0.70205134975675, 1.689201270942895}), |
998 | torch::tensor( |
999 | {-1.0508030958460797, -1.3941351509567654, -1.284337577714353}), |
1000 | torch::tensor({-1.071138110298716}), |
1001 | }, |
1002 | { |
1003 | torch::tensor( |
1004 | {8.233039313231831, |
1005 | 7.971150747377481, |
1006 | 6.643620950677599, |
1007 | 6.47097740790054, |
1008 | 6.170125488259256, |
1009 | 7.150739103343502}), |
1010 | torch::tensor( |
1011 | {8.417695070103738, 6.597188212844593, 7.23175710827678}), |
1012 | torch::tensor( |
1013 | {-6.729624357635757, -7.09743493108154, -6.753301896575352}), |
1014 | torch::tensor({-6.435639096011218}), |
1015 | }, |
1016 | { |
1017 | torch::tensor( |
1018 | {8.233424596059299, |
1019 | 7.971537360032308, |
1020 | 6.643920150720393, |
1021 | 6.471278075537239, |
1022 | 6.170405874224489, |
1023 | 7.151021086137983}), |
1024 | torch::tensor( |
1025 | {8.418084791214298, 6.597493171180545, 7.232043740621598}), |
1026 | torch::tensor( |
1027 | {-6.729918250724671, -7.097730102046093, -6.753584809755359}), |
1028 | torch::tensor({-6.4359165566974985}), |
1029 | }, |
1030 | { |
1031 | torch::tensor( |
1032 | {8.233424610557652, |
1033 | 7.971537374586563, |
1034 | 6.643920161995284, |
1035 | 6.471278086877828, |
1036 | 6.170405884785074, |
1037 | 7.151021096766406}), |
1038 | torch::tensor( |
1039 | {8.418084805901906, 6.597493182713584, 7.2320437514477875}), |
1040 | torch::tensor( |
1041 | {-6.72991829363266, -7.097730147102975, -6.753584838821182}), |
1042 | torch::tensor({-6.435916580217771}), |
1043 | }, |
1044 | { |
1045 | torch::tensor( |
1046 | {8.233424610575105, |
1047 | 7.971537374611125, |
1048 | 6.643920162027961, |
1049 | 6.471278086923277, |
1050 | 6.170405884809245, |
1051 | 7.151021096800041}), |
1052 | torch::tensor( |
1053 | {8.418084805946393, 6.597493182796847, 7.232043751509309}), |
1054 | torch::tensor( |
1055 | {-6.729918332327653, -7.097730188349552, -6.753584861205486}), |
1056 | torch::tensor({-6.435916596115672}), |
1057 | }, |
1058 | { |
1059 | torch::tensor( |
1060 | {8.233424610594861, |
1061 | 7.971537374639166, |
1062 | 6.64392016206557, |
1063 | 6.471278086975758, |
1064 | 6.170405884836981, |
1065 | 7.151021096838798}), |
1066 | torch::tensor( |
1067 | {8.418084805997617, 6.59749318289335, 7.232043751580523}), |
1068 | torch::tensor( |
1069 | {-6.72991837738045, -7.097730236373201, -6.753584887267492}), |
1070 | torch::tensor({-6.43591661462546}), |
1071 | }, |
1072 | { |
1073 | torch::tensor( |
1074 | {8.233424610617291, |
1075 | 7.971537374671012, |
1076 | 6.643920162108284, |
1077 | 6.471278087035361, |
1078 | 6.170405884868481, |
1079 | 7.151021096882812}), |
1080 | torch::tensor( |
1081 | {8.418084806055798, 6.59749318300295, 7.232043751661401}), |
1082 | torch::tensor( |
1083 | {-6.729918428547273, -7.09773029091405, -6.753584916866329}), |
1084 | torch::tensor({-6.4359166356471755}), |
1085 | }, |
1086 | { |
1087 | torch::tensor( |
1088 | {8.233424610642356, |
1089 | 7.9715373747065925, |
1090 | 6.643920162156006, |
1091 | 6.471278087101954, |
1092 | 6.1704058849036745, |
1093 | 7.15102109693199}), |
1094 | torch::tensor( |
1095 | {8.418084806120802, 6.597493183125404, 7.232043751751764}), |
1096 | torch::tensor( |
1097 | {-6.729918485714688, -7.0977303518511805, -6.753584949936365}), |
1098 | torch::tensor({-6.43591665913422}), |
1099 | }, |
1100 | { |
1101 | torch::tensor( |
1102 | {8.233424610670038, |
1103 | 7.97153737474589, |
1104 | 6.643920162208714, |
1105 | 6.471278087175501, |
1106 | 6.170405884942545, |
1107 | 7.151021096986303}), |
1108 | torch::tensor( |
1109 | {8.418084806192596, 6.597493183260647, 7.232043751851564}), |
1110 | torch::tensor( |
1111 | {-6.729918548853505, -7.097730419153473, -6.753584986460725}), |
1112 | torch::tensor({-6.435916685074594}), |
1113 | }, |
1114 | { |
1115 | torch::tensor( |
1116 | {8.233424610700352, |
1117 | 7.971537374788922, |
1118 | 6.643920162266432, |
1119 | 6.47127808725604, |
1120 | 6.17040588498511, |
1121 | 7.1510210970457795}), |
1122 | torch::tensor( |
1123 | {8.418084806271217, 6.597493183408747, 7.232043751960854}), |
1124 | torch::tensor( |
1125 | {-6.7299186179943, -7.097730492853521, -6.753585026457088}), |
1126 | torch::tensor({-6.435916713480863}), |
1127 | }, |
1128 | { |
1129 | torch::tensor( |
1130 | {8.23342461073333, |
1131 | 7.971537374835737, |
1132 | 6.643920162329224, |
1133 | 6.471278087343658, |
1134 | 6.170405885031416, |
1135 | 7.151021097110484}), |
1136 | torch::tensor( |
1137 | {8.418084806356747, 6.597493183569867, 7.232043752079749}), |
1138 | torch::tensor( |
1139 | {-6.729918693213275, -7.097730573032567, -6.753585069969552}), |
1140 | torch::tensor({-6.43591674438434}), |
1141 | }, |
1142 | }; |
1143 | } |
1144 | |
1145 | inline std::vector<std::vector<torch::Tensor>> AdamW_with_amsgrad() { |
1146 | return { |
1147 | { |
1148 | torch::tensor( |
1149 | {0.7912062750121864, |
1150 | 0.5074166292785842, |
1151 | 0.8601202529258052, |
1152 | 0.6613910130887053, |
1153 | 0.7501593169903569, |
1154 | 1.6905808503961983}), |
1155 | torch::tensor( |
1156 | {0.8925529482073002, 0.7050308347536254, 1.682309255842939}), |
1157 | torch::tensor( |
1158 | {-1.05029506454492, -1.3901937990816595, -1.2814942017397601}), |
1159 | torch::tensor({-1.0704267290556988}), |
1160 | }, |
1161 | { |
1162 | torch::tensor( |
1163 | {3.3017259270507915, |
1164 | 3.2082991753694565, |
1165 | 2.653930978510442, |
1166 | 2.5927674339810585, |
1167 | 2.4689608790182933, |
1168 | 2.825873703467739}), |
1169 | torch::tensor( |
1170 | {3.373698198112671, 2.6425942964586664, 2.8597930424244304}), |
1171 | torch::tensor( |
1172 | {-2.690360632302962, -2.8253191596069525, -2.6855499873057473}), |
1173 | torch::tensor({-2.5644658591929406}), |
1174 | }, |
1175 | { |
1176 | torch::tensor( |
1177 | {2.222607725541013, |
1178 | 2.3447188854637004, |
1179 | 1.5614270655258826, |
1180 | 1.606610018462357, |
1181 | 1.5260497191448619, |
1182 | 1.7309643622674138}), |
1183 | torch::tensor( |
1184 | {2.84137462783552, 1.824806600633721, 1.9620493659996037}), |
1185 | torch::tensor( |
1186 | {-2.576642773625787, -2.6706153846815766, -1.8799876863754623}), |
1187 | torch::tensor({-1.6044722984810953}), |
1188 | }, |
1189 | { |
1190 | torch::tensor( |
1191 | {2.0739558768648205, |
1192 | 2.3008338863863496, |
1193 | 1.4738888208638767, |
1194 | 1.5829485271829449, |
1195 | 1.4296176764284294, |
1196 | 1.5939984909850073}), |
1197 | torch::tensor( |
1198 | {2.9908013612792415, 1.936590940953305, 1.941691630199464}), |
1199 | torch::tensor( |
1200 | {-2.846562884997548, -2.9195962101501203, -1.746484716887341}), |
1201 | torch::tensor({-1.381525131003179}), |
1202 | }, |
1203 | { |
1204 | torch::tensor( |
1205 | {2.0333926094256953, |
1206 | 2.294977109171754, |
1207 | 1.452870514716895, |
1208 | 1.584853677999522, |
1209 | 1.4086299433181402, |
1210 | 1.5548201727855224}), |
1211 | torch::tensor( |
1212 | {3.0454817801193976, 1.9867169062383696, 1.935312753106444}), |
1213 | torch::tensor( |
1214 | {-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}), |
1215 | torch::tensor({-1.30563541393247}), |
1216 | }, |
1217 | { |
1218 | torch::tensor( |
1219 | {2.02392417168201, |
1220 | 2.2977279587859, |
1221 | 1.4488511120131309, |
1222 | 1.5894646930743725, |
1223 | 1.406253686759073, |
1224 | 1.5450647949022756}), |
1225 | torch::tensor( |
1226 | {3.069025602955343, 2.0096872138967488, 1.935438546309299}), |
1227 | torch::tensor( |
1228 | {-3.016103148166836, -3.0893062953033583, -1.6925290685615872}), |
1229 | torch::tensor({-1.282870120405012}), |
1230 | }, |
1231 | { |
1232 | torch::tensor( |
1233 | {2.0230257817348316, |
1234 | 2.3016167065040647, |
1235 | 1.4496901978629444, |
1236 | 1.5939034289777392, |
1237 | 1.4082421794430946, |
1238 | 1.5444538756003756}), |
1239 | torch::tensor( |
1240 | {3.0814016132787954, 2.022150201844143, 1.9387429991308658}), |
1241 | torch::tensor( |
1242 | {-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}), |
1243 | torch::tensor({-1.2786980736911064}), |
1244 | }, |
1245 | { |
1246 | torch::tensor( |
1247 | {2.024305922065404, |
1248 | 2.305101549510461, |
1249 | 1.4516863420588493, |
1250 | 1.5976447954882376, |
1251 | 1.4108596097183552, |
1252 | 1.5464236284303425}), |
1253 | torch::tensor( |
1254 | {3.0892574233545065, 2.0300944858242236, 1.943040321845021}), |
1255 | torch::tensor( |
1256 | {-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}), |
1257 | torch::tensor({-1.2806281811826203}), |
1258 | }, |
1259 | { |
1260 | torch::tensor( |
1261 | {2.0259854116237364, |
1262 | 2.3080169152218293, |
1263 | 1.4537680296813915, |
1264 | 1.6007369432392426, |
1265 | 1.4132529823064277, |
1266 | 1.5489035046525346}), |
1267 | torch::tensor( |
1268 | {3.0949717180851337, 2.0358251379915764, 1.9473249654893}), |
1269 | torch::tensor( |
1270 | {-3.0808231377905426, -3.160021699873689, -1.7062031001273494}), |
1271 | torch::tensor({-1.28423401369705}), |
1272 | }, |
1273 | { |
1274 | torch::tensor( |
1275 | {2.0275923948348638, |
1276 | 2.3104512601389637, |
1277 | 1.455657715721078, |
1278 | 1.6033123357613526, |
1279 | 1.4153003204463288, |
1280 | 1.5512775896116622}), |
1281 | torch::tensor( |
1282 | {3.099479021299846, 2.0403012223048775, 1.9512285931847464}), |
1283 | torch::tensor( |
1284 | {-3.092151979336299, -3.1725453680885267, -1.7120689614428697}), |
1285 | torch::tensor({-1.2880095517062655}), |
1286 | }, |
1287 | { |
1288 | torch::tensor( |
1289 | {2.029022468328371, |
1290 | 2.3125066985045892, |
1291 | 1.4573100228823295, |
1292 | 1.605484259933419, |
1293 | 1.417038245960655, |
1294 | 1.5533932056240227}), |
1295 | torch::tensor( |
1296 | {3.103195011518616, 2.043964003458376, 1.9546640840748621}), |
1297 | torch::tensor( |
1298 | {-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}), |
1299 | torch::tensor({-1.2914899987134136}), |
1300 | }, |
1301 | }; |
1302 | } |
1303 | |
1304 | inline std::vector<std::vector<torch::Tensor>> Adagrad() { |
1305 | return { |
1306 | { |
1307 | torch::tensor( |
1308 | {0.7891011045987429, |
1309 | 0.502443924512199, |
1310 | 0.8587078329085825, |
1311 | 0.6579710994224826, |
1312 | 0.7476364836215006, |
1313 | 1.697557019500397}), |
1314 | torch::tensor( |
1315 | {0.8914687688941954, 0.7020514988069096, 1.6892015076050444}), |
1316 | torch::tensor( |
1317 | {-1.0508031297732776, -1.3941351871450518, -1.284337597261839}), |
1318 | torch::tensor({-1.071138124161711}), |
1319 | }, |
1320 | { |
1321 | torch::tensor( |
1322 | {2.4079229696892583, |
1323 | 2.2346803754764286, |
1324 | 1.6967885588547365, |
1325 | 1.552279695827649, |
1326 | 1.2259044248443602, |
1327 | 2.221279696180243}), |
1328 | torch::tensor( |
1329 | {2.9334079162217193, 1.7619824934767887, 2.3464577179091473}), |
1330 | torch::tensor( |
1331 | {-2.221396083069719, -2.549950976011168, -1.9709315957317095}), |
1332 | torch::tensor({-1.5858816837541876}), |
1333 | }, |
1334 | { |
1335 | torch::tensor( |
1336 | {2.510404433941812, |
1337 | 2.3522584510262887, |
1338 | 1.7921695110761213, |
1339 | 1.657755825836846, |
1340 | 1.2891186618593045, |
1341 | 2.291878516133922}), |
1342 | torch::tensor( |
1343 | {3.092171180776419, 1.8971624370952997, 2.438734251283465}), |
1344 | torch::tensor( |
1345 | {-2.437641633486504, -2.7704264590526573, -2.0949471699460225}), |
1346 | torch::tensor({-1.6769121890401757}), |
1347 | }, |
1348 | { |
1349 | torch::tensor( |
1350 | {2.5652648968109415, |
1351 | 2.4155313947260972, |
1352 | 1.844241233613541, |
1353 | 1.7156513351246399, |
1354 | 1.3245206506797171, |
1355 | 2.3315409972138825}), |
1356 | torch::tensor( |
1357 | {3.178399916514377, 1.9721945764936502, 2.4909037706250428}), |
1358 | torch::tensor( |
1359 | {-2.5658710403147933, -2.901921821645266, -2.168560672193225}), |
1360 | torch::tensor({-1.7307903926154131}), |
1361 | }, |
1362 | { |
1363 | torch::tensor( |
1364 | {2.6021584494332592, |
1365 | 2.4582101324909065, |
1366 | 1.8796060082750778, |
1367 | 1.7550965207414717, |
1368 | 1.3489253597999988, |
1369 | 2.3589345190118247}), |
1370 | torch::tensor( |
1371 | {3.2368674310041516, 2.0236468833666894, 2.52707132741292}), |
1372 | torch::tensor( |
1373 | {-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}), |
1374 | torch::tensor({-1.7692090167089707}), |
1375 | }, |
1376 | { |
1377 | torch::tensor( |
1378 | {2.629700772579208, |
1379 | 2.4901377017698683, |
1380 | 1.906173477530586, |
1381 | 1.7847957161833832, |
1382 | 1.3674517119505822, |
1383 | 2.3797578857769905}), |
1384 | torch::tensor( |
1385 | {3.2807643102638546, 2.062561811940094, 2.5546379424362775}), |
1386 | torch::tensor( |
1387 | {-2.7286379977755035, -3.0695109399636236, -2.262081199960513}), |
1388 | torch::tensor({-1.7990936323432214}), |
1389 | }, |
1390 | { |
1391 | torch::tensor( |
1392 | {2.6515471766995247, |
1393 | 2.51550257362603, |
1394 | 1.927341363452414, |
1395 | 1.8084994719811576, |
1396 | 1.3823309942932445, |
1397 | 2.3964995243914373}), |
1398 | torch::tensor( |
1399 | {3.3157334001309473, 2.093728023484945, 2.5768468697402924}), |
1400 | torch::tensor( |
1401 | {-2.786981763434855, -3.129746439571402, -2.29562487034177}), |
1402 | torch::tensor({-1.8235564908139104}), |
1403 | }, |
1404 | { |
1405 | torch::tensor( |
1406 | {2.6695780544837886, |
1407 | 2.53646401614724, |
1408 | 1.9448721033433505, |
1409 | 1.828157582353901, |
1410 | 1.3947329882074622, |
1411 | 2.4104657178934947}), |
1412 | torch::tensor( |
1413 | {3.344694775590452, 2.1196465761628516, 2.5954050923596252}), |
1414 | torch::tensor( |
1415 | {-2.8363936812536537, -3.1808219609745194, -2.32404190866147}), |
1416 | torch::tensor({-1.8442667636913117}), |
1417 | }, |
1418 | { |
1419 | torch::tensor( |
1420 | {2.684883801533072, |
1421 | 2.5542762192735515, |
1422 | 1.9597939532350015, |
1423 | 1.844909608012419, |
1424 | 1.4053459079217485, |
1425 | 2.4224257790968386}), |
1426 | torch::tensor( |
1427 | {3.369349515259956, 2.1417845308976795, 2.611319989214332}), |
1428 | torch::tensor( |
1429 | {-2.879251075341889, -3.225165734647855, -2.3486956737228057}), |
1430 | torch::tensor({-1.86222449978646}), |
1431 | }, |
1432 | { |
1433 | torch::tensor( |
1434 | {2.698151012423769, |
1435 | 2.56972998600169, |
1436 | 1.972757472697587, |
1437 | 1.8594775691681182, |
1438 | 1.4146081751022495, |
1439 | 2.43287021079559}), |
1440 | torch::tensor( |
1441 | {3.390772758897601, 2.1610741754331757, 2.6252349489549824}), |
1442 | torch::tensor( |
1443 | {-2.917092322961074, -3.264351563375218, -2.370468664387175}), |
1444 | torch::tensor({-1.8780765115117757}), |
1445 | }, |
1446 | { |
1447 | torch::tensor( |
1448 | {2.7098389356033783, |
1449 | 2.5833548721723747, |
1450 | 1.9841994925173085, |
1451 | 1.8723468731726323, |
1452 | 1.4228158926355312, |
1453 | 2.4421305315945085}), |
1454 | torch::tensor( |
1455 | {3.4096859099156673, 2.178143852041279, 2.6375854547611364}), |
1456 | torch::tensor( |
1457 | {-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}), |
1458 | torch::tensor({-1.8922653655195538}), |
1459 | }, |
1460 | }; |
1461 | } |
1462 | |
1463 | inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay() { |
1464 | return { |
1465 | { |
1466 | torch::tensor( |
1467 | {0.7891011218979068, |
1468 | 0.5024439415126254, |
1469 | 0.8587078332470682, |
1470 | 0.6579710998575992, |
1471 | 0.7476364849956589, |
1472 | 1.6975570150849029}), |
1473 | torch::tensor( |
1474 | {0.8914687701583902, 0.7020514988715463, 1.6892015071335027}), |
1475 | torch::tensor( |
1476 | {-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}), |
1477 | torch::tensor({-1.0711381241615712}), |
1478 | }, |
1479 | { |
1480 | torch::tensor( |
1481 | {0.1846116678522213, |
1482 | 0.24944077103107917, |
1483 | 0.18651745437755768, |
1484 | 0.25219093533041764, |
1485 | 0.18712037968446713, |
1486 | 0.25289206444055234}), |
1487 | torch::tensor( |
1488 | {0.6482869597891656, 0.6580215784646755, 0.6581256007663537}), |
1489 | torch::tensor( |
1490 | {-1.454709711443681, -1.4748063405174818, -1.4811625946604765}), |
1491 | torch::tensor({-1.905292836544363}), |
1492 | }, |
1493 | { |
1494 | torch::tensor( |
1495 | {0.18059895999281475, |
1496 | 0.2438515539257779, |
1497 | 0.18067177884778182, |
1498 | 0.24397186395008694, |
1499 | 0.18168388351830797, |
1500 | 0.24533853846052017}), |
1501 | torch::tensor( |
1502 | {0.6325250261983028, 0.6331827793513023, 0.6366659383355598}), |
1503 | torch::tensor( |
1504 | {-1.420803333750877, -1.4215627240541653, -1.4320264544533396}), |
1505 | torch::tensor({-2.030135641848322}), |
1506 | }, |
1507 | { |
1508 | torch::tensor( |
1509 | {0.17981392697398363, |
1510 | 0.2427571544305695, |
1511 | 0.17981150414451733, |
1512 | 0.24275725992310523, |
1513 | 0.18014798619115763, |
1514 | 0.2432144956227816}), |
1515 | torch::tensor( |
1516 | {0.6294321320817985, 0.6294873737410742, 0.6306958589251878}), |
1517 | torch::tensor( |
1518 | {-1.4139253354785764, -1.413902680470981, -1.4173628530293867}), |
1519 | torch::tensor({-2.056210117690093}), |
1520 | }, |
1521 | { |
1522 | torch::tensor( |
1523 | {0.17967006242163747, |
1524 | 0.24255582734557277, |
1525 | 0.17966873677301953, |
1526 | 0.24255462870545766, |
1527 | 0.1797588230898851, |
1528 | 0.24267729072765756}), |
1529 | torch::tensor( |
1530 | {0.6288576295241085, 0.6288643132826753, 0.6291921485342001}), |
1531 | torch::tensor( |
1532 | {-1.4126465879787569, -1.4126335126907266, -1.4135586793353685}), |
1533 | torch::tensor({-2.0618018405404825}), |
1534 | }, |
1535 | { |
1536 | torch::tensor( |
1537 | {0.17964321284685653, |
1538 | 0.24251808241139364, |
1539 | 0.17964291377171066, |
1540 | 0.24251779104198598, |
1541 | 0.1796651574178102, |
1542 | 0.2425481059085456}), |
1543 | torch::tensor( |
1544 | {0.628748693136779, 0.6287498167193976, 0.6288312441271762}), |
1545 | torch::tensor( |
1546 | {-1.412405895385289, -1.4124029484481164, -1.4126313051315378}), |
1547 | torch::tensor({-2.0630223163099304}), |
1548 | }, |
1549 | { |
1550 | torch::tensor( |
1551 | {0.1796379973927849, |
1552 | 0.2425107191246215, |
1553 | 0.1796379363134342, |
1554 | 0.2425106592536174, |
1555 | 0.17964321802205502, |
1556 | 0.24251786094585864}), |
1557 | torch::tensor( |
1558 | {0.6287272170354161, 0.6287274414587727, 0.6287468362309863}), |
1559 | torch::tensor( |
1560 | {-1.4123588626342634, -1.412358263650784, -1.412412480569672}), |
1561 | torch::tensor({-2.0632918101480255}), |
1562 | }, |
1563 | { |
1564 | torch::tensor( |
1565 | {0.17963694231402444, |
1566 | 0.24250922426893615, |
1567 | 0.1796369298071621, |
1568 | 0.2425092121007451, |
1569 | 0.17963815759528073, |
1570 | 0.24251088666939838}), |
1571 | torch::tensor( |
1572 | {0.6287228195255881, 0.6287228675172439, 0.628727383936762}), |
1573 | torch::tensor( |
1574 | {-1.4123493065102781, -1.412349184462438, -1.4123617872243597}), |
1575 | torch::tensor({-2.063351765096138}), |
1576 | }, |
1577 | { |
1578 | torch::tensor( |
1579 | {0.1796367215904632, |
1580 | 0.24250891070978667, |
1581 | 0.17963671897003045, |
1582 | 0.24250890818318074, |
1583 | 0.17963700091084855, |
1584 | 0.24250929278196054}), |
1585 | torch::tensor( |
1586 | {0.6287218911936107, 0.6287219017313679, 0.6287229399204574}), |
1587 | torch::tensor( |
1588 | {-1.4123473011084142, -1.412347275640343, -1.41235016959507}), |
1589 | torch::tensor({-2.0633651674043505}), |
1590 | }, |
1591 | { |
1592 | torch::tensor( |
1593 | {0.17963667424978783, |
1594 | 0.24250884333120687, |
1595 | 0.17963667368829764, |
1596 | 0.24250884279379462, |
1597 | 0.17963673796131557, |
1598 | 0.24250893047794023}), |
1599 | torch::tensor( |
1600 | {0.6287216908150736, 0.6287216931558691, 0.6287219299749583}), |
1601 | torch::tensor( |
1602 | {-1.4123468700596724, -1.4123468646187736, -1.4123475243360133}), |
1603 | torch::tensor({-2.0633681724342527}), |
1604 | }, |
1605 | { |
1606 | torch::tensor( |
1607 | {0.1796366639185348, |
1608 | 0.24250882860835257, |
1609 | 0.17963666379614568, |
1610 | 0.24250882849182892, |
1611 | 0.17963667838367053, |
1612 | 0.2425088483939741}), |
1613 | torch::tensor( |
1614 | {0.6287216468984888, 0.6287216474215305, 0.6287217011907862}), |
1615 | torch::tensor( |
1616 | {-1.4123467758545658, -1.412346774671038, -1.4123469244007658}), |
1617 | torch::tensor({-2.0633688474977467}), |
1618 | }, |
1619 | }; |
1620 | } |
1621 | |
1622 | inline std::vector<std::vector<torch::Tensor>> |
1623 | Adagrad_with_weight_decay_and_lr_decay() { |
1624 | return { |
1625 | { |
1626 | torch::tensor( |
1627 | {0.7891011046018798, |
1628 | 0.5024439245163383, |
1629 | 0.8587078329086189, |
1630 | 0.6579710994225316, |
1631 | 0.747636483621666, |
1632 | 1.697557019500142}), |
1633 | torch::tensor( |
1634 | {0.8914687688943375, 0.7020514988069164, 1.6892015076050049}), |
1635 | torch::tensor( |
1636 | {-1.0508031297732776, -1.3941351871450511, -1.284337597261839}), |
1637 | torch::tensor({-1.0711381241617108}), |
1638 | }, |
1639 | { |
1640 | torch::tensor( |
1641 | {2.346218944110103, |
1642 | 2.191939439502003, |
1643 | 1.683355201740813, |
1644 | 1.5405520021635604, |
1645 | 1.2137800230828062, |
1646 | 2.205283463717303}), |
1647 | torch::tensor( |
1648 | {2.9090564593404, 1.7509657336815554, 2.336166413186925}), |
1649 | torch::tensor( |
1650 | {-2.206159683368316, -2.5344318233445415, -1.9622783535807609}), |
1651 | torch::tensor({-1.5796101463783623}), |
1652 | }, |
1653 | { |
1654 | torch::tensor( |
1655 | {2.3889328781057233, |
1656 | 2.2678221038007296, |
1657 | 1.7667624725138267, |
1658 | 1.6358015176639822, |
1659 | 1.2655767687152566, |
1660 | 2.261088056711282}), |
1661 | torch::tensor( |
1662 | {3.045569451994985, 1.8770196253823253, 2.4192707519566765}), |
1663 | torch::tensor( |
1664 | {-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}), |
1665 | torch::tensor({-1.664722108226537}), |
1666 | }, |
1667 | { |
1668 | torch::tensor( |
1669 | {2.3886137557806384, |
1670 | 2.2922158071009178, |
1671 | 1.8078384116424007, |
1672 | 1.684352474440932, |
1673 | 1.290353948335789, |
1674 | 2.2870715509706496}), |
1675 | torch::tensor( |
1676 | {3.111110355394278, 1.9438501730282314, 2.4630249355872826}), |
1677 | torch::tensor( |
1678 | {-2.5226122034499263, -2.857315093916292, -2.143964860243905}), |
1679 | torch::tensor({-1.7130685809905042}), |
1680 | }, |
1681 | { |
1682 | torch::tensor( |
1683 | {2.374703352203156, |
1684 | 2.298804499257456, |
1685 | 1.8330249458212446, |
1686 | 1.7151661013307244, |
1687 | 1.3048586226945842, |
1688 | 2.3017650590464274}), |
1689 | torch::tensor( |
1690 | {3.150318222034133, 1.9877926185369321, 2.491399976401679}), |
1691 | torch::tensor( |
1692 | {-2.601415913361488, -2.938203895113964, -2.1892988334550028}), |
1693 | torch::tensor({-1.7462964261966805}), |
1694 | }, |
1695 | { |
1696 | torch::tensor( |
1697 | {2.3553658567303812, |
1698 | 2.297191758042688, |
1699 | 1.8501154749072124, |
1700 | 1.736836058688188, |
1701 | 1.3141313000193942, |
1702 | 2.3107452592153854}), |
1703 | torch::tensor( |
1704 | {3.1762315339155434, 2.0197585204578647, 2.5117041377790197}), |
1705 | torch::tensor( |
1706 | {-2.6606644002288697, -2.9991216074293856, -2.223413376189609}), |
1707 | torch::tensor({-1.7712905233118807}), |
1708 | }, |
1709 | { |
1710 | torch::tensor( |
1711 | {2.3338052201696207, |
1712 | 2.2913023710914993, |
1713 | 1.8624163948044772, |
1714 | 1.7530300731725454, |
1715 | 1.3203313209234842, |
1716 | 2.3163969478854747}), |
1717 | torch::tensor( |
1718 | {3.1943525925688934, 2.044447386769377, 2.527109724607397}), |
1719 | torch::tensor( |
1720 | {-2.7076634717294894, -3.047500808469036, -2.250495807208967}), |
1721 | torch::tensor({-1.7911288238757486}), |
1722 | }, |
1723 | { |
1724 | torch::tensor( |
1725 | {2.3114979154644892, |
1726 | 2.2830501835377808, |
1727 | 1.871616142999356, |
1728 | 1.765632597660841, |
1729 | 1.324565631636651, |
1730 | 2.319939234205203}), |
1731 | torch::tensor( |
1732 | {3.2074779925809085, 2.0642940833670544, 2.5392671301471235}), |
1733 | torch::tensor( |
1734 | {-2.7463093287485925, -3.087315541134716, -2.272780318857348}), |
1735 | torch::tensor({-1.8074516661537263}), |
1736 | }, |
1737 | { |
1738 | torch::tensor( |
1739 | {2.2891841627387346, |
1740 | 2.2734716995793693, |
1741 | 1.8786818999895825, |
1742 | 1.7757301317117602, |
1743 | 1.3274682997719436, |
1744 | 2.3220679353993825}), |
1745 | torch::tensor( |
1746 | {3.2172019619454075, 2.0807140893178175, 2.5491374815141876}), |
1747 | torch::tensor( |
1748 | {-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}), |
1749 | torch::tensor({-1.821234722948421}), |
1750 | }, |
1751 | { |
1752 | torch::tensor( |
1753 | {2.2672498238343066, |
1754 | 2.2631678037928893, |
1755 | 1.8842131287032622, |
1756 | 1.7840007705383882, |
1757 | 1.3294311820750493, |
1758 | 2.323211243034543}), |
1759 | torch::tensor( |
1760 | {3.224507488068445, 2.094598223519413, 2.5573257155791715}), |
1761 | torch::tensor( |
1762 | {-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}), |
1763 | torch::tensor({-1.8331040438272388}), |
1764 | }, |
1765 | { |
1766 | torch::tensor( |
1767 | {2.2458961718688957, |
1768 | 2.2525031725114775, |
1769 | 1.8886034384961112, |
1770 | 1.7908930341267952, |
1771 | 1.3307102435291205, |
1772 | 2.323647497600462}), |
1773 | torch::tensor( |
1774 | {3.230036385413078, 2.1065407459636134, 2.5642349249609664}), |
1775 | torch::tensor( |
1776 | {-2.83151249424399, -3.1751926295316566, -2.3219682378974036}), |
1777 | torch::tensor({-1.8434843744626483}), |
1778 | }, |
1779 | }; |
1780 | } |
1781 | |
1782 | inline std::vector<std::vector<torch::Tensor>> RMSprop() { |
1783 | return { |
1784 | { |
1785 | torch::tensor( |
1786 | {0.7890625772821005, |
1787 | 0.502415108650816, |
1788 | 0.8587027713011453, |
1789 | 0.657967312300643, |
1790 | 0.7476283936579036, |
1791 | 1.6975509766054537}), |
1792 | torch::tensor( |
1793 | {0.8914573371873159, 0.7020499947573374, 1.6891991194739453}), |
1794 | torch::tensor( |
1795 | {-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}), |
1796 | torch::tensor({-1.0711379842715099}), |
1797 | }, |
1798 | { |
1799 | torch::tensor( |
1800 | {2.448571858277443, |
1801 | 2.2809152044417678, |
1802 | 1.7346424449151965, |
1803 | 1.5940004770230667, |
1804 | 1.250761131839982, |
1805 | 2.248993270255382}), |
1806 | torch::tensor( |
1807 | {2.994661478530102, 1.8150485290864256, 2.382542610897819}), |
1808 | torch::tensor( |
1809 | {-2.3036981738757825, -2.6337299521275646, -2.018370122358821}), |
1810 | torch::tensor({-1.620787559800898}), |
1811 | }, |
1812 | { |
1813 | torch::tensor( |
1814 | {2.5837582475607785, |
1815 | 2.4365737242301537, |
1816 | 1.8622886519354538, |
1817 | 1.7357065282848232, |
1818 | 1.3369695670141974, |
1819 | 2.3454934716983695}), |
1820 | torch::tensor( |
1821 | {3.2061266499381618, 1.9981112525417788, 2.5092495986614}), |
1822 | torch::tensor( |
1823 | {-2.6110809365525958, -2.9484807193016787, -2.194898560798439}), |
1824 | torch::tensor({-1.7501043480625826}), |
1825 | }, |
1826 | { |
1827 | torch::tensor( |
1828 | {2.669969051134511, |
1829 | 2.536559412710799, |
1830 | 1.9456091681389671, |
1831 | 1.828914948091767, |
1832 | 1.3952956766999587, |
1833 | 2.4110816686341923}), |
1834 | torch::tensor( |
1835 | {3.343672975593657, 2.1204057198913002, 2.5961524902119497}), |
1836 | torch::tensor( |
1837 | {-2.8372329851331006, -3.1817729538857207, -2.3249971853996954}), |
1838 | torch::tensor({-1.8450422173907486}), |
1839 | }, |
1840 | { |
1841 | torch::tensor( |
1842 | {2.7375365004059122, |
1843 | 2.6153071545358633, |
1844 | 2.0117493624534313, |
1845 | 1.9033001982031035, |
1846 | 1.4427501882445097, |
1847 | 2.4646213743186127}), |
1848 | torch::tensor( |
1849 | {3.452912454199796, 2.2190451524127535, 2.667552790123282}), |
1850 | torch::tensor( |
1851 | {-3.0329479456731505, -3.384582488936652, -2.4377299824997136}), |
1852 | torch::tensor({-1.9271014784118226}), |
1853 | }, |
1854 | { |
1855 | torch::tensor( |
1856 | {2.7952372917068753, |
1857 | 2.682820220375722, |
1858 | 2.0687223272686994, |
1859 | 1.967654548778711, |
1860 | 1.4844410726622166, |
1861 | 2.5117888904510117}), |
1862 | torch::tensor( |
1863 | {3.5471904628565745, 2.305113548262141, 2.7307948248967304}), |
1864 | torch::tensor( |
1865 | {-3.2141190290332537, -3.572944633614449, -2.5421970206546827}), |
1866 | torch::tensor({-2.0029976985219666}), |
1867 | }, |
1868 | { |
1869 | torch::tensor( |
1870 | {2.8467333937519483, |
1871 | 2.7432785177110395, |
1872 | 2.119898810135385, |
1873 | 2.0256805416741255, |
1874 | 1.5225256464280221, |
1875 | 2.554983108087885}), |
1876 | torch::tensor( |
1877 | {3.632098323876194, 2.383289304179778, 2.7889864719999222}), |
1878 | torch::tensor( |
1879 | {-3.387579944167926, -3.7537658010839294, -2.6423123266260427}), |
1880 | torch::tensor({-2.075616951445725}), |
1881 | }, |
1882 | { |
1883 | torch::tensor( |
1884 | {2.8938600498060203, |
1885 | 2.798776984183615, |
1886 | 2.16697381475156, |
1887 | 2.079238538430203, |
1888 | 1.5580820887115123, |
1889 | 2.5954023023969692}), |
1890 | torch::tensor( |
1891 | {3.71043881435304, 2.455919099321953, 2.8436784441941008}), |
1892 | torch::tensor( |
1893 | {-3.5567368287146417, -3.930484868709691, -2.740026479434574}), |
1894 | torch::tensor({-2.146398256871758}), |
1895 | }, |
1896 | { |
1897 | torch::tensor( |
1898 | {2.937649394399939, |
1899 | 2.850492965312311, |
1900 | 2.210902777232446, |
1901 | 2.1293746183147633, |
1902 | 1.5917084661873866, |
1903 | 2.6337100421079533}), |
1904 | torch::tensor( |
1905 | {3.7837853328443516, 2.5243155701130604, 2.8957265009949373}), |
1906 | torch::tensor( |
1907 | {-3.7234268485210475, -4.104949193318518, -2.836390693799751}), |
1908 | torch::tensor({-2.216118773360611}), |
1909 | }, |
1910 | { |
1911 | torch::tensor( |
1912 | {2.9787316798887558, |
1913 | 2.8991451078473207, |
1914 | 2.252272487010975, |
1915 | 2.1767288729202767, |
1916 | 1.6237627746697592, |
1917 | 2.670302507579268}), |
1918 | torch::tensor( |
1919 | {3.8530980655045086, 2.589275531553025, 2.9456388178450936}), |
1920 | torch::tensor( |
1921 | {-3.8886856459619636, -4.278191888396593, -2.9319964928350313}), |
1922 | torch::tensor({-2.285217699505124}), |
1923 | }, |
1924 | { |
1925 | torch::tensor( |
1926 | {3.017515620579049, |
1927 | 2.9452004251453268, |
1928 | 2.291469925522962, |
1929 | 2.2217217025782916, |
1930 | 1.6544730927621272, |
1931 | 2.705431441204422}), |
1932 | torch::tensor( |
1933 | {3.9190041004420166, 2.651317624465938, 2.993736489599992}), |
1934 | torch::tensor( |
1935 | {-4.053111559341913, -4.450801801162238, -3.0271845519131957}), |
1936 | torch::tensor({-2.3539498905973444}), |
1937 | }, |
1938 | }; |
1939 | } |
1940 | |
1941 | inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay() { |
1942 | return { |
1943 | { |
1944 | torch::tensor( |
1945 | {0.7890798754118442, |
1946 | 0.5024321083861885, |
1947 | 0.8587031097835685, |
1948 | 0.6579677474141494, |
1949 | 0.7476297677960806, |
1950 | 1.6975465611838714}), |
1951 | torch::tensor( |
1952 | {0.891458601354904, 0.7020500593937647, 1.6891986479348047}), |
1953 | torch::tensor( |
1954 | {-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}), |
1955 | torch::tensor({-1.0711379841318796}), |
1956 | }, |
1957 | { |
1958 | torch::tensor( |
1959 | {0.2139892652405453, |
1960 | 0.2779011713353896, |
1961 | 0.18684802794665187, |
1962 | 0.2507569370785562, |
1963 | 0.19145335235130007, |
1964 | 0.2557687813140708}), |
1965 | torch::tensor( |
1966 | {0.6720959116689083, 0.6480734848064099, 0.654263007004671}), |
1967 | torch::tensor( |
1968 | {-1.4357633640899097, -1.4493557950073235, -1.4619011018357073}), |
1969 | torch::tensor({-1.9673083558727926}), |
1970 | }, |
1971 | { |
1972 | torch::tensor( |
1973 | {0.23961935744660673, |
1974 | 0.30354236865888945, |
1975 | 0.19567694278583514, |
1976 | 0.2544696440133763, |
1977 | 0.21982879020261814, |
1978 | 0.27495711471979495}), |
1979 | torch::tensor( |
1980 | {0.6927895724658635, 0.638015535479105, 0.6523245375960234}), |
1981 | torch::tensor( |
1982 | {-1.413722583500382, -1.4170291001633526, -1.4166977298480703}), |
1983 | torch::tensor({-2.0626651115437147}), |
1984 | }, |
1985 | { |
1986 | torch::tensor( |
1987 | {0.2506635865272117, |
1988 | 0.314639511428635, |
1989 | 0.25116892910818034, |
1990 | 0.30431399579592144, |
1991 | 0.25219625048710015, |
1992 | 0.3160110008170742}), |
1993 | torch::tensor( |
1994 | {0.7051419232960522, 0.6699011906397543, 0.699097284678438}), |
1995 | torch::tensor( |
1996 | {-1.4206083241624232, -1.4257037444100107, -1.4171061826065132}), |
1997 | torch::tensor({-2.075537874763694}), |
1998 | }, |
1999 | { |
2000 | torch::tensor( |
2001 | {0.23285924743063605, |
2002 | 0.29652494304777544, |
2003 | 0.2335322002738168, |
2004 | 0.2969991261380461, |
2005 | 0.23358272245229555, |
2006 | 0.2973997498166104}), |
2007 | torch::tensor( |
2008 | {0.6855589594925036, 0.6796983775695974, 0.6864174803983276}), |
2009 | torch::tensor( |
2010 | {-1.43110762794651, -1.4334934742818164, -1.422739552145125}), |
2011 | torch::tensor({-2.0842642493046184}), |
2012 | }, |
2013 | { |
2014 | torch::tensor( |
2015 | {0.23356397699389828, |
2016 | 0.29737142391985355, |
2017 | 0.23367622061822368, |
2018 | 0.29749447597160267, |
2019 | 0.23418481357395918, |
2020 | 0.29818122925156104}), |
2021 | torch::tensor( |
2022 | {0.6866530583001205, 0.6858933385102559, 0.6883944045412603}), |
2023 | torch::tensor( |
2024 | {-1.4564955509607018, -1.4583548131500643, -1.4418225445708595}), |
2025 | torch::tensor({-2.1064103749186183}), |
2026 | }, |
2027 | { |
2028 | torch::tensor( |
2029 | {0.2318717301174723, |
2030 | 0.2952904159872858, |
2031 | 0.23194024439476665, |
2032 | 0.29537019824987687, |
2033 | 0.2316421336904657, |
2034 | 0.2951041425983894}), |
2035 | torch::tensor( |
2036 | {0.6834813130194509, 0.6834401711464199, 0.6837275457100463}), |
2037 | torch::tensor( |
2038 | {-1.4647835805276763, -1.4653452408179053, -1.4571142112777709}), |
2039 | torch::tensor({-2.1209598505912086}), |
2040 | }, |
2041 | { |
2042 | torch::tensor( |
2043 | {0.2308683396504178, |
2044 | 0.2940474629750448, |
2045 | 0.23089067678260966, |
2046 | 0.29407615110959306, |
2047 | 0.23064069314214175, |
2048 | 0.29379043611390243}), |
2049 | torch::tensor( |
2050 | {0.6815062281792611, 0.6815233687209215, 0.6812759203026146}), |
2051 | torch::tensor( |
2052 | {-1.4643013018530682, -1.4644523635284246, -1.4617493939684878}), |
2053 | torch::tensor({-2.1247293635678854}), |
2054 | }, |
2055 | { |
2056 | torch::tensor( |
2057 | {0.23066464201462678, |
2058 | 0.29376059273730426, |
2059 | 0.23067069245857366, |
2060 | 0.2937690399784267, |
2061 | 0.23057551211606675, |
2062 | 0.2936477517373107}), |
2063 | torch::tensor( |
2064 | {0.6809028781780304, 0.6809134105028244, 0.6807404613096301}), |
2065 | torch::tensor( |
2066 | {-1.4637927352177986, -1.4638374228010727, -1.46299287102643}), |
2067 | torch::tensor({-2.1258082720638107}), |
2068 | }, |
2069 | { |
2070 | torch::tensor( |
2071 | {0.23062625079199173, |
2072 | 0.2936990787425707, |
2073 | 0.2306278924729115, |
2074 | 0.2937014834661651, |
2075 | 0.23059813368157003, |
2076 | 0.29366073890476396}), |
2077 | torch::tensor( |
2078 | {0.6807251804689082, 0.6807295616357246, 0.6806523640328994}), |
2079 | torch::tensor( |
2080 | {-1.4635790398985618, -1.463592926902286, -1.4633272688236565}), |
2081 | torch::tensor({-2.1261396358141798}), |
2082 | }, |
2083 | { |
2084 | torch::tensor( |
2085 | {0.23061701122700193, |
2086 | 0.293683983817782, |
2087 | 0.23061747865501653, |
2088 | 0.29368467998690806, |
2089 | 0.23060855595638208, |
2090 | 0.2936719507340021}), |
2091 | torch::tensor( |
2092 | {0.6806714673830832, 0.6806730903175793, 0.6806434720800856}), |
2093 | torch::tensor( |
2094 | {-1.4635008778278134, -1.4635052859178375, -1.4634208375068285}), |
2095 | torch::tensor({-2.1262432969587723}), |
2096 | }, |
2097 | }; |
2098 | } |
2099 | |
2100 | inline std::vector<std::vector<torch::Tensor>> |
2101 | RMSprop_with_weight_decay_and_centered() { |
2102 | return { |
2103 | { |
2104 | torch::tensor( |
2105 | {0.7941000061626792, |
2106 | 0.507452636734552, |
2107 | 0.8637405354185987, |
2108 | 0.663005089317529, |
2109 | 0.7526661272860107, |
2110 | 1.7025887305065852}), |
2111 | torch::tensor( |
2112 | {0.8964950370033696, 0.7070877948157552, 1.6942369105467197}), |
2113 | torch::tensor( |
2114 | {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), |
2115 | torch::tensor({-1.0761757981162612}), |
2116 | }, |
2117 | { |
2118 | torch::tensor( |
2119 | {2.3762999876885833, |
2120 | 2.239095829416783, |
2121 | 1.726175067071914, |
2122 | 1.5891569459230444, |
2123 | 1.2410074108588462, |
2124 | 2.2345431036725723}), |
2125 | torch::tensor( |
2126 | {2.990896455635836, 1.8152108764849464, 2.377985429759037}), |
2127 | torch::tensor( |
2128 | {-2.3071822180635286, -2.636859516619699, -2.0198181394256642}), |
2129 | torch::tensor({-1.622583045791722}), |
2130 | }, |
2131 | { |
2132 | torch::tensor( |
2133 | {2.372800588647971, |
2134 | 2.3022753207224254, |
2135 | 1.836028714221617, |
2136 | 1.7190937269287105, |
2137 | 1.3068955839895078, |
2138 | 2.3035835673200364}), |
2139 | torch::tensor( |
2140 | {3.1656599892042343, 1.9942937608209466, 2.4947143457182657}), |
2141 | torch::tensor( |
2142 | {-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}), |
2143 | torch::tensor({-1.7513053380188806}), |
2144 | }, |
2145 | { |
2146 | torch::tensor( |
2147 | {2.2398453700818455, |
2148 | 2.2513384246965904, |
2149 | 1.8892176431436287, |
2150 | 1.7921873754661686, |
2151 | 1.3310951408713538, |
2152 | 2.3236392222350397}), |
2153 | torch::tensor( |
2154 | {3.240166119454613, 2.109742813600189, 2.5651614461576973}), |
2155 | torch::tensor( |
2156 | {-2.8388734382997454, -3.1824200770676123, -2.324831397600949}), |
2157 | torch::tensor({-1.8460315737386976}), |
2158 | }, |
2159 | { |
2160 | torch::tensor( |
2161 | {1.9829606312242465, |
2162 | 2.097356567850692, |
2163 | 1.9050263843525033, |
2164 | 1.8325835415812346, |
2165 | 1.3222762370713104, |
2166 | 2.3024963133870147}), |
2167 | torch::tensor( |
2168 | {3.2465360572089974, 2.1967266045869915, 2.6091992649970672}), |
2169 | torch::tensor( |
2170 | {-3.0326878099587207, -3.3827004807595005, -2.436989182250496}), |
2171 | torch::tensor({-1.928273216206344}), |
2172 | }, |
2173 | { |
2174 | torch::tensor( |
2175 | {1.6051175329080525, |
2176 | 1.8332107491649114, |
2177 | 1.8794767349053179, |
2178 | 1.8403588051948856, |
2179 | 1.273824111314107, |
2180 | 2.2296571379436823}), |
2181 | torch::tensor( |
2182 | {3.1814362940910437, 2.263019214072847, 2.6273016977574013}), |
2183 | torch::tensor( |
2184 | {-3.210932646440219, -3.567153254014387, -2.541016943923914}), |
2185 | torch::tensor({-2.0049155134617154}), |
2186 | }, |
2187 | { |
2188 | torch::tensor( |
2189 | {1.1588059349082709, |
2190 | 1.477861379523226, |
2191 | 1.7992410089026634, |
2192 | 1.806460009198667, |
2193 | 1.1739931551629919, |
2194 | 2.08647960875392}), |
2195 | torch::tensor( |
2196 | {3.03843703712275, 2.308203068375877, 2.6125393914734083}), |
2197 | torch::tensor( |
2198 | {-3.379830678608588, -3.741970414470626, -2.6410082400846546}), |
2199 | torch::tensor({-2.079294995910487}), |
2200 | }, |
2201 | { |
2202 | torch::tensor( |
2203 | {0.7701433312419088, |
2204 | 1.1105026677424745, |
2205 | 1.646507516936639, |
2206 | 1.71625269098179, |
2207 | 1.013748545414221, |
2208 | 1.8532966501655352}), |
2209 | torch::tensor( |
2210 | {2.827176875885245, 2.327401948159928, 2.5535309398603405}), |
2211 | torch::tensor( |
2212 | {-3.54193329850986, -3.9096652952123145, -2.739408870192437}), |
2213 | torch::tensor({-2.1537939241668997}), |
2214 | }, |
2215 | { |
2216 | torch::tensor( |
2217 | {0.5598923129351211, |
2218 | 0.8460500042788701, |
2219 | 1.4084175549165017, |
2220 | 1.5547314210944563, |
2221 | 0.8019580519338424, |
2222 | 1.5258384663629627}), |
2223 | torch::tensor( |
2224 | {2.5774950379490265, 2.313101306699127, 2.4388695757441745}), |
2225 | torch::tensor( |
2226 | {-3.6974974230160087, -4.070190514312716, -2.8378932675718405}), |
2227 | torch::tensor({-2.2307225014430423}), |
2228 | }, |
2229 | { |
2230 | torch::tensor( |
2231 | {0.5016784472836648, |
2232 | 0.7258690889265433, |
2233 | 1.0976902935953956, |
2234 | 1.319949187972513, |
2235 | 0.5853930356154851, |
2236 | 1.1446978015944624}), |
2237 | torch::tensor( |
2238 | {2.3235249877284945, 2.2592840970420176, 2.2681461698609375}), |
2239 | torch::tensor( |
2240 | {-3.8444921272569115, -4.22021051361099, -2.9373192115434263}), |
2241 | torch::tensor({-2.312733063937045}), |
2242 | }, |
2243 | { |
2244 | torch::tensor( |
2245 | {0.4875468895095056, |
2246 | 0.6878747871467128, |
2247 | 0.7787871237567606, |
2248 | 1.0462592546102176, |
2249 | 0.4416468896022397, |
2250 | 0.8122992916762792}), |
2251 | torch::tensor( |
2252 | {2.1078734515587483, 2.17034337037527, 2.0666325968568535}), |
2253 | torch::tensor( |
2254 | {-3.9782695475825216, -4.352093055115415, -3.0377809502927033}), |
2255 | torch::tensor({-2.403496388200805}), |
2256 | }, |
2257 | }; |
2258 | } |
2259 | |
2260 | inline std::vector<std::vector<torch::Tensor>> |
2261 | RMSprop_with_weight_decay_and_centered_and_momentum() { |
2262 | return { |
2263 | { |
2264 | torch::tensor( |
2265 | {0.7941000061626794, |
2266 | 0.507452636734552, |
2267 | 0.8637405354185985, |
2268 | 0.663005089317529, |
2269 | 0.7526661272860107, |
2270 | 1.7025887305065852}), |
2271 | torch::tensor( |
2272 | {0.8964950370033699, 0.7070877948157552, 1.6942369105467197}), |
2273 | torch::tensor( |
2274 | {-1.055840599214661, -1.3991726335388424, -1.2893752132746332}), |
2275 | torch::tensor({-1.0761757981162612}), |
2276 | }, |
2277 | { |
2278 | torch::tensor( |
2279 | {11.587263945492355, |
2280 | 12.552112516667206, |
2281 | 10.773002960161074, |
2282 | 10.782117868337808, |
2283 | 9.675467654064093, |
2284 | 10.830689360054789}), |
2285 | torch::tensor( |
2286 | {15.298238342006444, 11.252244653209866, 11.423905295074075}), |
2287 | torch::tensor( |
2288 | {-11.287147147258441, -11.673871066494183, -11.143068139029769}), |
2289 | torch::tensor({-10.744790465364126}), |
2290 | }, |
2291 | { |
2292 | torch::tensor( |
2293 | {5.993130757784388, |
2294 | 7.778269455146452, |
2295 | 9.705741295559012, |
2296 | 9.974952848613889, |
2297 | 8.171307305871647, |
2298 | 9.551498426643077}), |
2299 | torch::tensor( |
2300 | {12.811268477045155, 10.912201832960703, 10.87477550647832}), |
2301 | torch::tensor( |
2302 | {-11.20842921856976, -11.58706973895515, -11.098172235374586}), |
2303 | torch::tensor({-10.714110383698559}), |
2304 | }, |
2305 | { |
2306 | torch::tensor( |
2307 | {1.917316794757853, |
2308 | 3.442098373003915, |
2309 | 8.160846071267297, |
2310 | 8.76673426856121, |
2311 | 6.163892823252042, |
2312 | 7.748894752821816}), |
2313 | torch::tensor( |
2314 | {9.52929937981379, 10.371703621802425, 10.02242566317017}), |
2315 | torch::tensor( |
2316 | {-11.07914626767133, -11.444639737948599, -11.02397978065452}), |
2317 | torch::tensor({-10.663204622623406}), |
2318 | }, |
2319 | { |
2320 | torch::tensor( |
2321 | {0.24211162925745067, |
2322 | 0.8235150923738451, |
2323 | 6.109652191353378, |
2324 | 7.070860554523037, |
2325 | 3.8366635637770212, |
2326 | 5.46037058418296}), |
2327 | torch::tensor( |
2328 | {5.7908039507441, 9.534309069066389, 8.752252906881251}), |
2329 | torch::tensor( |
2330 | {-10.868651889371552, -11.212965695734527, -10.90242744782103}), |
2331 | torch::tensor({-10.579596899816439}), |
2332 | }, |
2333 | { |
2334 | torch::tensor( |
2335 | {0.0024206009020476234, |
2336 | 0.05521740497689468, |
2337 | 3.753606156332189, |
2338 | 4.9331546064599685, |
2339 | 1.7094621184709604, |
2340 | 3.022224882400484}), |
2341 | torch::tensor( |
2342 | {2.4729429920325234, 8.290211439306459, 6.983317870704776}), |
2343 | torch::tensor( |
2344 | {-10.529133489023623, -10.839885990130032, -10.704345435808353}), |
2345 | torch::tensor({-10.44279235413811}), |
2346 | }, |
2347 | { |
2348 | torch::tensor( |
2349 | {8.523664833406631e-06, |
2350 | -0.00018498015809617104, |
2351 | 1.6343074841140277, |
2352 | 2.683608480982546, |
2353 | 0.41425107807132744, |
2354 | 1.092111816609512}), |
2355 | torch::tensor( |
2356 | {0.553119873538318, 6.566845593450314, 4.783317472190566}), |
2357 | torch::tensor( |
2358 | {-9.990101114696575, -10.24914448933998, -10.38447825909146}), |
2359 | torch::tensor({-10.220382375374728}), |
2360 | }, |
2361 | { |
2362 | torch::tensor( |
2363 | {5.3669182339397725e-08, |
2364 | -2.899704029399283e-07, |
2365 | 0.37916783268568177, |
2366 | 0.9399553431452395, |
2367 | 0.02859528129337607, |
2368 | 0.17650614337704745}), |
2369 | torch::tensor( |
2370 | {0.03166973497545419, 4.442846994093523, 2.5203464928754724}), |
2371 | torch::tensor( |
2372 | {-9.15653357178671, -9.339631853060773, -9.875729313751442}), |
2373 | torch::tensor({-9.862669711962374}), |
2374 | }, |
2375 | { |
2376 | torch::tensor( |
2377 | {2.1133356499004335e-06, |
2378 | 2.4524630407768025e-06, |
2379 | 0.023655729923601883, |
2380 | 0.14273709578291396, |
2381 | -8.950192389690758e-05, |
2382 | 0.004237697008964042}), |
2383 | torch::tensor( |
2384 | {-0.00012364097582548376, 2.291191859107928, 0.8331414409602524}), |
2385 | torch::tensor( |
2386 | {-7.922566174765117, -8.003055545094796, -9.086673634672907}), |
2387 | torch::tensor({-9.297519364373224}), |
2388 | }, |
2389 | { |
2390 | torch::tensor( |
2391 | {0.0023497430294992434, |
2392 | 0.0028611316714725037, |
2393 | 0.0006998739627296072, |
2394 | 0.003657156536057531, |
2395 | 0.001654303471369622, |
2396 | 0.0018171459470053366}), |
2397 | torch::tensor( |
2398 | {0.004569191565477355, 0.7292466599711233, 0.11475431260766135}), |
2399 | torch::tensor( |
2400 | {-6.223834483308681, -6.185383631607397, -7.912955414853613}), |
2401 | torch::tensor({-8.430731662958186}), |
2402 | }, |
2403 | { |
2404 | torch::tensor( |
2405 | {0.10393820340367545, |
2406 | 0.13982074666181732, |
2407 | 0.0831407198272949, |
2408 | 0.10183584198629944, |
2409 | 0.13949594516972202, |
2410 | 0.17822672100147108}), |
2411 | torch::tensor( |
2412 | {0.340394645020639, 0.24860888862359687, 0.3191404515531066}), |
2413 | torch::tensor( |
2414 | {-4.174294597914298, -4.037528929635062, -6.297198700024484}), |
2415 | torch::tensor({-7.182093090194918}), |
2416 | }, |
2417 | }; |
2418 | } |
2419 | |
2420 | inline std::vector<std::vector<torch::Tensor>> SGD() { |
2421 | return { |
2422 | { |
2423 | torch::tensor( |
2424 | {-0.21063957030131192, |
2425 | -0.4972093725858961, |
2426 | -0.13931849072410168, |
2427 | -0.33939101965581686, |
2428 | -0.25112865488453673, |
2429 | 0.6992101966874735}), |
2430 | torch::tensor( |
2431 | {-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}), |
2432 | torch::tensor( |
2433 | {-0.07998325778863398, |
2434 | -0.42149210515421365, |
2435 | -0.33498349553944556}), |
2436 | torch::tensor({-0.14255126505509488}), |
2437 | }, |
2438 | { |
2439 | torch::tensor( |
2440 | {-0.15543131540224012, |
2441 | -0.42351103963720343, |
2442 | -0.04196796248622072, |
2443 | -0.2095223178068499, |
2444 | -0.16031407286541022, |
2445 | 0.8209742464453325}), |
2446 | torch::tensor( |
2447 | {0.07724343607160136, 0.03387529472490231, 1.0028793648054941}), |
2448 | torch::tensor( |
2449 | {-0.8213382425894498, -1.1570800333254736, -1.615476033165743}), |
2450 | torch::tensor({-1.8734090731084845}), |
2451 | }, |
2452 | { |
2453 | torch::tensor( |
2454 | {-0.13342791770744886, |
2455 | -0.3941509709488104, |
2456 | -0.011470356542661934, |
2457 | -0.16885142516066962, |
2458 | -0.13306680693528108, |
2459 | 0.8576491729785701}), |
2460 | torch::tensor( |
2461 | {0.15081014600761677, 0.13560816175111742, 1.0971559708365837}), |
2462 | torch::tensor( |
2463 | {-0.9780975407869251, -1.3215153697157924, -1.876021387605152}), |
2464 | torch::tensor({-2.2024413056528886}), |
2465 | }, |
2466 | { |
2467 | torch::tensor( |
2468 | {-0.11963097684681223, |
2469 | -0.37573675130134543, |
2470 | 0.0069987166413883715, |
2471 | -0.14420855651125972, |
2472 | -0.11733423659038758, |
2473 | 0.8788673419128562}), |
2474 | torch::tensor( |
2475 | {0.1969829338759005, 0.1973461164047132, 1.1520119567305152}), |
2476 | torch::tensor( |
2477 | {-1.0677802792431819, -1.4166561260631119, -2.022033753216991}), |
2478 | torch::tensor({-2.383452427292781}), |
2479 | }, |
2480 | { |
2481 | torch::tensor( |
2482 | {-0.10950806441156272, |
2483 | -0.3622226699218595, |
2484 | 0.02028489243523426, |
2485 | -0.1264725422838007, |
2486 | -0.10635775660996463, |
2487 | 0.8936912722040982}), |
2488 | torch::tensor( |
2489 | {0.23089462331826793, 0.24184450074084418, 1.1904864598387046}), |
2490 | torch::tensor( |
2491 | {-1.1306213044009719, -1.4837186483578142, -2.122884602514208}), |
2492 | torch::tensor({-2.5071352505158395}), |
2493 | }, |
2494 | { |
2495 | torch::tensor( |
2496 | {-0.10149090356585248, |
2497 | -0.3515172115812867, |
2498 | 0.030662536099764083, |
2499 | -0.11261325211798616, |
2500 | -0.09797248308626623, |
2501 | 0.905027632401109}), |
2502 | torch::tensor( |
2503 | {0.25777759826689434, 0.2766609657536915, 1.2199973265718322}), |
2504 | torch::tensor( |
2505 | {-1.1789655573653979, -1.5355073692636774, -2.199612583884608}), |
2506 | torch::tensor({-2.600529541471662}), |
2507 | }, |
2508 | { |
2509 | torch::tensor( |
2510 | {-0.09484472748389533, |
2511 | -0.3426405023243085, |
2512 | 0.03917399284640637, |
2513 | -0.10124188994381228, |
2514 | -0.09121264836307835, |
2515 | 0.9141743475340721}), |
2516 | torch::tensor( |
2517 | {0.2800829300171032, 0.3052600200290069, 1.2438661306695873}), |
2518 | torch::tensor( |
2519 | {-1.2182324765944266, -1.5776851394085492, -2.2613704866316295}), |
2520 | torch::tensor({-2.6752743361973184}), |
2521 | }, |
2522 | { |
2523 | torch::tensor( |
2524 | {-0.08916446117741175, |
2525 | -0.33505233521798666, |
2526 | 0.04638527943959316, |
2527 | -0.09160422984057517, |
2528 | -0.08556486270584644, |
2529 | 0.9218219103015535}), |
2530 | torch::tensor( |
2531 | {0.2991619380154852, 0.3295237551295101, 1.2638639017720827}), |
2532 | torch::tensor( |
2533 | {-1.251282493526328, -1.6132564639504312, -2.3129529937213853}), |
2534 | torch::tensor({-2.73741957239466}), |
2535 | }, |
2536 | { |
2537 | torch::tensor( |
2538 | {-0.08420245801272856, |
2539 | -0.3284224385121882, |
2540 | 0.05263847708646642, |
2541 | -0.08324438788845245, |
2542 | -0.08072424164719598, |
2543 | 0.9283806476306355}), |
2544 | torch::tensor( |
2545 | {0.31584087342663564, 0.35059019818200393, 1.2810450644764015}), |
2546 | torch::tensor( |
2547 | {-1.2798091496372141, -1.6440072538210193, -2.357180462961105}), |
2548 | torch::tensor({-2.7905023459395872}), |
2549 | }, |
2550 | { |
2551 | torch::tensor( |
2552 | {-0.07979600214534928, |
2553 | -0.3225337978155753, |
2554 | 0.0581562720006689, |
2555 | -0.07586555700667826, |
2556 | -0.07649523955108037, |
2557 | 0.9341138824526719}), |
2558 | torch::tensor( |
2559 | {0.3306627217189733, 0.3692005578577212, 1.2960873917356066}), |
2560 | torch::tensor( |
2561 | {-1.3048976883823566, -1.6710855742501123, -2.395849898454614}), |
2562 | torch::tensor({-2.836765085555123}), |
2563 | }, |
2564 | { |
2565 | torch::tensor( |
2566 | {-0.07583232846497832, |
2567 | -0.3172360102461862, |
2568 | 0.06309179259248046, |
2569 | -0.06926361352067158, |
2570 | -0.07274510848082802, |
2571 | 0.9392004636935606}), |
2572 | torch::tensor( |
2573 | {0.3440038606091545, 0.3858647867996722, 1.3094518934419668}), |
2574 | torch::tensor( |
2575 | {-1.3272851146877218, -1.6952731308502653, -2.4301754289421598}), |
2576 | torch::tensor({-2.8777164728823017}), |
2577 | }, |
2578 | }; |
2579 | } |
2580 | |
2581 | inline std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay() { |
2582 | return { |
2583 | { |
2584 | torch::tensor( |
2585 | {-0.21042867144447805, |
2586 | -0.49671181653925384, |
2587 | -0.13917719856207697, |
2588 | -0.3390489907590303, |
2589 | -0.2508762913762564, |
2590 | 0.6985126396619242}), |
2591 | torch::tensor( |
2592 | {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), |
2593 | torch::tensor( |
2594 | {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), |
2595 | torch::tensor({-0.14248012693079315}), |
2596 | }, |
2597 | { |
2598 | torch::tensor( |
2599 | {-0.13579982290274883, |
2600 | -0.3765456284475787, |
2601 | -0.03166970700350034, |
2602 | -0.18102559254681197, |
2603 | -0.1373234786735746, |
2604 | 0.7522156177001302}), |
2605 | torch::tensor( |
2606 | {0.08550003826014418, 0.051563225553454196, 0.9321399061276381}), |
2607 | torch::tensor( |
2608 | {-0.796312238882584, -1.1010063686038731, -1.5363716774172782}), |
2609 | torch::tensor({-1.8045854907382846}), |
2610 | }, |
2611 | { |
2612 | torch::tensor( |
2613 | {-0.09659168723529124, |
2614 | -0.30562076936588267, |
2615 | 0.0067128671455129185, |
2616 | -0.1166002367977548, |
2617 | -0.09012083166238948, |
2618 | 0.7264953102453368}), |
2619 | torch::tensor( |
2620 | {0.16531808496504802, 0.16488328577596398, 0.9610743966573317}), |
2621 | torch::tensor( |
2622 | {-0.9202466399245914, -1.2052829272891832, -1.7049756710541348}), |
2623 | torch::tensor({-2.0415977924493043}), |
2624 | }, |
2625 | { |
2626 | torch::tensor( |
2627 | {-0.06728100597713035, |
2628 | -0.2496589601654196, |
2629 | 0.03186158526394668, |
2630 | -0.07105441484407878, |
2631 | -0.056478595544178806, |
2632 | 0.6910758436366733}), |
2633 | torch::tensor( |
2634 | {0.21707768347081777, 0.23575238192099465, 0.9564382346520686}), |
2635 | torch::tensor( |
2636 | {-0.9788195039029999, -1.2447191597975946, -1.762020156061963}), |
2637 | torch::tensor({-2.131504419683077}), |
2638 | }, |
2639 | { |
2640 | torch::tensor( |
2641 | {-0.04304955053155505, |
2642 | -0.20206572730420902, |
2643 | 0.050959513946324475, |
2644 | -0.034700093557440984, |
2645 | -0.02922465201167018, |
2646 | 0.6547611705604361}), |
2647 | torch::tensor( |
2648 | {0.2563898231537708, 0.2878867158887637, 0.9414221685252802}), |
2649 | torch::tensor( |
2650 | {-1.0143969472996655, -1.2623288365082088, -1.7800471460065668}), |
2651 | torch::tensor({-2.170255083720924}), |
2652 | }, |
2653 | { |
2654 | torch::tensor( |
2655 | {-0.022154717038262738, |
2656 | -0.16036518660639862, |
2657 | 0.06644401410758827, |
2658 | -0.004183373274651896, |
2659 | -0.005965877978527781, |
2660 | 0.6200298215101535}), |
2661 | torch::tensor( |
2662 | {0.2886406829874717, 0.32924516791460257, 0.9230983700837223}), |
2663 | torch::tensor( |
2664 | {-1.0397895250773481, -1.2710914166240181, -1.7807758009603087}), |
2665 | torch::tensor({-2.1862978976514738}), |
2666 | }, |
2667 | { |
2668 | torch::tensor( |
2669 | {-0.0037439139848317077, |
2670 | -0.12328293308251938, |
2671 | 0.07944696186805641, |
2672 | 0.022100305718442022, |
2673 | 0.014399113804332037, |
2674 | 0.587697912745227}), |
2675 | torch::tensor( |
2676 | {0.3162871074692008, 0.36346293565421134, 0.9042402154310412}), |
2677 | torch::tensor( |
2678 | {-1.060234961430088, -1.2762264965487675, -1.7731268727630662}), |
2679 | torch::tensor({-2.191253945056341}), |
2680 | }, |
2681 | { |
2682 | torch::tensor( |
2683 | {0.012675985938854726, |
2684 | -0.09003711893222131, |
2685 | 0.09059095692632844, |
2686 | 0.04506778924310349, |
2687 | 0.03247299240601001, |
2688 | 0.5579755127260052}), |
2689 | torch::tensor( |
2690 | {0.3406226998933173, 0.3924947745885882, 0.8860121369119325}), |
2691 | torch::tensor( |
2692 | {-1.0781407849705034, -1.2800528898634018, -1.7613120374342215}), |
2693 | torch::tensor({-2.190575043873577}), |
2694 | }, |
2695 | { |
2696 | torch::tensor( |
2697 | {0.027425440985777993, |
2698 | -0.06008809958617219, |
2699 | 0.10026092920861808, |
2700 | 0.06531092947039244, |
2701 | 0.048628754907931976, |
2702 | 0.5308215072596255}), |
2703 | torch::tensor( |
2704 | {0.36239744520280553, 0.4175162387638887, 0.8688788105023479}), |
2705 | torch::tensor( |
2706 | {-1.0946579691370502, -1.283610342226948, -1.7474706191775764}), |
2707 | torch::tensor({-2.1870021744944763}), |
2708 | }, |
2709 | { |
2710 | torch::tensor( |
2711 | {0.04073250980147411, |
2712 | -0.03303024103555013, |
2713 | 0.1087177047593139, |
2714 | 0.08324870459183518, |
2715 | 0.0631222868881554, |
2716 | 0.5060892094042873}), |
2717 | torch::tensor( |
2718 | {0.38208249693950175, 0.4393002654989596, 0.8529817924677643}), |
2719 | torch::tensor( |
2720 | {-1.1103326127955466, -1.287332405916359, -1.73273866274852}), |
2721 | torch::tensor({-2.1819672316721337}), |
2722 | }, |
2723 | { |
2724 | torch::tensor( |
2725 | {0.05277160918732605, |
2726 | -0.008539186625351441, |
2727 | 0.1161515444487197, |
2728 | 0.09919929206676087, |
2729 | 0.07614530177703588, |
2730 | 0.48359250162323586}), |
2731 | torch::tensor( |
2732 | {0.3999968617221315, 0.45839442009256354, 0.8383132966805791}), |
2733 | torch::tensor( |
2734 | {-1.1254107858333455, -1.2913604197768889, -1.717739109221235}), |
2735 | torch::tensor({-2.1762368071604308}), |
2736 | }, |
2737 | }; |
2738 | } |
2739 | |
2740 | inline std::vector<std::vector<torch::Tensor>> |
2741 | SGD_with_weight_decay_and_momentum() { |
2742 | return { |
2743 | { |
2744 | torch::tensor( |
2745 | {-0.21042867144447805, |
2746 | -0.49671181653925384, |
2747 | -0.13917719856207697, |
2748 | -0.3390489907590303, |
2749 | -0.2508762913762564, |
2750 | 0.6985126396619242}), |
2751 | torch::tensor( |
2752 | {-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}), |
2753 | torch::tensor( |
2754 | {-0.079932454658518, -0.42109796996670307, -0.33469915794198624}), |
2755 | torch::tensor({-0.14248012693079315}), |
2756 | }, |
2757 | { |
2758 | torch::tensor( |
2759 | {0.0056118487251954775, |
2760 | -0.0710915563059199, |
2761 | 0.07701400891926036, |
2762 | 0.047067327035013866, |
2763 | 0.0428654052972598, |
2764 | 0.4352977220593751}), |
2765 | torch::tensor( |
2766 | {0.23834837300214828, 0.32366382503704183, 0.7128321016634689}), |
2767 | torch::tensor( |
2768 | {-1.041947788394885, -1.1730950187020548, -1.7648205873351157}), |
2769 | torch::tensor({-2.3359277661920594}), |
2770 | }, |
2771 | { |
2772 | torch::tensor( |
2773 | {0.11520007183759418, |
2774 | 0.1289453768763286, |
2775 | 0.14586845555951963, |
2776 | 0.1775341535876219, |
2777 | 0.15614155642578995, |
2778 | 0.33379126147460536}), |
2779 | torch::tensor( |
2780 | {0.465853656413685, 0.520197917876909, 0.7274876508280723}), |
2781 | torch::tensor( |
2782 | {-1.2034746444882527, -1.286126969233868, -1.604528340632377}), |
2783 | torch::tensor({-2.203215909196624}), |
2784 | }, |
2785 | { |
2786 | torch::tensor( |
2787 | {0.15331258730374997, |
2788 | 0.197909036233604, |
2789 | 0.16663814647374195, |
2790 | 0.2183320498727895, |
2791 | 0.1803274550482287, |
2792 | 0.28362745794417826}), |
2793 | torch::tensor( |
2794 | {0.5532312776994917, 0.5834224152126115, 0.6903579410976886}), |
2795 | torch::tensor( |
2796 | {-1.3052171323471546, -1.3514190497186434, -1.5153574535010634}), |
2797 | torch::tensor({-2.123181139806548}), |
2798 | }, |
2799 | { |
2800 | torch::tensor( |
2801 | {0.16814113185552507, |
2802 | 0.22386572201448868, |
2803 | 0.17413795101952864, |
2804 | 0.23280515326261633, |
2805 | 0.1839142207976228, |
2806 | 0.2614499495870909}), |
2807 | torch::tensor( |
2808 | {0.592282876576759, 0.6083877519652824, 0.663438748699906}), |
2809 | torch::tensor( |
2810 | {-1.3591143274292896, -1.383673065830997, -1.467157893517277}), |
2811 | torch::tensor({-2.087859547998447}), |
2812 | }, |
2813 | { |
2814 | torch::tensor( |
2815 | {0.1743742243877178, |
2816 | 0.2343126153059798, |
2817 | 0.17716942927642254, |
2818 | 0.23838669643330088, |
2819 | 0.18308461132092924, |
2820 | 0.25149544624452974}), |
2821 | torch::tensor( |
2822 | {0.6108281747800746, 0.6192657661217672, 0.6475519545045926}), |
2823 | torch::tensor( |
2824 | {-1.3860527054444407, -1.398816664238087, -1.4412527948055516}), |
2825 | torch::tensor({-2.0731939075659627}), |
2826 | }, |
2827 | { |
2828 | torch::tensor( |
2829 | {0.1771465478751462, |
2830 | 0.23875859951719522, |
2831 | 0.1784868271584857, |
2832 | 0.2406786372566496, |
2833 | 0.18181103291606765, |
2834 | 0.24687877342069478}), |
2835 | torch::tensor( |
2836 | {0.6198586021174767, 0.6242349464856269, 0.638736845373371}), |
2837 | torch::tensor( |
2838 | {-1.3993307716862977, -1.4058965193851591, -1.42747775986796}), |
2839 | torch::tensor({-2.0672675843404598}), |
2840 | }, |
2841 | { |
2842 | torch::tensor( |
2843 | {0.17843093585357683, |
2844 | 0.24073954802700465, |
2845 | 0.17908697027440873, |
2846 | 0.2416675839909268, |
2847 | 0.18088350526559058, |
2848 | 0.24467193314356378}), |
2849 | torch::tensor( |
2850 | {0.6243071074374693, 0.6265628975677455, 0.6339840865876518}), |
2851 | torch::tensor( |
2852 | {-1.4058750036106915, -1.4092362337714568, -1.4202202926903085}), |
2853 | torch::tensor({-2.0649062340635584}), |
2854 | }, |
2855 | { |
2856 | torch::tensor( |
2857 | {0.17904350645021613, |
2858 | 0.24165496946247034, |
2859 | 0.17936920658487726, |
2860 | 0.24211164489776849, |
2861 | 0.18031858582735988, |
2862 | 0.2435923992630521}), |
2863 | torch::tensor( |
2864 | {0.626513445507806, 0.6276715667697311, 0.6314641991686346}), |
2865 | torch::tensor( |
2866 | {-1.409113940967948, -1.410830795235453, -1.4164247285253404}), |
2867 | torch::tensor({-2.0639728292802046}), |
2868 | }, |
2869 | { |
2870 | torch::tensor( |
2871 | {0.17934167113683835, |
2872 | 0.242089962404631, |
2873 | 0.17950490408309286, |
2874 | 0.24231745350706005, |
2875 | 0.17999989292556767, |
2876 | 0.2430557755257577}), |
2877 | torch::tensor( |
2878 | {0.6276131793232345, 0.6282062328090801, 0.6301427155170752}), |
2879 | torch::tensor( |
2880 | {-1.4107251789010826, -1.4116011824171857, -1.4144511767962422}), |
2881 | torch::tensor({-2.0636056316673934}), |
2882 | }, |
2883 | { |
2884 | torch::tensor( |
2885 | {0.17948886155124505, |
2886 | 0.24230096332204806, |
2887 | 0.17957117450689372, |
2888 | 0.242415213133214, |
2889 | 0.17982712042628357, |
2890 | 0.2427862039224869}), |
2891 | torch::tensor( |
2892 | {0.6281635672171683, 0.6284667582211864, 0.6294549191500093}), |
2893 | torch::tensor( |
2894 | {-1.4115305541843781, -1.4119772978756444, -1.4134296522818641}), |
2895 | torch::tensor({-2.0634616066978615}), |
2896 | }, |
2897 | }; |
2898 | } |
2899 | |
2900 | inline std::vector<std::vector<torch::Tensor>> |
2901 | SGD_with_weight_decay_and_nesterov_momentum() { |
2902 | return { |
2903 | { |
2904 | torch::tensor( |
2905 | {-0.21040617235121148, |
2906 | -0.49689727139951717, |
2907 | -0.13754215970803657, |
2908 | -0.33701686525263036, |
2909 | -0.2500172388792182, |
2910 | 0.700697918175925}), |
2911 | torch::tensor( |
2912 | {-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}), |
2913 | torch::tensor( |
2914 | {-0.10624536304143092, -0.4461132561477894, -0.3805647497874434}), |
2915 | torch::tensor({-0.2068230782168696}), |
2916 | }, |
2917 | { |
2918 | torch::tensor( |
2919 | {-0.1262387113548655, |
2920 | -0.3844658218758334, |
2921 | 0.03124406856508884, |
2922 | -0.11170532152425781, |
2923 | -0.09823268522398332, |
2924 | 0.9040698525178972}), |
2925 | torch::tensor( |
2926 | {0.17551336074135096, 0.27976614792027166, 1.2138399680985128}), |
2927 | torch::tensor( |
2928 | {-1.592840413595591, -1.8986806244521564, -2.966181914454827}), |
2929 | torch::tensor({-3.7728444542017687}), |
2930 | }, |
2931 | { |
2932 | torch::tensor( |
2933 | {-0.11614716303292183, |
2934 | -0.3709539909720773, |
2935 | 0.04307078045512772, |
2936 | -0.09588329367245825, |
2937 | -0.08795603365024904, |
2938 | 0.9178771227283019}), |
2939 | torch::tensor( |
2940 | {0.20944042006388683, 0.3195483889401668, 1.2500270348310718}), |
2941 | torch::tensor( |
2942 | {-1.635011052494502, -1.9463243375558272, -3.035708036973984}), |
2943 | torch::tensor({-3.8570351018212796}), |
2944 | }, |
2945 | { |
2946 | torch::tensor( |
2947 | {-0.10793942832760066, |
2948 | -0.35995697973682966, |
2949 | 0.05260329955808716, |
2950 | -0.08312010825923577, |
2951 | -0.07986326997915319, |
2952 | 0.9287409473303162}), |
2953 | torch::tensor( |
2954 | {0.2370574459090396, 0.35168415020524857, 1.278618438127574}), |
2955 | torch::tensor( |
2956 | {-1.669141810658011, -1.984894370767313, -3.091259532917102}), |
2957 | torch::tensor({-3.923827025320545}), |
2958 | }, |
2959 | { |
2960 | torch::tensor( |
2961 | {-0.1010142826857921, |
2962 | -0.35067247612415425, |
2963 | 0.06058642765135953, |
2964 | -0.07242353828264116, |
2965 | -0.07320722520220559, |
2966 | 0.9376663294528951}), |
2967 | torch::tensor( |
2968 | {0.26037531373638517, 0.37864768429039036, 1.3021925174954938}), |
2969 | torch::tensor( |
2970 | {-1.6978623668013235, -2.017346013780729, -3.137511248751908}), |
2971 | torch::tensor({-3.9791368472670334}), |
2972 | }, |
2973 | { |
2974 | torch::tensor( |
2975 | {-0.0950223925827384, |
2976 | -0.34263425874631004, |
2977 | 0.06744912149060932, |
2978 | -0.0632219668955612, |
2979 | -0.06756850374320933, |
2980 | 0.9452179348012486}), |
2981 | torch::tensor( |
2982 | {0.2805629021730173, 0.40186559210837897, 1.322201974233735}), |
2983 | torch::tensor( |
2984 | {-1.7226667375672964, -2.0453651314263936, -3.177094625235675}), |
2985 | torch::tensor({-4.02626353351958}), |
2986 | }, |
2987 | { |
2988 | torch::tensor( |
2989 | {-0.08974074058929343, |
2990 | -0.3355446553621404, |
2991 | 0.07346375443579244, |
2992 | -0.055152336279101065, |
2993 | -0.06268648871001672, |
2994 | 0.9517469338705254}), |
2995 | torch::tensor( |
2996 | {0.2983667593362755, 0.42224471824689497, 1.3395523443811077}), |
2997 | torch::tensor( |
2998 | {-1.7445022249557358, -2.070022023204061, -3.2116640112699977}), |
2999 | torch::tensor({-4.0672678681014025}), |
3000 | }, |
3001 | { |
3002 | torch::tensor( |
3003 | {-0.08501805567029425, |
3004 | -0.32920173535901165, |
3005 | 0.07881418855733362, |
3006 | -0.04796950604202488, |
3007 | -0.058388411064112675, |
3008 | 0.9574862878804136}), |
3009 | torch::tensor( |
3010 | {0.314293480998118, 0.44039784591234016, 1.3548455497581404}), |
3011 | torch::tensor( |
3012 | {-1.7640079833055848, -2.092039516337239, -3.2423272727017984}), |
3013 | torch::tensor({-4.1035226231344275}), |
3014 | }, |
3015 | { |
3016 | torch::tensor( |
3017 | {-0.08074691916762683, |
3018 | -0.32346209125355385, |
3019 | 0.08363041954091821, |
3020 | -0.04150011326440421, |
3021 | -0.05455400195824525, |
3022 | 0.9625982669377223}), |
3023 | torch::tensor( |
3024 | {0.3287029554083447, 0.456758647445438, 1.3685016029692183}), |
3025 | torch::tensor( |
3026 | {-1.781635969965323, -2.1119291060463254, -3.2698625668054278}), |
3027 | torch::tensor({-4.135987715622241}), |
3028 | }, |
3029 | { |
3030 | torch::tensor( |
3031 | {-0.07684825926741544, |
3032 | -0.3182201487496606, |
3033 | 0.08800775942949753, |
3034 | -0.03561702050647377, |
3035 | -0.05109626940874274, |
3036 | 0.967200308063431}), |
3037 | torch::tensor( |
3038 | {0.3418602073865105, 0.47164537681262036, 1.3808249543962092}), |
3039 | torch::tensor( |
3040 | {-1.7977177360253929, -2.1300662313234127, -3.2948372910743537}), |
3041 | torch::tensor({-4.165360368453936}), |
3042 | }, |
3043 | { |
3044 | torch::tensor( |
3045 | {-0.07326215742204903, |
3046 | -0.31339589848358795, |
3047 | 0.09201816976416921, |
3048 | -0.030224217178854797, |
3049 | -0.0479503174605994, |
3050 | 0.9713800632469923}), |
3051 | torch::tensor( |
3052 | {0.35396614509118257, 0.485298524494989, 1.3920431643924076}), |
3053 | torch::tensor( |
3054 | {-1.8125038126190638, -2.146734711618823, -3.3176778240157505}), |
3055 | torch::tensor({-4.192162739857097}), |
3056 | }, |
3057 | }; |
3058 | } |
3059 | |
3060 | } // namespace expected_parameters |
3061 | |