Skip to content

Commit

Permalink
修复了Batch下length计算错误的问题 #376
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Jul 2, 2020
1 parent df9f862 commit f6944d8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
4 changes: 3 additions & 1 deletion ltp/fastltp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, *args, **kwargs):
import onnxruntime as rt
so = rt.SessionOptions()
so.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# fixme should auto detect
providers = ['CPUExecutionProvider'] if self.device.type == 'cpu' else ['GPUExecutionProvider']

onnx_path = os.path.join(self.cache_dir, "ltp.onnx")
Expand Down Expand Up @@ -100,9 +102,9 @@ def forward(self, *args, **kwargs):

@no_gard
def seg(self, inputs: List[str]):
length = [len(text) for text in inputs]
tokenizerd = self.tokenizer.batch_encode_plus(inputs, padding=True)
pretrained_inputs = {key: convert(value) for key, value in tokenizerd.items()}
length = np.sum(pretrained_inputs['attention_mask'], axis=-1) - 2

# todo: io binding
cls, hidden, seg = self.onnx.run(None, pretrained_inputs)
Expand Down
13 changes: 9 additions & 4 deletions ltp/ltp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,17 @@ def sent_split(inputs: List[str], flag: str = "all", limit: int = 510):

@no_gard
def seg(self, inputs: List[str]):
length = torch.as_tensor([len(text) for text in inputs], device=self.device)
tokenizerd = self.tokenizer.batch_encode_plus(inputs, return_tensors='pt', padding=True)

input_ids = tokenizerd['input_ids'].to(self.device)
attention_mask = tokenizerd['attention_mask'].to(self.device)
token_type_ids = tokenizerd['token_type_ids'].to(self.device)
length = torch.sum(attention_mask, dim=-1) - 2

pretrained_output, *_ = self.model.pretrained(
input_ids=tokenizerd['input_ids'].to(self.device),
attention_mask=tokenizerd['attention_mask'].to(self.device),
token_type_ids=tokenizerd['token_type_ids'].to(self.device)
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)

# remove [CLS] [SEP]
Expand Down

0 comments on commit f6944d8

Please sign in to comment.