This file contains information on the 20 pre-trained models that we provided.
The provided models were trained with the following architecture:
- Forward RNN, NADE and FB-RNN: Five layers (BatchNormalization, LSTM Layer 1, LSTM Layer 2, BatchNormalization, Linear).
- BIMODAL: seven layers (BatchNormalization, LSTM Layer 1 - forward, LSTM Layer 1 - backward, LSTM Layer 2 - forward, LSTM Layer 2 – backward, BatchNormalization, Linear).
A dropout value of 0.3 was used for the output weights in the first LSTM Layer. Models were trained with the Adam optimization algorithm, using cross-entropy loss for performance optimization, computed based on five-fold cross-validation (random partitioning protocol). Models were trained for 10 epochs. Additional details are found in Tables 1-3.
Table 1. Details on the architecture of the Forward RNN and NADE.
| Type | No. Units | No. Parameters |
|---|---|---|
| BatchNormalization 1 | 55 | 110 |
| LSTM 1 | 256 or 512 | 320512 |
| LSTM 2 | 256 or 512 | 526336 |
| BatchNormalization 2 | 256 | 512 |
| Linear Layer | 55 | 14080 |
Table 2. Details on the architecture of the FB-RNN models.
| Type | No. Units | No. Parameters |
|---|---|---|
| BatchNormalization 1 | 110 | 220 |
| LSTM 1 | 256 or 512 | 376832 |
| LSTM 2 | 256 or 512 | 526336 |
| BatchNormalization 2 | 256 | 512 |
| Linear Layer | 55 | 28160 |
Table 3. Details on the architecture of the BIMODAL models.
| Type | No. Units | No. Parameters |
|---|---|---|
| BatchNormalization 1 | 55 | 110 |
| LSTM 1 Forward | 128 or 256 | 94720 |
| LSTM 1 | Backward 128 or 256 | 94720 |
| LSTM 2 | 128 or 256 | 132096 |
| LSTM 2 Backward | 128 or 256 | 132096 |
| BatchNormalization 2 | 256 | 512 |
| Linear Layer | 55 | 14080 |
The ID contained in the field "model name" can be use to sample from the pre-trained models, as explained in the README.
| *model name * | method | starting point | no. hidden | augmentation |
|---|---|---|---|---|
| 'BIMODAL_fixed_1024' | BIMODAL | fixed | 1024 | none |
| 'BIMODAL_fixed_512' | BIMODAL | fixed | 512 | none |
| 'BIMODAL_random_1024' | BIMODAL | random | 1024 | none |
| 'BIMODAL_random_1024_aug_5' | BIMODAL | random | 1024 | 5-fold |
| 'BIMODAL_random_512' | BIMODAL | random | 512 | none |
| 'BIMODAL_random_512_aug_5' | BIMODAL | random | 512 | 5-fold |
| 'FBRNN_fixed_1024' | FB-RNN | fixed | 1024 | none |
| 'FBRNN_fixed_512' | FB-RNN | fixed | 512 | none |
| 'FBRNN_random_1024' | FB-RNN | random | 1024 | none |
| 'FBRNN_random_1024_aug_5' | FB-RNN | random | 1024 | 5-fold |
| 'FBRNN_random_512' | FBRNN | random | 512 | none |
| 'FBRNN_random_512_aug_5' | FB-RNN | random | 512 | 5-fold |
| 'ForwardRNN_1024' | Forward RNN | fixed | 1024 | none |
| 'ForwardRNN_512' | Forward RNN | fixed | 512 | none |
| 'NADE_fixed_1024' | NADE | fixed | 1024 | none |
| 'NADE_fixed_512' | NADE | fixed | 512 | none |
| 'NADE_random_1024' | NADE | random | 1024 | none |
| 'NADE_random_1024_aug_5' | NADE | random | 1024 | 5-fold |
| 'NADE_random_512' | NADE | random | 512 | none |
| 'NADE_random_512_aug_5' | NADE | random | 512 | 5-fold |