Skip to content

Commit

Permalink
fix grad typo #517
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Jun 28, 2021
1 parent 3adf276 commit 09ea4e7
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions ltp/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
}


def no_gard(func):
def no_grad(func):
def wrapper(*args, **kwargs):
with torch.no_grad():
return func(*args, **kwargs)
Expand Down Expand Up @@ -204,7 +204,7 @@ def seg_with_dict(self, inputs: List[str], tokenized: BatchEncoding, batch_prefi
matching.append(matching_pos)
return matching

@no_gard
@no_grad
def _seg(self, tokenizerd, is_preseged=False):
input_ids = tokenizerd['input_ids'].to(self.device)
attention_mask = tokenizerd['attention_mask'].to(self.device)
Expand All @@ -228,7 +228,7 @@ def _seg(self, tokenizerd, is_preseged=False):
segment_output = segment_output.decoded or torch.argmax(segment_output.logits, dim=-1).cpu().numpy()
return word_cls, char_input, segment_output, length

@no_gard
@no_grad
def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True, is_preseged=False):
"""
分词
Expand Down Expand Up @@ -326,7 +326,7 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True
'word_cls_input': word_cls_input, 'word_cls_mask': word_cls_mask
}

@no_gard
@no_grad
def pos(self, hidden: dict):
"""
词性标注
Expand All @@ -343,7 +343,7 @@ def pos(self, hidden: dict):
postagger_output = convert_idx_to_name(postagger_output, hidden['word_length'], self.pos_vocab)
return postagger_output

@no_gard
@no_grad
def ner(self, hidden: dict, as_entities=True):
"""
命名实体识别
Expand All @@ -363,7 +363,7 @@ def ner(self, hidden: dict, as_entities=True):
ner_output = convert_idx_to_name(ner_output, hidden['word_length'], self.ner_vocab)
return [get_entities(ner) for ner in ner_output] if as_entities else ner_output

@no_gard
@no_grad
def srl(self, hidden: dict, keep_empty=True):
"""
语义角色标注
Expand Down Expand Up @@ -395,7 +395,7 @@ def srl(self, hidden: dict, keep_empty=True):
]
return srl_labels_res

@no_gard
@no_grad
def dep(self, hidden: dict, fast=True, as_tuple=True):
"""
依存句法树
Expand Down Expand Up @@ -441,7 +441,7 @@ def dep(self, hidden: dict, fast=True, as_tuple=True):
for heads, rels in zip(head_pred, rel_pred)
]

@no_gard
@no_grad
def sdp(self, hidden: dict, mode: str = 'graph'):
"""
语义依存图(树)
Expand Down

0 comments on commit 09ea4e7

Please sign in to comment.