Files
bertalign/bertalign/aligner.py
2022-04-17 00:15:50 +08:00

98 lines
3.8 KiB
Python

import numpy as np
from bertalign import model
from bertalign.corelib import *
from bertalign.utils import *
class Bertalign:
def __init__(self,
src,
tgt,
max_align=5,
top_k=3,
win=5,
skip=-0.1,
margin=True,
len_penalty=True,
is_split=False,
):
self.max_align = max_align
self.top_k = top_k
self.win = win
self.skip = skip
self.margin = margin
self.len_penalty = len_penalty
src = clean_text(src)
tgt = clean_text(tgt)
src_lang = detect_lang(src)
tgt_lang = detect_lang(tgt)
if is_split:
src_sents = src.splitlines()
tgt_sents = tgt.splitlines()
else:
src_sents = split_sents(src, src_lang)
tgt_sents = split_sents(tgt, tgt_lang)
src_num = len(src_sents)
tgt_num = len(tgt_sents)
src_lang = LANG.ISO[src_lang]
tgt_lang = LANG.ISO[tgt_lang]
print("Source language: {}, Number of sentences: {}".format(src_lang, src_num))
print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num))
print("Embedding source and target text using {} ...".format(model.model_name))
src_vecs, src_lens = model.transform(src_sents, max_align - 1)
tgt_vecs, tgt_lens = model.transform(tgt_sents, max_align - 1)
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.src_sents = src_sents
self.tgt_sents = tgt_sents
self.src_num = src_num
self.tgt_num = tgt_num
self.src_lens = src_lens
self.tgt_lens = tgt_lens
self.char_ratio = char_ratio
self.src_vecs = src_vecs
self.tgt_vecs = tgt_vecs
def align_sents(self):
print("Performing first-step alignment ...")
D, I = find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k)
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
first_w, first_path = find_first_search_path(self.src_num, self.tgt_num)
first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)
first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types)
print("Performing second-step alignment ...")
second_alignment_types = get_alignment_types(self.max_align)
second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num)
second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens,
second_w, second_path, second_alignment_types,
self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty)
second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types)
print("Finished! Successfuly aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang))
self.result = second_alignment
def print_sents(self):
for bead in (self.result):
src_line = self._get_line(bead[0], self.src_sents)
tgt_line = self._get_line(bead[1], self.tgt_sents)
print(src_line + "\n" + tgt_line + "\n")
@staticmethod
def _get_line(bead, lines):
line = ''
if len(bead) > 0:
line = ' '.join(lines[bead[0]:bead[-1]+1])
return line