Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Upgrade version to 1.0.0 #152

Merged
merged 85 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
977defc
[DOC] rectify some atrributes and class
KenelmQLH Oct 24, 2021
d4656e4
[DOC] rectify parser comments
KenelmQLH Oct 24, 2021
da47222
[DOC] test
KenelmQLH Oct 24, 2021
3047977
[DOC] test remote
KenelmQLH Oct 24, 2021
1b9e835
Update sif.rst
KenelmQLH Oct 24, 2021
11a3490
Update sif.rst
KenelmQLH Oct 24, 2021
68e3883
test collection api
KenelmQLH Oct 30, 2021
4027c52
Merge remote-tracking branch 'origin/dev' into dev
KenelmQLH Oct 30, 2021
4608fa5
Merge branch 'bigdata-ustc:dev' into dev
KenelmQLH Nov 11, 2021
834ea98
Merge branch 'dev' of github.com:bigdata-ustc/EduNLP into dev
KenelmQLH Jan 20, 2022
f352c30
Merge branch 'bigdata-ustc:dev' into dev
KenelmQLH Apr 27, 2022
ccf153b
Merge remote-tracking branch 'upstream/dev' into dev
KenelmQLH May 11, 2022
bc1ddfd
Merge remote-tracking branch 'upstream/dev' into dev
KenelmQLH May 13, 2022
64fe539
Merge remote-tracking branch 'upstream/dev' into dev
KenelmQLH Jun 7, 2022
e761e6e
Merge branch 'dev' of github.com:KenelmQLH/EduNLP into dev
KenelmQLH Jun 7, 2022
25ce404
[fix] support gpu for Vector
nnnyt Dec 19, 2022
1e1cf2f
[feat] support gpu for I2V
nnnyt Dec 19, 2022
8da2009
[fix] use ubuntu-20.04 for py3.6
nnnyt Dec 19, 2022
c9be8fa
[fix] remove useless import
nnnyt Dec 19, 2022
1350ff6
update device for t2v
KenelmQLH Mar 24, 2023
50de155
merge fix_device
KenelmQLH Mar 24, 2023
16cf3ac
fix gensim with empty_vector
KenelmQLH Mar 27, 2023
4d93fbf
add downstream tasks for DisenQ
nnnyt Mar 27, 2023
6f9aa8c
modify downstream task for BERT
nnnyt Mar 27, 2023
572d5e9
finish test
KenelmQLH Mar 28, 2023
99ed226
add finetuning for disenq
nnnyt Mar 28, 2023
701d01d
updata .github
KenelmQLH Mar 28, 2023
c9afda1
update test
KenelmQLH Mar 28, 2023
d6e60cd
[BUG] Fix i2v.inter_vector error when params are provided
KenelmQLH Mar 29, 2023
4808d88
Fix load_error when training and loading model on different device
KenelmQLH Mar 29, 2023
407c6d5
Merge branch 'bigdata-ustc:master' into LMforTask
nnnyt Mar 30, 2023
5e0d5e9
[fix] unify output format for downstream models
nnnyt Mar 30, 2023
cacc7cf
Merge pull request #138 from nnnyt/LMforTask
KenelmQLH Mar 30, 2023
ea176dc
Merge pull request #137 from KenelmQLH/feat_i2v
KenelmQLH Mar 30, 2023
e36dac4
[doc] add similarity prediction demo
nnnyt Jun 26, 2023
d9e3187
add demos for difficulty & discrimination prediction
ShangziXue Jun 29, 2023
19ca5db
[fix] fix dependencies
nnnyt Jul 2, 2023
ff26a84
remove print and fix tokenizer
KenelmQLH Jul 2, 2023
67c70ef
add demos for difficulty prediction and discrimination prediction
ShangziXue Jul 2, 2023
eb49643
update gitignore
KenelmQLH Jul 2, 2023
be76573
Merge branch 'feat_i2v' into update_doc
KenelmQLH Jul 2, 2023
9394c61
update AUTHORS.md
ShangziXue Jul 2, 2023
19b540f
[DOC] Add paper segmentation
KenelmQLH Jul 2, 2023
f2ba01b
Fix bugs for quesnet figure loading
wintermelon008 Jul 2, 2023
b9604e4
Add quesnet_new
wintermelon008 Jul 2, 2023
dda1037
[Doc] Add knowledge_prediction
KenelmQLH Jul 2, 2023
dea9530
[docs] Add demo for quality evaluate
wintermelon008 Jul 2, 2023
8087e36
fix demos
ShangziXue Jul 4, 2023
f66b2a7
Merge remote-tracking branch 'nyt/update_doc' into update_doc
KenelmQLH Jul 5, 2023
024d997
update examples
KenelmQLH Jul 5, 2023
310586c
Merge pull request #141 from nnnyt/update_doc
KenelmQLH Jul 6, 2023
01d97ec
update demos
ShangziXue Jul 11, 2023
1b16353
Merge pull request #144 from wintermelon008/demo
nnnyt Jul 11, 2023
a3c281c
clean the extra code
KenelmQLH Jul 12, 2023
0629ec3
update demos
ShangziXue Jul 12, 2023
932d602
Fixed some bugs
ShangziXue Jul 14, 2023
214edcf
Merge pull request #140 from ShangziXue/master
nnnyt Jul 14, 2023
cfe00a7
Merge pull request #143 from KenelmQLH/update_doc
nnnyt Jul 14, 2023
d3ca962
[doc] reorganize downstream tasks
nnnyt Jul 14, 2023
be1aef1
Merge pull request #145 from nnnyt/update_doc
KenelmQLH Jul 14, 2023
68918b3
rewrite the Dataset
wintermelon008 Jul 14, 2023
bb5b3f0
Update quesnet_vec
wintermelon008 Jul 21, 2023
422c4b0
Merge branch 'bigdata-ustc:dev' into dev
wintermelon008 Jul 21, 2023
be5119f
Merge branches 'dev' and 'dev' of github.com:wintermelon008/EduNLP in…
wintermelon008 Jul 21, 2023
21a7d6d
[fix] small bugs fixed
wintermelon008 Jul 31, 2023
729e506
fix bugs for quesnet
wintermelon008 Jul 31, 2023
35c1057
update codes dor flake8
wintermelon008 Aug 1, 2023
9bd85db
fix bugs for flake8
wintermelon008 Aug 1, 2023
92776a6
fix bugs for flake8
wintermelon008 Aug 1, 2023
00e3726
Update tutorial docs and api docs
ShangziXue Aug 2, 2023
8d0c377
update tutorial and api references
karin0018 Aug 2, 2023
fc760d7
fix the flake8 bug
karin0018 Aug 2, 2023
362df0b
fix the flake8 bug
karin0018 Aug 2, 2023
8f103a1
fix the examples bug
karin0018 Aug 2, 2023
c840248
fix bug
karin0018 Aug 2, 2023
06461fd
[fix] update some codes for quesnet
wintermelon008 Aug 4, 2023
4ba7ec4
Merge pull request #146 from wintermelon008/dev
nnnyt Aug 5, 2023
5f2f71f
Merge pull request #150 from karin0018/dev
nnnyt Aug 5, 2023
3e6204c
Merge pull request #151 from ShangziXue/dev
nnnyt Aug 5, 2023
ace1960
Fix bugs for quesnet_vec
wintermelon008 Oct 13, 2023
d65242c
Update
wintermelon008 Oct 13, 2023
d786570
Merge pull request #153 from wintermelon008/dev
KenelmQLH Jan 13, 2024
9a3f5f7
Update change.txt
KenelmQLH Jan 13, 2024
22a9e10
Merge branch 'dev' of github.com:KenelmQLH/EduNLP into dev
KenelmQLH Jan 13, 2024
598d788
Merge pull request #154 from KenelmQLH/dev
nnnyt Jan 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ on: [push, pull_request]
jobs:
build:

runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
include:
- os: "ubuntu-latest"
- os: "ubuntu-20.04"
python-version: "3.6"

steps:
- uses: actions/checkout@v2
Expand All @@ -24,4 +28,4 @@ jobs:
- name: Test with pytest
run: |
pytest
codecov
codecov
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@

[Jundong Wu](https://github.com/wintermelon008)

[Shangzi Xue](https://github.com/ShangziXue)

The stared contributors are the corresponding authors.
6 changes: 6 additions & 0 deletions CHANGE.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
v1.0.0
1. Support cuda for I2V and T2V.
2. Add demos for downstream tasks including knowledge & difficulty & discrimination prediction, similarity prediction and paper segmentation.
3. Refactor quesnet for pretrain and vectorization.
4. Update documents about tutorials and API.

v0.0.9
1. Refactor tokenizer Basic Tokenizer and Pretrained Tokenizer
2. Refactor model structures following huggingface styles for Elmo, BERT, DisenQNet and QuesNet
Expand Down
60 changes: 44 additions & 16 deletions EduNLP/I2V/i2v.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
# 2021/8/1 @ tongshiwei

import torch
import json
import os.path
from typing import List, Tuple
Expand Down Expand Up @@ -59,12 +60,12 @@ class I2V(object):
"""

def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
pretrained_t2v=False, model_dir=MODEL_DIR, **kwargs):
pretrained_t2v=False, model_dir=MODEL_DIR, device='cpu', **kwargs):
if pretrained_t2v:
logger.info("Use pretrained t2v model %s" % t2v)
self.t2v = get_t2v_pretrained_model(t2v, model_dir)
self.t2v = get_t2v_pretrained_model(t2v, model_dir, device)
else:
self.t2v = T2V(t2v, *args, **kwargs)
self.t2v = T2V(t2v, device=device, *args, **kwargs)
if tokenizer == 'bert':
self.tokenizer = BertTokenizer.from_pretrained(
**tokenizer_kwargs if tokenizer_kwargs is not None else {})
Expand All @@ -82,31 +83,53 @@ def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
**tokenizer_kwargs if tokenizer_kwargs is not None else {})
self.params = {
"tokenizer": tokenizer,
"tokenizer_kwargs": tokenizer_kwargs,
"t2v": t2v,
"args": args,
"tokenizer_kwargs": tokenizer_kwargs,
"pretrained_t2v": pretrained_t2v,
"model_dir": model_dir,
"kwargs": kwargs,
"pretrained_t2v": pretrained_t2v
}
self.device = torch.device(device)

def __call__(self, items, *args, **kwargs):
"""transfer item to vector"""
return self.infer_vector(items, *args, **kwargs)

def tokenize(self, items, *args, key=lambda x: x, **kwargs) -> list:
# """tokenize item"""
"""
tokenize item
Parameter
----------
items: a list of questions
Return
----------
tokens: list
"""
return self.tokenizer(items, *args, key=key, **kwargs)

def infer_vector(self, items, key=lambda x: x, **kwargs) -> tuple:
"""
get question embedding
NotImplemented
"""
raise NotImplementedError

def infer_item_vector(self, tokens, *args, **kwargs) -> ...:
"""NotImplemented"""
return self.infer_vector(tokens, *args, **kwargs)[0]

def infer_token_vector(self, tokens, *args, **kwargs) -> ...:
"""NotImplemented"""
return self.infer_vector(tokens, *args, **kwargs)[1]

def save(self, config_path):
"""
save model weights in config_path
Parameter:
----------
config_path: str
"""
with open(config_path, "w", encoding="utf-8") as wf:
json.dump(self.params, wf, ensure_ascii=False, indent=2)

Expand All @@ -123,6 +146,7 @@ def load(cls, config_path, *args, **kwargs):

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
"""NotImplemented"""
raise NotImplementedError

@property
Expand Down Expand Up @@ -327,13 +351,13 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
return self.t2v.infer_vector(inputs, *args, **kwargs), self.t2v.infer_tokens(inputs, *args, **kwargs)

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', *args, **kwargs):
model_path = path_append(model_dir, get_pretrained_model_info(name)[0].split('/')[-1], to_str=True)
for i in [".tar.gz", ".tar.bz2", ".tar.bz", ".tar.tgz", ".tar", ".tgz", ".zip", ".rar"]:
model_path = model_path.replace(i, "")
logger.info("model_path: %s" % model_path)
tokenizer_kwargs = {"tokenizer_config_dir": model_path}
return cls("elmo", name, pretrained_t2v=True, model_dir=model_dir,
return cls("elmo", name, pretrained_t2v=True, model_dir=model_dir, device=device,
tokenizer_kwargs=tokenizer_kwargs)


Expand Down Expand Up @@ -386,17 +410,19 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
--------
vector:list
"""
is_batch = isinstance(items, list)
items = items if is_batch else [items]
inputs = self.tokenize(items, key=key, return_tensors=return_tensors)
return self.t2v.infer_vector(inputs, *args, **kwargs), self.t2v.infer_tokens(inputs, *args, **kwargs)

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', *args, **kwargs):
model_path = path_append(model_dir, get_pretrained_model_info(name)[0].split('/')[-1], to_str=True)
for i in [".tar.gz", ".tar.bz2", ".tar.bz", ".tar.tgz", ".tar", ".tgz", ".zip", ".rar"]:
model_path = model_path.replace(i, "")
logger.info("model_path: %s" % model_path)
tokenizer_kwargs = {"tokenizer_config_dir": model_path}
return cls("bert", name, pretrained_t2v=True, model_dir=model_dir,
return cls("bert", name, pretrained_t2v=True, model_dir=model_dir, device=device,
tokenizer_kwargs=tokenizer_kwargs)


Expand Down Expand Up @@ -452,7 +478,7 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
return i_vec, t_vec

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, **kwargs):
def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', **kwargs):
model_path = path_append(model_dir, get_pretrained_model_info(name)[0].split('/')[-1], to_str=True)
for i in [".tar.gz", ".tar.bz2", ".tar.bz", ".tar.tgz", ".tar", ".tgz", ".zip", ".rar"]:
model_path = model_path.replace(i, "")
Expand All @@ -461,7 +487,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, **kwargs):
tokenizer_kwargs = {
"tokenizer_config_dir": model_path,
}
return cls("disenq", name, pretrained_t2v=True, model_dir=model_dir,
return cls("disenq", name, pretrained_t2v=True, model_dir=model_dir, device=device,
tokenizer_kwargs=tokenizer_kwargs, **kwargs)


Expand Down Expand Up @@ -495,18 +521,20 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
token embeddings
question embedding
"""
is_batch = isinstance(items, list)
items = items if is_batch else [items]
encodes = self.tokenize(items, key=key, meta=meta, *args, **kwargs)
return self.t2v.infer_vector(encodes), self.t2v.infer_tokens(encodes)

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
def from_pretrained(cls, name, model_dir=MODEL_DIR, device='cpu', *args, **kwargs):
model_path = path_append(model_dir, get_pretrained_model_info(name)[0].split('/')[-1], to_str=True)
for i in [".tar.gz", ".tar.bz2", ".tar.bz", ".tar.tgz", ".tar", ".tgz", ".zip", ".rar"]:
model_path = model_path.replace(i, "")
logger.info("model_path: %s" % model_path)
tokenizer_kwargs = {
"tokenizer_config_dir": model_path}
return cls("quesnet", name, pretrained_t2v=True, model_dir=model_dir,
return cls("quesnet", name, pretrained_t2v=True, model_dir=model_dir, device=device,
tokenizer_kwargs=tokenizer_kwargs)


Expand All @@ -520,7 +548,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
}


def get_pretrained_i2v(name, model_dir=MODEL_DIR):
def get_pretrained_i2v(name, model_dir=MODEL_DIR, device='cpu'):
"""
It is a good idea if you want to switch item to vector earily.

Expand Down Expand Up @@ -560,4 +588,4 @@ def get_pretrained_i2v(name, model_dir=MODEL_DIR):
)
_, t2v = get_pretrained_model_info(name)
_class, *params = MODEL_MAP[t2v], name
return _class.from_pretrained(*params, model_dir=model_dir)
return _class.from_pretrained(*params, model_dir=model_dir, device=device)
2 changes: 2 additions & 0 deletions EduNLP/ModelZoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .utils import *
from .bert import *
from .rnn import *
from .disenqnet import *
from .quesnet import *
2 changes: 1 addition & 1 deletion EduNLP/ModelZoo/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def from_pretrained(cls, pretrained_model_path, *args, **kwargs):
config_path = os.path.join(pretrained_model_path, "config.json")
model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
model = cls.from_config(config_path, *args, **kwargs)
loaded_state_dict = torch.load(model_path)
loaded_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
loaded_keys = loaded_state_dict.keys()
expected_keys = model.state_dict().keys()

Expand Down
71 changes: 42 additions & 29 deletions EduNLP/ModelZoo/bert/bert.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import torch
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from baize.torch import load_net
import torch.nn.functional as F
import json
import os
from ..base_model import BaseModel
from transformers.modeling_outputs import ModelOutput
from transformers import BertModel, PretrainedConfig
from typing import List, Optional
from ..utils import PropertyPredictionOutput, KnowledgePredictionOutput
from transformers import BertModel, PretrainedConfig, BertConfig
from typing import List
from ..rnn.harnn import HAM

__all__ = ["BertForPropertyPrediction", "BertForKnowledgePrediction"]


class BertForPPOutput(ModelOutput):
loss: torch.FloatTensor = None
logits: torch.FloatTensor = None
__all__ = ["BertForPropertyPrediction", "BertForKnowledgePrediction"]


class BertForPropertyPrediction(BaseModel):
def __init__(self, pretrained_model_dir=None, head_dropout=0.5):
def __init__(self, pretrained_model_dir=None, head_dropout=0.5, init=True):
super(BertForPropertyPrediction, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_dir)
bert_config = BertConfig.from_pretrained(pretrained_model_dir)
if init:
print(f'Load BertModel from checkpoint: {pretrained_model_dir}')
self.bert = BertModel.from_pretrained(pretrained_model_dir)
else:
print(f'Load BertModel from config: {pretrained_model_dir}')
self.bert = BertModel(bert_config)
self.hidden_size = self.bert.config.hidden_size
self.head_dropout = head_dropout
self.dropout = nn.Dropout(head_dropout)
self.classifier = nn.Linear(self.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()

self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__"]}
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "bert_config"]}
self.config['architecture'] = 'BertForPropertyPrediction'
self.config = PretrainedConfig.from_dict(self.config)

Expand All @@ -47,44 +46,54 @@ def forward(self,
loss = None
if labels is not None:
loss = self.criterion(logits, labels) if labels is not None else None
return BertForPPOutput(
return PropertyPredictionOutput(
loss=loss,
logits=logits,
)

@classmethod
def from_config(cls, config_path, **kwargs):
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
model_config.update(kwargs)
return cls(
pretrained_model_dir=model_config['pretrained_model_dir'],
head_dropout=model_config.get("head_dropout", 0.5)
head_dropout=model_config.get("head_dropout", 0.5),
init=model_config.get('init', False)
)

# @classmethod
# def from_pretrained(cls):
# NotImplementedError
# # 需要验证是否和huggingface的模型兼容
def save_config(self, config_dir):
config_path = os.path.join(config_dir, "model_config.json")
with open(config_path, "w", encoding="utf-8") as wf:
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
self.bert.config.save_pretrained(config_dir)


class BertForKnowledgePrediction(BaseModel):
def __init__(self,
pretrained_model_dir=None,
num_classes_list: List[int] = None,
num_total_classes: int = None,
pretrained_model_dir=None,
head_dropout=0.5,
flat_cls_weight=0.5,
attention_unit_size=256,
fc_hidden_size=512,
beta=0.5,
init=True
):
super(BertForKnowledgePrediction, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_dir)
bert_config = BertConfig.from_pretrained(pretrained_model_dir)
if init:
print(f'Load BertModel from checkpoint: {pretrained_model_dir}')
self.bert = BertModel.from_pretrained(pretrained_model_dir)
else:
print(f'Load BertModel from config: {pretrained_model_dir}')
self.bert = BertModel(bert_config)
self.hidden_size = self.bert.config.hidden_size
self.head_dropout = head_dropout
self.dropout = nn.Dropout(head_dropout)
self.classifier = nn.Linear(self.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()
self.flat_classifier = nn.Linear(self.hidden_size, num_total_classes)
Expand All @@ -101,7 +110,7 @@ def __init__(self,
self.num_classes_list = num_classes_list
self.num_total_classes = num_total_classes

self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__"]}
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "bert_config"]}
self.config['architecture'] = 'BertForKnowledgePrediction'
self.config = PretrainedConfig.from_dict(self.config)

Expand All @@ -124,15 +133,17 @@ def forward(self,
labels = torch.sum(torch.nn.functional.one_hot(labels, num_classes=self.num_total_classes), dim=1)
labels = labels.float()
loss = self.criterion(logits, labels) if labels is not None else None
return BertForPPOutput(
return KnowledgePredictionOutput(
loss=loss,
logits=logits,
)

@classmethod
def from_config(cls, config_path, **kwargs):
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
model_config.update(kwargs)
return cls(
pretrained_model_dir=model_config['pretrained_model_dir'],
Expand All @@ -143,9 +154,11 @@ def from_config(cls, config_path, **kwargs):
attention_unit_size=model_config.get('attention_unit_size', 256),
fc_hidden_size=model_config.get('fc_hidden_size', 512),
beta=model_config.get('beta', 0.5),
init=model_config.get('init', False)
)

# @classmethod
# def from_pretrained(cls):
# NotImplementedError
# # 需要验证是否和huggingface的模型兼容
def save_config(self, config_dir):
config_path = os.path.join(config_dir, "model_config.json")
with open(config_path, "w", encoding="utf-8") as wf:
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
self.bert.config.save_pretrained(config_dir)
Loading
Loading