Skip to content

Latest commit

 

History

History
64 lines (47 loc) · 2.04 KB

File metadata and controls

64 lines (47 loc) · 2.04 KB

Chinese-Text-Classification-Pytorch-Tuning

LICENSE

中文文本分类,TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention, DPCNN, Transformer, 基于pytorch,开箱即用。

现也已加入对Bert的支持。

基于ray.tune实现了对不同模型进行超参数优化的功能。简单易用。

环境

python 3.7
pytorch 1.1
tqdm
sklearn
tensorboardX
ray

使用说明

第一步:安装ray - pip install ray

第二步:选定要做超参数优化的模型: 如TextRNN
(Bert需要参照此处额外下载文件,不用Bert可跳过)

第三步:根据第二步选中的模型,在run.py中设定相关超参数的search_space。具体的语法可参照这里。如

search_space = {
    'learning_rate': tune.loguniform(1e-5, 1e-2),
    'num_epochs': tune.randint(5, 21),
    'dropout': tune.uniform(0, 0.5),
    'hidden_size': tune.randint(32, 257),
    'num_layers': tune.randint(1,3)
}

此处请注意确认相关参数是否适用于选择的模型,否则会报错

第四步:启动50次超参数优化实验

python run.py --model TextCNN --tune_param True --tune_samples 50 

第五步:在自动生成的实验结果文件tune_results_.csv中查看实验记录


更多用法

# 使用GPU
python run.py --model TextRNN --tune_param True --tune_gpu True

# 使用Bert
python run.py --model bert --tune_param True --tune_gpu True

# 自定义实验结果文件后缀名
python run.py --model TextRNN --tune_param True --tune_file rnn_char

# 使用ASHA scheduler来做early stopping
python run.py --model TextRNN --tune_param True --tune_asha True

# 使用当前的超参数进行模型训练,不进行超参数优化
python run.py --model TextRNN --tune_param False

更多细节请参照源文档