Skip to content

Commit

Permalink
Merge pull request #15 from walkernr/main
Browse files Browse the repository at this point in the history
General Update
  • Loading branch information
walkernr committed Jan 17, 2022
2 parents 7a13a9e + b939802 commit da7b7cd
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 260 deletions.
64 changes: 63 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,69 @@ A framework for materials science NER using the HuggingFace Transformers NLP Too

# Installation

```git
git clone https://github.com/walkernr/MatBERT_NER.git MatBERT_NER
cd MatBERT_NER
pip install -r requirements.txt .
```

# Example Usage

# License
The folowing command will train the MatBERT model on the solid state dataset using default parameters

```
python train.py -dv gpu:0 -ds solid_state -ml matbert
```

Additional parameters can be specified.

```
usage: train.py [-h] [-dv DEVICE] [-sd SEEDS] [-ts TAG_SCHEMES] [-st SPLITS] [-ds DATASETS] [-ml MODELS] [-sl] [-bs BATCH_SIZE] [-on OPTIMIZER_NAME] [-wd WEIGHT_DECAY] [-ne N_EPOCH]
[-eu EMBEDDING_UNFREEZE] [-tu TRANSFORMER_UNFREEZE] [-el EMBEDDING_LEARNING_RATE] [-tl TRANSFORMER_LEARNING_RATE] [-cl CLASSIFIER_LEARNING_RATE] [-sf SCHEDULING_FUNCTION]
[-km]
optional arguments:
-h, --help show this help message and exit
-dv DEVICE, --device DEVICE
computation device for model (e.g. cpu, gpu:0, gpu:1)
-sd SEEDS, --seeds SEEDS
comma-separated seeds for data shuffling and model initialization (e.g. 1,2,3 or 2,4,8)
-ts TAG_SCHEMES, --tag_schemes TAG_SCHEMES
comma-separated tagging schemes to be considered (e.g. iob1,iob2,iobes)
-st SPLITS, --splits SPLITS
comma-separated training splits to be considered, in percent (e.g. 80). test split will always be 10% and the validation split will be 1/8 of the training split
unless the training split is 100%
-ds DATASETS, --datasets DATASETS
comma-separated datasets to be considered (e.g. solid_state,doping)
-ml MODELS, --models MODELS
comma-separated models to be considered (e.g. matbert,scibert,bert)
-sl, --sentence_level
switch for sentence-level learning instead of paragraph-level
-bs BATCH_SIZE, --batch_size BATCH_SIZE
number of samples in each batch
-on OPTIMIZER_NAME, --optimizer_name OPTIMIZER_NAME
name of optimizer, add "_lookahead" to implement lookahead on top of optimizer (not recommended for ranger or rangerlars)
-wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
weight decay for optimizer (excluding bias, gamma, and beta)
-ne N_EPOCH, --n_epoch N_EPOCH
number of training epochs
-eu EMBEDDING_UNFREEZE, --embedding_unfreeze EMBEDDING_UNFREEZE
epoch (index) at which bert embeddings are unfrozen
-tu TRANSFORMER_UNFREEZE, --transformer_unfreeze TRANSFORMER_UNFREEZE
comma-separated number of transformers (encoders) to unfreeze at each epoch
-el EMBEDDING_LEARNING_RATE, --embedding_learning_rate EMBEDDING_LEARNING_RATE
embedding learning rate
-tl TRANSFORMER_LEARNING_RATE, --transformer_learning_rate TRANSFORMER_LEARNING_RATE
transformer learning rate
-cl CLASSIFIER_LEARNING_RATE, --classifier_learning_rate CLASSIFIER_LEARNING_RATE
pooler/classifier learning rate
-sf SCHEDULING_FUNCTION, --scheduling_function SCHEDULING_FUNCTION
function for learning rate scheduler (linear, exponential, or cosine)
-km, --keep_model switch for saving the best model parameters to disk
```

To train on custom annotated datasets, the `train.py` script has a dictionary `data_files` where additional datasets can be specified. Similarly, alternative pre-trained models can be used by modifying the `model_files` dictionary.

For prediction, the `predict` function contained within `predict.py` can be used. An example that was used internally can be found in the `predict_script.py` file. Furthermore, an example utilizing MongoDB can be found in the `predict_mongo.py` script. Note that these two examples will need to be edited for your specific needs to be usable.

# License
454 changes: 227 additions & 227 deletions matbert_ner/data/doping.json

Large diffs are not rendered by default.

38 changes: 30 additions & 8 deletions matbert_ner/models/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from torchtools.optim import RangerLars, Ralamb, Ranger, Novograd, RAdam, Lamb, Lookahead
from seqeval.scheme import IOB1, IOB2, IOBES
from seqeval.metrics import accuracy_score, classification_report
import json


class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)


class StateCacher(object):
Expand Down Expand Up @@ -177,7 +189,8 @@ def save_history(self, history_path):
None
'''
# save epoch metrics
torch.save(self.epoch_metrics, history_path)
with open(history_path, 'w') as f:
f.write(json.dumps(self.epoch_metrics, indent=2, cls=NpEncoder))


def load_history(self, history_path):
Expand All @@ -189,7 +202,8 @@ def load_history(self, history_path):
None
'''
# load epoch metrics from path
self.epoch_metrics = torch.load(history_path)
with open(history_path, 'r') as f:
self.epoch_metrics = json.load(f)
# set past epochs
self.past_epoch = len(self.epoch_metrics['training'].keys())

Expand Down Expand Up @@ -441,7 +455,7 @@ def process_summaries(self, annotations):
# append joined entity to dictionary
entry_entities[sentence[entity_idx[k]]['annotation']].add(' '.join([sentence[u]['text'] for u in range(entity_idx[k], entity_idx[k+1])]))
# append entry entity dictionary
annotation['entities'] = {class_type: list(entry_entities[class_type]) for class_type in class_types}
annotation['entities'] = {class_type: sorted(list(set(entry_entities[class_type]))) for class_type in class_types}
return annotations


Expand Down Expand Up @@ -712,7 +726,8 @@ def test(self, test_iter, test_path=None, state_path=None):
metrics, test_results = self.train_evaluate_epoch(0, 1, test_iter, 'test')
# save the test metrics and results
if test_path is not None:
torch.save((metrics, test_results), test_path)
with open(test_path, 'w') as f:
f.write(json.dumps({'metrics': metrics, 'results': test_results}, indent=2, cls=NpEncoder))
# return the test metrics and results
return metrics, test_results

Expand All @@ -739,15 +754,18 @@ def merge_split_entries(self, prediction_results):
return merged_prediction_results


def predict(self, predict_iter, original_data=None, predict_path=None, state_path=None):
def predict(self, predict_iter, original_data=None, state_path=None, predict_path=None, return_full_dict=False):
'''
Predicts classifications for a dataset
Arguments:
predict_iter: Prediction dataloader
predict_path: Path to save the predictions to
original_data: Original data before pre-processing
state_path: Path to load the model state from
predict_path: Path to save the predictions to
return_full_dict: Toggle for returning full JSON entries or only the detected entities
Returns:
Dictionary of text and annotations by word, sentence, paragraph e.g. [[[{'text': text, 'annotation': annotation},...],...],...]
or dictionary of entity summaries
'''
# if state path provided, load state (excluding optimizer)
if state_path is not None:
Expand Down Expand Up @@ -775,6 +793,10 @@ def predict(self, predict_iter, original_data=None, predict_path=None, state_pat
annotations = self.process_summaries(output_annotations)
# save annotations
if predict_path is not None:
torch.save(annotations, predict_path)
with open(predict_path, 'w') as f:
f.write(json.dumps(annotations, indent=2))
# return the annotations
return annotations
if return_full_dict:
return annotations
else:
return [annotation['entities'] for annotation in annotations]
10 changes: 7 additions & 3 deletions matbert_ner/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from matbert_ner.models.model_trainer import NERTrainer


def predict(texts, is_file, model_file, state_path, predict_path=None, scheme="IOBES", batch_size=256, device="cpu", seed=None):
def predict(texts, is_file, model_file, state_path, predict_path=None, return_full_dict=False, scheme="IOBES", batch_size=256, device="cpu", seed=None):
"""
Predict labels for texts. Please limit input to 512 tokens or less.
Args:
texts ([str]): List of string texts to predict labels for. Limit to 512 estimated tokens. Untokenized text will be tokenized interally with
texts ([str]): JSON filename, list of JSON entries, or list of string texts to predict labels for. Untokenized text will be tokenized interally with
the Materials Tokenizer.
is_file (bool): Toggle for whether the texts are a JSON file or list of JSON entries/strings
model_file (str): Path to BERT model file.
state_path (str): Path to model state for NER task, fine tuned for specific task (e.g., gold nanoparticles).
predict_path (str): Name of output file
return_full_dict (bool): Toggle for returning the full JSON entry or just the summarized entities detected by the model
scheme (str): IOBES or IOB2.
batch_size (int): Number of samples to predict in one batch pass.
device (str): Select 'cpu', 'gpu', or torch specific logic for running on multiple GPUs.
Expand Down Expand Up @@ -49,6 +52,7 @@ def predict(texts, is_file, model_file, state_path, predict_path=None, scheme="I
bert_ner_trainer = NERTrainer(bert_ner, device)
annotations = bert_ner_trainer.predict(ner_data.dataloaders['predict'],
original_data=ner_data.data['predict'],
state_path=state_path,
predict_path=predict_path,
state_path=state_path)
return_full_dict=return_full_dict)
return annotations
14 changes: 10 additions & 4 deletions matbert_ner/predict_script.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from predict import predict
import json

doping_data = '../../datasets/dop_toparse_169828.json'
aunp_data = '../../datasets/aunp_recipes_characterization_filtered.json'
example_data = '../../../LBL_NER_DATASETS/example.json'
model = '../../matbert-base-uncased'
solid_state_state = './matbert_solid_state_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'
doping_state = './matbert_doping_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'
aunp6_state = './matbert_aunp6_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'
solid_state_state = '../../MatBERT_NER_models/matbert_solid_state_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'
doping_state = '../../MatBERT_NER_models/matbert_doping_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'
aunp6_state = '../../MatBERT_NER_models/matbert_aunp6_paragraph_iobes_crf_10_lamb_5_1_012_1e-04_2e-03_1e-02_0e+00_exponential_256_100/best.pt'

# predict(doping_data, True, model, solid_state_state, predict_path=solid_state_state.replace('best.pt', 'predict_doping_solid_state_169828.pt'), device='gpu:0')
# predict(doping_data, True, model, doping_state, predict_path=doping_state.replace('best.pt', 'predict_doping_doping_169828.pt'), device='gpu:0')
# predict(aunp_data, True, model, solid_state_state, predict_path=solid_state_state.replace('best.pt', 'predict_aunp_solid_state.pt'), device='gpu:0')
predict(aunp_data, True, model, aunp6_state, predict_path=aunp6_state.replace('best.pt', 'predict_aunp_aunp6.pt'), device='gpu:0')
# predict(aunp_data, True, model, aunp6_state, predict_path=aunp6_state.replace('best.pt', 'predict_aunp_aunp6.pt'), device='gpu:0')
predict_path = solid_state_state.replace('best.pt', 'predict_example.json')
example_data = json.load(open(example_data, 'r'))
x = predict(example_data, False, model, solid_state_state, predict_path=predict_path, device='cpu')
print(json.dumps(x, indent=2))
26 changes: 16 additions & 10 deletions matbert_ner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def parse_args():
# model file dictionary
model_files = {'bert': 'bert-base-uncased',
'scibert': 'allenai/scibert_scivocab_uncased',
'matbert': '../../matbert-base-uncased'}
'matbert': '../../matbert-base-uncased',
'matbert_uncased': '../../matbert-base-uncased',
'matbert_cased': '../../matbert-base-cased'}
# loop through command line lists
for seed in seeds:
for scheme in schemes:
Expand Down Expand Up @@ -167,9 +169,10 @@ def parse_args():
print('Classes: {}'.format(' '.join(ner_data.classes)))
# if test file already exists, skip, otherwise, train
succeeded = True
if os.path.exists(save_dir+'history.pt'):
if os.path.exists(save_dir+'history.json'):
print('Already trained {}'.format(alias))
history = torch.load(save_dir+'history.pt')
with open(save_dir+'history.json', 'r') as f:
history = json.load(f)
if split == 100:
print('{:<10}{:<10}'.format('epoch', 'training'))
for i in range(len(history['training'].keys())):
Expand All @@ -192,7 +195,7 @@ def parse_args():
embedding_unfreeze=embedding_unfreeze, encoder_schedule=encoder_schedule, scheduling_function=scheduling_function,
save_dir=save_dir, use_cache=use_cache)
# save model history
bert_ner_trainer.save_history(history_path=save_dir+'history.pt')
bert_ner_trainer.save_history(history_path=save_dir+'history.json')
# if cache was used and the model should be kept, the state must be saved directly after loading best parameters
if use_cache:
bert_ner_trainer.load_state_from_cache('best')
Expand All @@ -206,14 +209,17 @@ def parse_args():
if ner_data.dataloaders['test'] is not None and succeeded:
if os.path.exists(save_dir+'best.pt'):
# predict test results
metrics, test_results = bert_ner_trainer.test(ner_data.dataloaders['test'], test_path=save_dir+'test.pt', state_path=save_dir+'best.pt')
metrics, test_results = bert_ner_trainer.test(ner_data.dataloaders['test'], test_path=save_dir+'test.json', state_path=save_dir+'best.pt')
# predict classifications
annotations = bert_ner_trainer.predict(ner_data.dataloaders['test'], original_data=ner_data.data['test'], predict_path=save_dir+'predict.pt', state_path=save_dir+'best.pt')
elif os.path.exists(save_dir+'test.pt'):
annotations = bert_ner_trainer.predict(ner_data.dataloaders['test'], original_data=ner_data.data['test'], predict_path=save_dir+'predict.json', state_path=save_dir+'best.pt', return_full_dict=True)
elif os.path.exists(save_dir+'test.json'):
# retrieve test results
metrics, test_results = torch.load(save_dir+'test.pt')
# retireve classifications
annotations = torch.load(save_dir+'predict.pt')
with open(save_dir+'test.json', 'r') as f:
test = json.load(f)
metrics, test_results = test['metrics'], test['results']
# retrieve classifications
with open(save_dir+'predict.json', 'r') as f:
annotations = json.load(f)
# print classification report over test results
print(classification_report(test_results['labels'], test_results['predictions'], mode='strict', scheme=bert_ner_trainer.metric_scheme))
# save tokens/annotations to text file
Expand Down
14 changes: 7 additions & 7 deletions matbert_ner/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def filter_data(self, data):
data_filt = []
for i, entry in enumerate(tqdm(data, desc='| filtering entries |')):
try:
identifier = entry['meta']['doi']+'/'+str(entry['meta']['par'])+'/'+str(entry['meta']['split'])
d = {'meta': {'doi': entry['meta']['doi'], 'par': entry['meta']['par'], 'split': entry['meta']['split']}}
identifier = '{}/{}/{}'.format(entry['meta']['doi'], str(entry['meta']['par']), str(entry['meta']['split']))
d = {'meta': entry['meta']}
except:
try:
identifier = entry['meta']['doi']+'/'+str(entry['meta']['par'])
identifier = '{}/{}'.format(entry['meta']['doi'], str(entry['meta']['par']))
d = {'meta': {'doi': entry['meta']['doi'], 'par': entry['meta']['par'], 'split': 0}}
except:
try:
Expand Down Expand Up @@ -202,14 +202,14 @@ def load_from_file(self, data_file, annotated=True):
'''
# open data file
try:
with open(data_file, 'r') as f:
content = f.read()
entries = json.loads(content)
except:
with open(data_file, 'r') as f:
entries = []
for l in tqdm(f, desc='| loading entries from file |'):
entries.append(json.loads(l))
except:
with open(data_file, 'r') as f:
content = f.read()
entries = json.loads(content)
data_raw = self.load_from_memory(entries, annotated)
return data_raw

Expand Down

0 comments on commit da7b7cd

Please sign in to comment.