Tariff corr mlm 20250915 v1
In [18]:
Copied!
import os, json, pathlib
RES_DIR = "/home/jovyan/work/resources"
CONF_PATH = f"{RES_DIR}/confusion.json"
MAP_PATH = f"{RES_DIR}/manual_map.json"
DOM_PATH = f"{RES_DIR}/domain_terms.txt"
os.makedirs(RES_DIR, exist_ok=True)
def _ensure_file(path, default_text):
if not os.path.exists(path):
with open(path, "w", encoding="utf-8") as f: f.write(default_text)
# 先放空白;之後你可累積
_ensure_file(CONF_PATH, "{}\n")
_ensure_file(MAP_PATH, "{}\n")
_ensure_file(DOM_PATH, "\n")
def load_resources():
with open(CONF_PATH, "r", encoding="utf-8") as f: CONFUSION = json.load(f) # dict[str, list[str]]
with open(MAP_PATH, "r", encoding="utf-8") as f: MANUAL_MAP = json.load(f) # dict[str, str]
with open(DOM_PATH, "r", encoding="utf-8") as f: DOMAIN = set(w.strip() for w in f if w.strip())
return CONFUSION, MANUAL_MAP, DOMAIN
CONFUSION, MANUAL_MAP, DOMAIN = load_resources()
len(CONFUSION), len(MANUAL_MAP), len(DOMAIN)
import os, json, pathlib RES_DIR = "/home/jovyan/work/resources" CONF_PATH = f"{RES_DIR}/confusion.json" MAP_PATH = f"{RES_DIR}/manual_map.json" DOM_PATH = f"{RES_DIR}/domain_terms.txt" os.makedirs(RES_DIR, exist_ok=True) def _ensure_file(path, default_text): if not os.path.exists(path): with open(path, "w", encoding="utf-8") as f: f.write(default_text) # 先放空白;之後你可累積 _ensure_file(CONF_PATH, "{}\n") _ensure_file(MAP_PATH, "{}\n") _ensure_file(DOM_PATH, "\n") def load_resources(): with open(CONF_PATH, "r", encoding="utf-8") as f: CONFUSION = json.load(f) # dict[str, list[str]] with open(MAP_PATH, "r", encoding="utf-8") as f: MANUAL_MAP = json.load(f) # dict[str, str] with open(DOM_PATH, "r", encoding="utf-8") as f: DOMAIN = set(w.strip() for w in f if w.strip()) return CONFUSION, MANUAL_MAP, DOMAIN CONFUSION, MANUAL_MAP, DOMAIN = load_resources() len(CONFUSION), len(MANUAL_MAP), len(DOMAIN)
Out[18]:
(0, 0, 4086)
In [19]:
Copied!
# 需要:pip install -q pymupdf regex
import os, re, fitz
from collections import Counter
# 路徑設定
RES_DIR = "/home/jovyan/work/resources"
DOM_PATH = f"{RES_DIR}/domain_terms.txt"
PDF_DIR = "/home/jovyan/work/_Material/法源" # ← 你的 PDF 根目錄
MD_DIR = "/home/jovyan/work/mkdocs/My_Notes/01課程筆記" # ← 你的已修好 .md/.mdx 根目錄
# 參數
MIN_LEN = 2
TOP_K = 4000
BONUS_KW = ("關稅","估價","關稅估價","完稅","價格","協定","報關","稅捐","條","第","款","章","項")
os.makedirs(RES_DIR, exist_ok=True)
# 工具:列舉檔案
def list_files(root: str, exts: tuple[str,...]):
out = []
for dp, _, fs in os.walk(root):
for f in fs:
if f.lower().endswith(exts):
out.append(os.path.join(dp, f))
return sorted(out)
# 工具:MD 清洗
fm_pat = re.compile(r"^---[\s\S]*?---\s*", re.MULTILINE) # front-matter
fence_pat = re.compile(r"```[\s\S]*?```|~~~[\s\S]*?~~~", re.MULTILINE)
indent_pat = re.compile(r"(^ {4}.+?$)", re.MULTILINE)
inline_pat = re.compile(r"`[^`]+`")
link_pat = re.compile(r"!$begin:math:display$[^$end:math:display$]*\]$begin:math:text$[^$end:math:text$]*\)|$begin:math:display$[^$end:math:display$]*\]$begin:math:text$[^$end:math:text$]*\)")
url_pat = re.compile(r"https?://\S+")
html_pat = re.compile(r"<[^>]+>")
hash_pat = re.compile(r"^#+\s*", re.MULTILINE)
tbl_pat = re.compile(r"^\|.*\|$", re.MULTILINE)
def clean_md(txt: str) -> str:
txt = fm_pat.sub(" ", txt)
txt = fence_pat.sub(" ", txt)
txt = indent_pat.sub(" ", txt)
txt = inline_pat.sub(" ", txt)
txt = link_pat.sub(" ", txt)
txt = url_pat.sub(" ", txt)
txt = html_pat.sub(" ", txt)
txt = hash_pat.sub("", txt)
txt = tbl_pat.sub(" ", txt)
return txt
token_pat = re.compile(rf'[\u4e00-\u9fffA-Za-z0-9]{{{MIN_LEN},}}')
# 從 PDF 抽詞
def extract_terms_from_pdfs(paths, top_k=4000):
freq = Counter()
for p in paths:
try:
doc = fitz.open(p)
for pg in doc:
t = pg.get_text("text") or ""
for w in token_pat.findall(t):
freq[w] += 3 if any(b in w for b in BONUS_KW) else 1
except Exception as e:
print("略過無法讀取:", p, "|", e)
return [w for w,_ in freq.most_common(top_k)]
# 從 MD 抽詞
def extract_terms_from_mds(paths, top_k=4000):
freq = Counter()
for p in paths:
try:
t = open(p, "r", encoding="utf-8", errors="ignore").read()
t = clean_md(t)
for w in token_pat.findall(t):
freq[w] += 3 if any(b in w for b in BONUS_KW) else 1
except Exception as e:
print("略過無法讀取:", p, "|", e)
return [w for w,_ in freq.most_common(top_k)]
# 執行
pdfs = list_files(PDF_DIR, (".pdf",))
mds = list_files(MD_DIR, (".md", ".mdx"))
print(f"PDF 檔數:{len(pdfs)};MD 檔數:{len(mds)}")
terms_pdf = extract_terms_from_pdfs(pdfs, top_k=TOP_K)
terms_md = extract_terms_from_mds(mds, top_k=TOP_K)
# 合併既有 DOMAIN
domain = set()
if os.path.exists(DOM_PATH):
with open(DOM_PATH, "r", encoding="utf-8") as f:
domain |= {x.strip() for x in f if x.strip()}
before = len(domain)
domain |= set(terms_pdf)
domain |= set(terms_md)
with open(DOM_PATH, "w", encoding="utf-8") as f:
f.write("\n".join(sorted(domain)))
print(f"已更新 DOMAIN 條目數:{len(domain)}(+{len(domain)-before})")
# 需要:pip install -q pymupdf regex import os, re, fitz from collections import Counter # 路徑設定 RES_DIR = "/home/jovyan/work/resources" DOM_PATH = f"{RES_DIR}/domain_terms.txt" PDF_DIR = "/home/jovyan/work/_Material/法源" # ← 你的 PDF 根目錄 MD_DIR = "/home/jovyan/work/mkdocs/My_Notes/01課程筆記" # ← 你的已修好 .md/.mdx 根目錄 # 參數 MIN_LEN = 2 TOP_K = 4000 BONUS_KW = ("關稅","估價","關稅估價","完稅","價格","協定","報關","稅捐","條","第","款","章","項") os.makedirs(RES_DIR, exist_ok=True) # 工具:列舉檔案 def list_files(root: str, exts: tuple[str,...]): out = [] for dp, _, fs in os.walk(root): for f in fs: if f.lower().endswith(exts): out.append(os.path.join(dp, f)) return sorted(out) # 工具:MD 清洗 fm_pat = re.compile(r"^---[\s\S]*?---\s*", re.MULTILINE) # front-matter fence_pat = re.compile(r"```[\s\S]*?```|~~~[\s\S]*?~~~", re.MULTILINE) indent_pat = re.compile(r"(^ {4}.+?$)", re.MULTILINE) inline_pat = re.compile(r"`[^`]+`") link_pat = re.compile(r"!$begin:math:display$[^$end:math:display$]*\]$begin:math:text$[^$end:math:text$]*\)|$begin:math:display$[^$end:math:display$]*\]$begin:math:text$[^$end:math:text$]*\)") url_pat = re.compile(r"https?://\S+") html_pat = re.compile(r"<[^>]+>") hash_pat = re.compile(r"^#+\s*", re.MULTILINE) tbl_pat = re.compile(r"^\|.*\|$", re.MULTILINE) def clean_md(txt: str) -> str: txt = fm_pat.sub(" ", txt) txt = fence_pat.sub(" ", txt) txt = indent_pat.sub(" ", txt) txt = inline_pat.sub(" ", txt) txt = link_pat.sub(" ", txt) txt = url_pat.sub(" ", txt) txt = html_pat.sub(" ", txt) txt = hash_pat.sub("", txt) txt = tbl_pat.sub(" ", txt) return txt token_pat = re.compile(rf'[\u4e00-\u9fffA-Za-z0-9]{{{MIN_LEN},}}') # 從 PDF 抽詞 def extract_terms_from_pdfs(paths, top_k=4000): freq = Counter() for p in paths: try: doc = fitz.open(p) for pg in doc: t = pg.get_text("text") or "" for w in token_pat.findall(t): freq[w] += 3 if any(b in w for b in BONUS_KW) else 1 except Exception as e: print("略過無法讀取:", p, "|", e) return [w for w,_ in freq.most_common(top_k)] # 從 MD 抽詞 def extract_terms_from_mds(paths, top_k=4000): freq = Counter() for p in paths: try: t = open(p, "r", encoding="utf-8", errors="ignore").read() t = clean_md(t) for w in token_pat.findall(t): freq[w] += 3 if any(b in w for b in BONUS_KW) else 1 except Exception as e: print("略過無法讀取:", p, "|", e) return [w for w,_ in freq.most_common(top_k)] # 執行 pdfs = list_files(PDF_DIR, (".pdf",)) mds = list_files(MD_DIR, (".md", ".mdx")) print(f"PDF 檔數:{len(pdfs)};MD 檔數:{len(mds)}") terms_pdf = extract_terms_from_pdfs(pdfs, top_k=TOP_K) terms_md = extract_terms_from_mds(mds, top_k=TOP_K) # 合併既有 DOMAIN domain = set() if os.path.exists(DOM_PATH): with open(DOM_PATH, "r", encoding="utf-8") as f: domain |= {x.strip() for x in f if x.strip()} before = len(domain) domain |= set(terms_pdf) domain |= set(terms_md) with open(DOM_PATH, "w", encoding="utf-8") as f: f.write("\n".join(sorted(domain))) print(f"已更新 DOMAIN 條目數:{len(domain)}(+{len(domain)-before})")
PDF 檔數:4;MD 檔數:4 已更新 DOMAIN 條目數:4086(+0)
In [20]:
Copied!
# pip install -q transformers torch rapidfuzz regex pangu zhon
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch, re, difflib, pangu
from rapidfuzz import process, fuzz
tok = AutoTokenizer.from_pretrained("bert-base-chinese")
mlm = AutoModelForMaskedLM.from_pretrained("bert-base-chinese").eval()
ZH_PUNC = {",":",",".":"。","?":"?","!":"!",":":":",";":";"}
def norm_punc(s):
s = s.translate(str.maketrans(ZH_PUNC))
for p in ",。?!:;": s = s.replace(" "+p, p)
return pangu.spacing_text(s)
def mlm_topk_char(sent_chars, i, k=5):
# Restrict context to model max length with a centered window.
max_len = getattr(mlm.config, 'max_position_embeddings', 512)
window = max(2, max_len - 2) # reserve [CLS],[SEP]
left = (window - 1) // 2
right = window - 1 - left
start = max(0, i - left)
end = min(len(sent_chars), i + 1 + right)
sub_chars = sent_chars[start:end]
# Tokenize with truncation and get masks/type ids to avoid buffer mismatch.
t = tok(sub_chars, is_split_into_words=True, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=max_len)
ids = t["input_ids"][0].clone()
pos = (i - start) + 1 # [CLS] offset within the window
ids[pos] = tok.mask_token_id
# Prepare inputs with attention and token_type ids.
inputs = {"input_ids": ids.unsqueeze(0)}
if "token_type_ids" in t:
inputs["token_type_ids"] = t["token_type_ids"]
else:
inputs["token_type_ids"] = torch.zeros_like(ids).unsqueeze(0)
if "attention_mask" in t:
inputs["attention_mask"] = t["attention_mask"]
else:
inputs["attention_mask"] = torch.ones_like(ids).unsqueeze(0)
with torch.inference_mode():
logits = mlm(**inputs).logits[0, pos]
topk = torch.topk(logits, k)
return [tok.convert_ids_to_tokens(int(x)) for x in topk.indices]
def protect_terms(text, terms):
for t in sorted(terms, key=len, reverse=True):
text = text.replace(t, f"⟦{t}⟧")
return text
def unprotect_terms(text):
return text.replace("⟦","").replace("⟧","")
def correct_sentence(sent, CONFUSION, MANUAL_MAP, DOMAIN):
# 詞級先行(你的人工作業會逐步累積到 MANUAL_MAP)
for a,b in sorted(MANUAL_MAP.items(), key=lambda x: len(x[0]), reverse=True):
sent = re.sub(re.escape(a), b, sent)
s = protect_terms(sent, DOMAIN)
chars = list(s)
for i,ch in enumerate(chars):
if ch in "⟦⟧": continue
# 候選取「混淆表」如有,否則允許原字
cand_pool = CONFUSION.get(ch, [])
if ch not in cand_pool: cand_pool = [ch] + list(cand_pool)
# 用 MLM 打分取得高機率候選,與候選池交集
topk = mlm_topk_char(chars, i, k=6)
cands = [c for c in topk if c in cand_pool]
if cands and cands[0] != ch:
chars[i] = cands[0]
out = unprotect_terms("".join(chars))
# 對不在 DOMAIN 的相近詞拉回(可選)
if DOMAIN:
words = set(re.findall(r'[\u4e00-\u9fffA-Za-z0-9]{2,}', out))
for w in words:
if w not in DOMAIN:
hit = process.extractOne(w, DOMAIN, scorer=fuzz.WRatio)
if hit and hit[1] >= 93:
out = re.sub(re.escape(w), hit[0], out)
return norm_punc(out)
def split_sents(text):
return [p for p in re.split(r'(?<=[。!?\?])\s+|\n+', text) if p.strip()]
def correct_text(text, CONFUSION, MANUAL_MAP, DOMAIN):
return "\n".join(correct_sentence(s, CONFUSION, MANUAL_MAP, DOMAIN) for s in split_sents(text))
def show_diff(a, b):
return "".join(difflib.unified_diff(a.splitlines(1), b.splitlines(1), fromfile="orig.txt", tofile="fixed.txt", lineterm=""))
# pip install -q transformers torch rapidfuzz regex pangu zhon from transformers import AutoTokenizer, AutoModelForMaskedLM import torch, re, difflib, pangu from rapidfuzz import process, fuzz tok = AutoTokenizer.from_pretrained("bert-base-chinese") mlm = AutoModelForMaskedLM.from_pretrained("bert-base-chinese").eval() ZH_PUNC = {",":",",".":"。","?":"?","!":"!",":":":",";":";"} def norm_punc(s): s = s.translate(str.maketrans(ZH_PUNC)) for p in ",。?!:;": s = s.replace(" "+p, p) return pangu.spacing_text(s) def mlm_topk_char(sent_chars, i, k=5): # Restrict context to model max length with a centered window. max_len = getattr(mlm.config, 'max_position_embeddings', 512) window = max(2, max_len - 2) # reserve [CLS],[SEP] left = (window - 1) // 2 right = window - 1 - left start = max(0, i - left) end = min(len(sent_chars), i + 1 + right) sub_chars = sent_chars[start:end] # Tokenize with truncation and get masks/type ids to avoid buffer mismatch. t = tok(sub_chars, is_split_into_words=True, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=max_len) ids = t["input_ids"][0].clone() pos = (i - start) + 1 # [CLS] offset within the window ids[pos] = tok.mask_token_id # Prepare inputs with attention and token_type ids. inputs = {"input_ids": ids.unsqueeze(0)} if "token_type_ids" in t: inputs["token_type_ids"] = t["token_type_ids"] else: inputs["token_type_ids"] = torch.zeros_like(ids).unsqueeze(0) if "attention_mask" in t: inputs["attention_mask"] = t["attention_mask"] else: inputs["attention_mask"] = torch.ones_like(ids).unsqueeze(0) with torch.inference_mode(): logits = mlm(**inputs).logits[0, pos] topk = torch.topk(logits, k) return [tok.convert_ids_to_tokens(int(x)) for x in topk.indices] def protect_terms(text, terms): for t in sorted(terms, key=len, reverse=True): text = text.replace(t, f"⟦{t}⟧") return text def unprotect_terms(text): return text.replace("⟦","").replace("⟧","") def correct_sentence(sent, CONFUSION, MANUAL_MAP, DOMAIN): # 詞級先行(你的人工作業會逐步累積到 MANUAL_MAP) for a,b in sorted(MANUAL_MAP.items(), key=lambda x: len(x[0]), reverse=True): sent = re.sub(re.escape(a), b, sent) s = protect_terms(sent, DOMAIN) chars = list(s) for i,ch in enumerate(chars): if ch in "⟦⟧": continue # 候選取「混淆表」如有,否則允許原字 cand_pool = CONFUSION.get(ch, []) if ch not in cand_pool: cand_pool = [ch] + list(cand_pool) # 用 MLM 打分取得高機率候選,與候選池交集 topk = mlm_topk_char(chars, i, k=6) cands = [c for c in topk if c in cand_pool] if cands and cands[0] != ch: chars[i] = cands[0] out = unprotect_terms("".join(chars)) # 對不在 DOMAIN 的相近詞拉回(可選) if DOMAIN: words = set(re.findall(r'[\u4e00-\u9fffA-Za-z0-9]{2,}', out)) for w in words: if w not in DOMAIN: hit = process.extractOne(w, DOMAIN, scorer=fuzz.WRatio) if hit and hit[1] >= 93: out = re.sub(re.escape(w), hit[0], out) return norm_punc(out) def split_sents(text): return [p for p in re.split(r'(?<=[。!?\?])\s+|\n+', text) if p.strip()] def correct_text(text, CONFUSION, MANUAL_MAP, DOMAIN): return "\n".join(correct_sentence(s, CONFUSION, MANUAL_MAP, DOMAIN) for s in split_sents(text)) def show_diff(a, b): return "".join(difflib.unified_diff(a.splitlines(1), b.splitlines(1), fromfile="orig.txt", tofile="fixed.txt", lineterm=""))
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight'] - This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [ ]:
Copied!
IN_PATH = "/home/jovyan/work/mkdocs/My_Notes/逐字稿/W03-0915-1逐字稿初稿.md" # 改成你的檔案
OUT_PATH = "/home/jovyan/work/mkdocs/My_Notes/逐字稿/W03-0915-1逐字稿-auto_fixed.txt"
raw = open(IN_PATH, "r", encoding="utf-8").read()
CONFUSION, MANUAL_MAP, DOMAIN = load_resources()
fixed = correct_text(raw, CONFUSION, MANUAL_MAP, DOMAIN)
os.makedirs(os.path.dirname(OUT_PATH), exist_ok=True)
open(OUT_PATH, "w", encoding="utf-8").write(fixed)
print(show_diff(raw, fixed))
print("輸出:", OUT_PATH)
IN_PATH = "/home/jovyan/work/mkdocs/My_Notes/逐字稿/W03-0915-1逐字稿初稿.md" # 改成你的檔案 OUT_PATH = "/home/jovyan/work/mkdocs/My_Notes/逐字稿/W03-0915-1逐字稿-auto_fixed.txt" raw = open(IN_PATH, "r", encoding="utf-8").read() CONFUSION, MANUAL_MAP, DOMAIN = load_resources() fixed = correct_text(raw, CONFUSION, MANUAL_MAP, DOMAIN) os.makedirs(os.path.dirname(OUT_PATH), exist_ok=True) open(OUT_PATH, "w", encoding="utf-8").write(fixed) print(show_diff(raw, fixed)) print("輸出:", OUT_PATH)
In [ ]:
Copied!
def update_manual_map(pairs):
"""
pairs: List[Tuple['錯','正']]
"""
mp = json.load(open(MAP_PATH,"r",encoding="utf-8"))
for a,b in pairs:
if a and b and a!=b:
mp[a] = b
with open(MAP_PATH,"w",encoding="utf-8") as f:
json.dump(mp, f, ensure_ascii=False, indent=2)
print("已更新 manual_map.json,共", len(mp), "條")
# 範例:把你這輪人工更正回寫
# update_manual_map([("上至", "上字"), ("完稅價", "完稅價格")])
def update_manual_map(pairs): """ pairs: List[Tuple['錯','正']] """ mp = json.load(open(MAP_PATH,"r",encoding="utf-8")) for a,b in pairs: if a and b and a!=b: mp[a] = b with open(MAP_PATH,"w",encoding="utf-8") as f: json.dump(mp, f, ensure_ascii=False, indent=2) print("已更新 manual_map.json,共", len(mp), "條") # 範例:把你這輪人工更正回寫 # update_manual_map([("上至", "上字"), ("完稅價", "完稅價格")])