diff --git a/ltp/frontend.py b/ltp/frontend.py index 5358d6e4..9f1a0b69 100644 --- a/ltp/frontend.py +++ b/ltp/frontend.py @@ -47,7 +47,7 @@ } -def no_gard(func): +def no_grad(func): def wrapper(*args, **kwargs): with torch.no_grad(): return func(*args, **kwargs) @@ -204,7 +204,7 @@ def seg_with_dict(self, inputs: List[str], tokenized: BatchEncoding, batch_prefi matching.append(matching_pos) return matching - @no_gard + @no_grad def _seg(self, tokenizerd, is_preseged=False): input_ids = tokenizerd['input_ids'].to(self.device) attention_mask = tokenizerd['attention_mask'].to(self.device) @@ -228,7 +228,7 @@ def _seg(self, tokenizerd, is_preseged=False): segment_output = segment_output.decoded or torch.argmax(segment_output.logits, dim=-1).cpu().numpy() return word_cls, char_input, segment_output, length - @no_gard + @no_grad def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True, is_preseged=False): """ 分词 @@ -326,7 +326,7 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True 'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask } - @no_gard + @no_grad def pos(self, hidden: dict): """ 词性标注 @@ -343,7 +343,7 @@ def pos(self, hidden: dict): postagger_output = convert_idx_to_name(postagger_output, hidden['word_length'], self.pos_vocab) return postagger_output - @no_gard + @no_grad def ner(self, hidden: dict, as_entities=True): """ 命名实体识别 @@ -363,7 +363,7 @@ def ner(self, hidden: dict, as_entities=True): ner_output = convert_idx_to_name(ner_output, hidden['word_length'], self.ner_vocab) return [get_entities(ner) for ner in ner_output] if as_entities else ner_output - @no_gard + @no_grad def srl(self, hidden: dict, keep_empty=True): """ 语义角色标注 @@ -395,7 +395,7 @@ def srl(self, hidden: dict, keep_empty=True): ] return srl_labels_res - @no_gard + @no_grad def dep(self, hidden: dict, fast=True, as_tuple=True): """ 依存句法树 @@ -441,7 +441,7 @@ def dep(self, hidden: dict, fast=True, as_tuple=True): for heads, rels in zip(head_pred, rel_pred) ] - @no_gard + @no_grad def sdp(self, hidden: dict, mode: str = 'graph'): """ 语义依存图(树)