-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
136 lines (128 loc) · 5 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Global Variables
SEED = 42
WANDB_API = "ENTER_YOUR_OWN_KEY"
HF_API = "ENTER_YOUR_OWN_KEY"
OPENAI_API = "ENTER_YOUR_OWN_KEY"
# Multi Class Classification Variables
MULTI_CLASS_WANDB_PROJECT = "clef-mediqa-chat-multi-class-classification"
# MULTI_CLASS_WANDB_PROJECT = "mediqa-chat-mc-test"
MULTI_CLASS_EPOCHS = 150
MULTI_CLASS_MICRO_BATCH_SIZE = 16
MULTI_CLASS_GRADIENT_ACCUMULATION_STEPS = 1
MULTI_CLASS_GRADIENT_CLIPPING = 1.0
MULTI_CLASS_LEARNING_RATE = 1e-3
MULTI_CLASS_WARM_UP_STEPS = 0.1
MULTI_CLASS_ADAMW_WEIGHT_DECAY = 0.01
MULTI_CLASS_ADAMW_EPS = 1e-6
MULTI_CLASS_ADAM_BIAS_CORRECTION = True
MULTI_CLASS_MAX_LENGTH = 512
MULTI_CLASS_MODEL_CHECKPOINT = "emilyalsentzer/Bio_ClinicalBERT"
MULTI_CLASS_N_SPLITS = 3
MULTI_CLASS_JOB_TYPE = "accuracy_baseline"
MULTI_CLASS_SAMPLING = False
MUTLI_CLASS_BALANCE_LOSS = False
MULTI_CLASS_SEED = 42
MULTI_CLASS_MODEL_NAME = f"{MULTI_CLASS_N_SPLITS}-stratified-cv-{MULTI_CLASS_MODEL_CHECKPOINT.split('/')[-1]}-lora"
MULTI_CLASS_NOTES = \
"""
Multi Class Classification with 3 fold stratified cross validation on section-header. Uses complete data with 20 classes with Cross Entropy Loss and seed set to 42.
Accuracy for every class.
Model is trained using PEFT from Huggingface
"""
# +
# Multi Label Classification Variables
MULTI_LABEL_WANDB_PROJECT = "mediqa-chat-multi-label-classification"
# MULTI_LABEL_WANDB_PROJECT = "mediqa-chat-ml-test"
MULTI_LABEL_EPOCHS = 30
MULTI_LABEL_LEARNING_RATE = 2e-5
MULTI_LABEL_WARM_UP_STEPS = 0.1
MULTI_LABEL_ADAMW_WEIGHT_DECAY = 0.01
MULTI_LABEL_ADAMW_EPS = 1e-6
MULTI_LABEL_MAX_LENGTH = 512
MULTI_LABEL_ATTRIBUTION_LENGTH = 200
MULTI_LABEL_MODEL_CHECKPOINT = "emilyalsentzer/Bio_ClinicalBERT"
MULTI_LABEL_N_SPLITS = 3
MULTI_LABEL_JOB_TYPE = "roc_auc_pr_baseline"
MULTI_LABEL_SAMPLING = False
MUTLI_LABEL_BALANCE_LOSS = False
MULTI_LABEL_SEED = 42
MULTI_LABEL_MODEL_NAME = "5-fold-multilabel-cv-bio-clinicalbert-multilabel-focal-loss-seed-42-complete-data"
MULTI_LABEL_NOTES = \
"""
Multi Label Classification on complete data with 20 classes with Focal Loss and seed set to 42.
MultiLabel Stratification has been used as cross validation strategy.
ROC-AUC and PR Score for every class
"""
# Summary Generation Variables
TASKA_SUMMARY_WANDB_PROJECT = "clef-mediqa-sum-summarization-taska"
# TASKA_SUMMARY_WANDB_PROJECT = "mediqa-chat-taska-summary-test"
TASKA_SUMMARY_EPOCHS = 150
TASKA_SUMMARY_TRAIN_MICRO_BATCH_SIZE_PER_GPU = 1
TASKA_SUMMARY_GRADIENT_ACCUMULATION_STEPS = 16
TASKA_SUMMARY_GRADIENT_CLIPPING = 1.0
TASKA_SUMMARY_LEARNING_RATE = 1e-3
TASKA_SUMMARY_MAX_SOURCE_LENGTH = 512
TASKA_SUMMARY_MIN_TARGET_LENGTH = 8
TASKA_SUMMARY_MAX_TARGET_LENGTH = 512
TASKA_SUMMARY_PADDING = "max_length"
TASKA_SUMMARY_IGNORE_PAD_TOKEN_FOR_LOSS = True
TASKA_SUMMARY_PROMPT = "summarize: "
TASKA_SUMMARY_MODEL_CHECKPOINT = "google/flan-t5-large"
TASKA_SUMMARY_N_SPLITS = 3
TASKA_SUMMARY_SINGLE_MODEL_NAME = \
f"{TASKA_SUMMARY_N_SPLITS}-fold-stratified-cv-{TASKA_SUMMARY_MODEL_CHECKPOINT.split('/')[-1]}-lora"
TASKA_SUMMARY_NOTES = \
"""
Summarization of complete dialogues.
The data has been stratified on Section Header for 3 folds
Metric for early stopping is Log Loss.
Metric for text generation evaluation is ROGUE, BERTScore, BLEURT
"""
TASKA_SUMMARY_JOB_TYPE = "rouge_bertscore_bluert_baseline"
TASKA_SUMMRAY_USE_STEMMER = True
TASKA_SUMMARY_OPTIMIZER_WEIGHT_DECAY = 0.01
TASKA_SUMMARY_OPTIMIZER_EPS = 1e-6
TASKA_SUMMARY_OPTIMIZER_BIAS_CORRECTION = True
TASKA_SUMMARY_NUM_WARMUP_STEPS = 0.1
TASKA_SUMMARY_SAMPLING = False
TASKA_SUMMARY_SEED = 42
TASKA_SUMMARY_DIALOGUE_W_SECTION_CODE = False
TASKA_SUMMARY_DIALOGUE_W_SECTION_CODE_DESC = True
# # +
# Task B Summary Generation Variables
TASKC_SUMMARY_WANDB_PROJECT = "mediqa-chat-TaskC-summarization"
# TASKC_SUMMARY_WANDB_PROJECT = "mediqa-chat-TaskC-summarY-test"
TASKC_SUMMARY_EPOCHS = 30
TASKC_SUMMARY_BATCH_SIZE = 4
TASKC_SUMMARY_LEARNING_RATE = 2e-5
TASKC_SUMMARY_MAX_SOURCE_LENGTH = 3400
TASKC_SUBJECTIVE_MIN_TARGET_LENGTH = 50
TASKC_OBJECTIVE_EXAM_MIN_TARGET_LENGTH = 5
TASKC_OBJECTIVE_RESULT_MIN_TARGET_LENGTH = 5
TASKC_ASSESSMENT_AND_PLAN_MIN_TARGET_LENGTH = 50
TASKC_SUBJECTIVE_MAX_TARGET_LENGTH = 768
TASKC_OBJECTIVE_EXAM_MAX_TARGET_LENGTH = 256
TASKC_OBJECTIVE_RESULT_MAX_TARGET_LENGTH = 256
TASKC_ASSESSMENT_AND_PLAN_MAX_TARGET_LENGTH = 640
TASKC_SUMMARY_PADDING = "max_length"
TASKC_SUMMARY_IGNORE_PAD_TOKEN_FOR_LOSS = True
TASKC_SUMMARY_MODEL_CHECKPOINT = "MingZhong/DialogLED-large-5120"
TASKC_SUMMARY_MODEL_NAME = "5-KFold-dialogled-large-with-section-information"
TASKC_SUMMARY_PREFIX = ""
TASKC_SUMMARY_N_SPLITS = 3
TASKC_SUMMARY_NOTES = \
"""
Summarization of Long Dialogues with section code description.
Early Stopping criteria is Loss
Metrics are Rouge, Bertscore, BlueRT
"""
TASKC_SUMMARY_JOB_TYPE = "rouge_bertscore_bluert_baseline"
TASKC_SUMMRAY_USE_STEMMER = True
TASKC_SUMMARY_WEIGHT_DECAY = 0.01
TASKC_SUMMARY_NUM_WARMUP_STEPS = 0.1
TASKC_SUMMARY_NUM_BEAMS = 2
TASKC_SUMMARY_DIALOGUE_W_SECTION_CODE = False
TASKC_SUMMARY_DIALOGUE_W_SECTION_CODE_DESC = True
TASKC_SUMMARY_SAMPLING = False
TASKC_SUMMARY_SEED = 42
TASKC_GRADIENT_ACCUMULATION_STEPS = 4