Skip to content

Commit cb793a1

Browse files
committed
update readme about num_epoch and mems=None move to first seq step
1 parent ce5d7fa commit cb793a1

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@ $ pip install pytorch_pretrained_bert
1515
$ python main.py --data ./data.txt --tokenizer bert-base-uncased \
1616
--seq_len 512 --reuse_len 256 --perm_size 256 \
1717
--bi_data True --mask_alpha 6 --mask_beta 1 \
18-
--num_predict 85 --mem_len 384 --num_step 100
18+
--num_predict 85 --mem_len 384 --num_epoch 100
1919
```
2020

2121
Also, You can run code in [Google Colab](https://colab.research.google.com/github/graykode/xlnet-Pytorch/blob/master/XLNet.ipynb) easily.
2222

2323
- Hyperparameters for Pretraining in Paper.
2424

2525
<p align="center"><img width="300" src="images/hyperparameters.png" /> </p>
26-
2726
#### Option
2827

2928
- `—data`(String) : `.txt` file to train. It doesn't matter multiline text. Also, one file will be one batch tensor. Default : `data.txt`
@@ -37,7 +36,7 @@ Also, You can run code in [Google Colab](https://colab.research.google.com/githu
3736
- `—mask_beta`(Integer) : How many tokens to mask within each group. Default : `1`
3837
- `—num_predict`(Interger) : Num of tokens to predict. In Paper, it mean Partial Prediction. Default : `85`
3938
- `—mem_len`(Interger) : Number of steps to cache in Transformer-XL Architecture. Default : `384`
40-
- `number_step`(Interger) : Number of Step(Epoch). Default : `100`
39+
- `num_epoch`(Interger) : Number of Epoch. Default : `100`
4140

4241

4342

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@
6767

6868
criterion = nn.CrossEntropyLoss()
6969
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
70-
mems = None
7170

7271
for num_epoch in range(args.num_epoch):
72+
mems = None
7373

7474
features = data_utils._create_data(sp=sp,
7575
input_paths=args.data,

0 commit comments

Comments
 (0)