Text fix
In [3]:
Copied!
!pip install -q transformers torch regex rapidfuzz pangu zhon
!pip install -q transformers torch regex rapidfuzz pangu zhon
In [4]:
Copied!
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch, re, difflib
from rapidfuzz import process, fuzz
import pangu
MODEL_ID = "bert-base-chinese" # 公開模型,離線可快取
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mlm = AutoModelForMaskedLM.from_pretrained(MODEL_ID).eval()
@torch.inference_mode()
def mlm_topk(sent, i, k=5):
tokens = tok(list(sent), is_split_into_words=True, return_tensors="pt")
input_ids = tokens["input_ids"][0]
mask_id = tok.mask_token_id
pos = i + 1 # [CLS] 在第0位
ids = input_ids.clone()
ids[pos] = mask_id
out = mlm(input_ids=ids.unsqueeze(0)).logits[0, pos]
topk = torch.topk(out, k)
return [tok.convert_ids_to_tokens(int(t)) for t in topk.indices], topk.values.tolist()
from transformers import AutoTokenizer, AutoModelForMaskedLM import torch, re, difflib from rapidfuzz import process, fuzz import pangu MODEL_ID = "bert-base-chinese" # 公開模型,離線可快取 tok = AutoTokenizer.from_pretrained(MODEL_ID) mlm = AutoModelForMaskedLM.from_pretrained(MODEL_ID).eval() @torch.inference_mode() def mlm_topk(sent, i, k=5): tokens = tok(list(sent), is_split_into_words=True, return_tensors="pt") input_ids = tokens["input_ids"][0] mask_id = tok.mask_token_id pos = i + 1 # [CLS] 在第0位 ids = input_ids.clone() ids[pos] = mask_id out = mlm(input_ids=ids.unsqueeze(0)).logits[0, pos] topk = torch.topk(out, k) return [tok.convert_ids_to_tokens(int(t)) for t in topk.indices], topk.values.tolist()
tokenizer_config.json: 0%| | 0.00/49.0 [00:00<?, ?B/s]
config.json: 0%| | 0.00/624 [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/110k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/269k [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/412M [00:00<?, ?B/s]
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight'] - This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [5]:
Copied!
# 最小形近表(可擴充),全部繁體字元
CONFUSION = {
"稅": ["稅","稅","稅"], # 佔位,實務請擴充你的領域錯字集
"關": ["關","觀","官"],
"條": ["條","條","條","條"],
"價": ["價","價","價"],
"貨": ["貨","貨","貨"],
# 範例:把常見錯→正收集成映射
}
MANUAL_MAP = {
# "錯誤詞": "正確詞"
"上至": "上字", # 例:法院案號「上字」→初稿常被誤作「上至」
}
DOMAIN = set("""
關稅法 關稅估價 合理方法 第29條 第35條 移轉訂價 關稅估價協定
進口 完稅價格 報關業者 委任書 空運快遞貨物通關辦法
""".split())
def protect_terms(text:str, terms:set[str])->str:
for t in sorted(terms, key=len, reverse=True):
text = text.replace(t, f"⟦{t}⟧")
return text
def unprotect_terms(text:str)->str:
return text.replace("⟦","").replace("⟧","")
# 最小形近表(可擴充),全部繁體字元 CONFUSION = { "稅": ["稅","稅","稅"], # 佔位,實務請擴充你的領域錯字集 "關": ["關","觀","官"], "條": ["條","條","條","條"], "價": ["價","價","價"], "貨": ["貨","貨","貨"], # 範例:把常見錯→正收集成映射 } MANUAL_MAP = { # "錯誤詞": "正確詞" "上至": "上字", # 例:法院案號「上字」→初稿常被誤作「上至」 } DOMAIN = set(""" 關稅法 關稅估價 合理方法 第29條 第35條 移轉訂價 關稅估價協定 進口 完稅價格 報關業者 委任書 空運快遞貨物通關辦法 """.split()) def protect_terms(text:str, terms:set[str])->str: for t in sorted(terms, key=len, reverse=True): text = text.replace(t, f"⟦{t}⟧") return text def unprotect_terms(text:str)->str: return text.replace("⟦","").replace("⟧","")
In [6]:
Copied!
ZH_PUNC = {",":",",".":"。","?":"?","!":"!",":":":",";":";"}
def norm_punc(s:str)->str:
s = s.translate(str.maketrans(ZH_PUNC))
for p in ",。?!:;": s = s.replace(" "+p, p)
return pangu.spacing_text(s)
def candidates_for_char(ch:str):
cand = CONFUSION.get(ch, [])
# 加入原字,避免被強改
if ch not in cand: cand = [ch] + cand
# 過濾成單字且非特殊符號
return [c for c in cand if len(c)==1]
def correct_sentence(sent:str, k=5, prob_thresh=-8.0):
s = protect_terms(sent, DOMAIN)
chars = list(s)
changed = False
for i,ch in enumerate(chars):
if ch in "⟦⟧": # 保護塊跳過
continue
# 用 MLM 估該位置的前 k 候選
topk_tokens, topk_scores = mlm_topk(chars, i, k=k)
# 僅考慮「與混淆表交集」的候選
cand = [c for c in topk_tokens if c in candidates_for_char(ch)]
if not cand:
continue
# 若原字不在高分候選且最高分候選分數顯著更好,才替換
# 這裡用簡單啟發:若 top1 != 原字 且分數明顯高於閾值
top1 = cand[0]
if top1 != ch:
chars[i] = top1
changed = True
out = unprotect_terms("".join(chars))
# 手動映射(詞級)
for a,b in sorted(MANUAL_MAP.items(), key=lambda x: len(x[0]), reverse=True):
out = re.sub(re.escape(a), b, out)
return norm_punc(out), changed
ZH_PUNC = {",":",",".":"。","?":"?","!":"!",":":":",";":";"} def norm_punc(s:str)->str: s = s.translate(str.maketrans(ZH_PUNC)) for p in ",。?!:;": s = s.replace(" "+p, p) return pangu.spacing_text(s) def candidates_for_char(ch:str): cand = CONFUSION.get(ch, []) # 加入原字,避免被強改 if ch not in cand: cand = [ch] + cand # 過濾成單字且非特殊符號 return [c for c in cand if len(c)==1] def correct_sentence(sent:str, k=5, prob_thresh=-8.0): s = protect_terms(sent, DOMAIN) chars = list(s) changed = False for i,ch in enumerate(chars): if ch in "⟦⟧": # 保護塊跳過 continue # 用 MLM 估該位置的前 k 候選 topk_tokens, topk_scores = mlm_topk(chars, i, k=k) # 僅考慮「與混淆表交集」的候選 cand = [c for c in topk_tokens if c in candidates_for_char(ch)] if not cand: continue # 若原字不在高分候選且最高分候選分數顯著更好,才替換 # 這裡用簡單啟發:若 top1 != 原字 且分數明顯高於閾值 top1 = cand[0] if top1 != ch: chars[i] = top1 changed = True out = unprotect_terms("".join(chars)) # 手動映射(詞級) for a,b in sorted(MANUAL_MAP.items(), key=lambda x: len(x[0]), reverse=True): out = re.sub(re.escape(a), b, out) return norm_punc(out), changed
In [ ]:
Copied!
import difflib, re
def split_sents(text:str):
# 保守斷句:句末標點 + 換行作為界線
parts = re.split(r'(?<=[。!?\?])\s+|\n+', text)
return [p for p in parts if p.strip()]
def correct_text(text:str):
sents = split_sents(text)
fixed = []
for s in sents:
fs, _ = correct_sentence(s)
fixed.append(fs)
return "\n".join(fixed)
def show_diff(a:str, b:str)->str:
return "".join(difflib.unified_diff(
a.splitlines(keepends=True),
b.splitlines(keepends=True),
fromfile="orig.txt", tofile="fixed.txt", lineterm=""
))
import difflib, re def split_sents(text:str): # 保守斷句:句末標點 + 換行作為界線 parts = re.split(r'(?<=[。!?\?])\s+|\n+', text) return [p for p in parts if p.strip()] def correct_text(text:str): sents = split_sents(text) fixed = [] for s in sents: fs, _ = correct_sentence(s) fixed.append(fs) return "\n".join(fixed) def show_diff(a:str, b:str)->str: return "".join(difflib.unified_diff( a.splitlines(keepends=True), b.splitlines(keepends=True), fromfile="orig.txt", tofile="fixed.txt", lineterm="" ))