Bertalign and evaluation scripts
This commit is contained in:
450
bin/bert_align.py
Normal file
450
bin/bert_align.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
# 2021/11/27
|
||||||
|
# bfsujason@163.com
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python bin/bert_align.py \
|
||||||
|
-s data/mac/dev/zh \
|
||||||
|
-t data/mac/dev/en \
|
||||||
|
-o data/mac/dev/auto \
|
||||||
|
-m data/mac/dev/meta_data.tsv \
|
||||||
|
--src_embed data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
|
||||||
|
--tgt_embed data/mac/dev/en/overlap data/mac/dev/en/overlap.emb \
|
||||||
|
--max_align 8 --margin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import faiss
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import numba as nb
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# user-defined parameters
|
||||||
|
parser = argparse.ArgumentParser('Sentence alignment using Vecalign')
|
||||||
|
parser.add_argument('-s', '--src', type=str, required=True, help='preprocessed source file to align')
|
||||||
|
parser.add_argument('-t', '--tgt', type=str, required=True, help='preprocessed target file to align')
|
||||||
|
parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.')
|
||||||
|
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
|
||||||
|
parser.add_argument('--src_embed', type=str, nargs=2, required=True,
|
||||||
|
help='Source embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
|
||||||
|
parser.add_argument('--tgt_embed', type=str, nargs=2, required=True,
|
||||||
|
help='Target embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
|
||||||
|
parser.add_argument('--max_align', type=int, default=5, help='Maximum alignment types, n + m <= this value.')
|
||||||
|
parser.add_argument('--win', type=int, default=5, help='Window size for the second-pass alignment.')
|
||||||
|
parser.add_argument('--top_k', type=int, default=3, help='Top-k target neighbors of each source sentence.')
|
||||||
|
parser.add_argument('--skip', type=float, default=-0.1, help='Similarity score for 0-1 and 1-0 alignment.')
|
||||||
|
parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# fixed parameters to determine the
|
||||||
|
# window size for the first-pass alignment
|
||||||
|
min_win_size = 10
|
||||||
|
max_win_size = 600
|
||||||
|
win_per_100 = 8
|
||||||
|
|
||||||
|
# read in embeddings
|
||||||
|
src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1])
|
||||||
|
tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
|
||||||
|
embedding_size = src_line_embeddings.shape[1]
|
||||||
|
|
||||||
|
make_dir(args.out)
|
||||||
|
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
|
||||||
|
|
||||||
|
# start alignment
|
||||||
|
for rec in jobs:
|
||||||
|
src_file, tgt_file, align_file = rec.split("\t")
|
||||||
|
print("Aligning {} to {}".format(src_file, tgt_file))
|
||||||
|
|
||||||
|
# read in source and target sentences
|
||||||
|
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
|
||||||
|
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
|
||||||
|
|
||||||
|
# convert source and target texts into embeddings
|
||||||
|
# and calculate sentence length
|
||||||
|
t_0 = time.time()
|
||||||
|
src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1)
|
||||||
|
tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1)
|
||||||
|
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
|
||||||
|
print("Reading embeddings takes {:.3f}".format(time.time() - t_0))
|
||||||
|
|
||||||
|
# using faiss, find in the target text
|
||||||
|
# the k nearest neighbors of each source sentence
|
||||||
|
t_1 = time.time()
|
||||||
|
if torch.cuda.is_available(): # GPU version
|
||||||
|
res = faiss.StandardGpuResources()
|
||||||
|
index = faiss.IndexFlatIP(embedding_size)
|
||||||
|
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
|
||||||
|
gpu_index.add(tgt_vecs[0,:])
|
||||||
|
xq = src_vecs[0,:]
|
||||||
|
D,I = gpu_index.search(xq,args.top_k)
|
||||||
|
else: # CPU version
|
||||||
|
index = faiss.IndexFlatIP(embedding_size) # use inter product to build index
|
||||||
|
index.add(tgt_vecs[0,:])
|
||||||
|
xq = src_vecs[0,:]
|
||||||
|
D,I = index.search(xq, args.top_k)
|
||||||
|
print("Finding top-k neighbors takes {:.3f}".format(time.time() - t_1))
|
||||||
|
|
||||||
|
# find 1-to-1 alignment
|
||||||
|
t_2 = time.time()
|
||||||
|
src_len = len(src_lines)
|
||||||
|
tgt_len = len(tgt_lines)
|
||||||
|
first_alignment_types = make_alignment_types(2) # 0-0, 1-0 and 1-1
|
||||||
|
first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100)
|
||||||
|
first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k)
|
||||||
|
first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types)
|
||||||
|
print("First pass alignment takes {:.3f}".format(time.time() - t_2))
|
||||||
|
|
||||||
|
# find m-to-n alignment
|
||||||
|
t_3 = time.time()
|
||||||
|
second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len)
|
||||||
|
second_alignment_types = make_alignment_types(args.max_align)
|
||||||
|
second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin)
|
||||||
|
second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types)
|
||||||
|
print("Second pass alignment takes {:.3f}".format(time.time() - t_3))
|
||||||
|
|
||||||
|
# save alignment
|
||||||
|
print_alignments(second_alignment, align_file)
|
||||||
|
|
||||||
|
def second_back_track(i, j, b, search_path, a_types):
|
||||||
|
alignment = []
|
||||||
|
while ( i !=0 and j != 0 ):
|
||||||
|
j_offset = j - search_path[i][0]
|
||||||
|
a = b[i][j_offset]
|
||||||
|
s = a_types[a][0]
|
||||||
|
t = a_types[a][1]
|
||||||
|
src_range = [i - offset - 1 for offset in range(s)][::-1]
|
||||||
|
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
|
||||||
|
alignment.append((src_range, tgt_range))
|
||||||
|
|
||||||
|
i = i-s
|
||||||
|
j = j-t
|
||||||
|
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False):
|
||||||
|
src_len = src_vecs.shape[1]
|
||||||
|
tgt_len = tgt_vecs.shape[1]
|
||||||
|
|
||||||
|
# intialize sum matrix
|
||||||
|
cost = np.zeros((src_len + 1, w))
|
||||||
|
#back = np.zeros((tgt_len + 1, w), dtype=nb.int64)
|
||||||
|
back = np.zeros((src_len + 1, w), dtype=nb.int64)
|
||||||
|
cost[0][0] = 0
|
||||||
|
back[0][0] = -1
|
||||||
|
|
||||||
|
for i in range(1, src_len + 1):
|
||||||
|
i_start = search_path[i][0]
|
||||||
|
i_end = search_path[i][1]
|
||||||
|
for j in range(i_start, i_end + 1):
|
||||||
|
if i + j == 0:
|
||||||
|
continue
|
||||||
|
best_score = -np.inf
|
||||||
|
best_a = -1
|
||||||
|
for a in range(align_types.shape[0]):
|
||||||
|
a_1 = align_types[a][0]
|
||||||
|
a_2 = align_types[a][1]
|
||||||
|
prev_i = i - a_1
|
||||||
|
prev_j = j - a_2
|
||||||
|
|
||||||
|
if prev_i < 0 or prev_j < 0 : # no previous cell in DP table
|
||||||
|
continue
|
||||||
|
prev_i_start = search_path[prev_i][0]
|
||||||
|
prev_i_end = search_path[prev_i][1]
|
||||||
|
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
|
||||||
|
continue
|
||||||
|
prev_j_offset = prev_j - prev_i_start
|
||||||
|
score = cost[prev_i][prev_j_offset]
|
||||||
|
if score == -np.inf:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if a_1 == 0 or a_2 == 0: # deletion or insertion
|
||||||
|
cur_score = skip
|
||||||
|
else:
|
||||||
|
src_v = src_vecs[a_1-1,i-1,:]
|
||||||
|
tgt_v = tgt_vecs[a_2-1,j-1,:]
|
||||||
|
src_l = src_lens[a_1-1, i-1]
|
||||||
|
tgt_l = tgt_lens[a_2-1, j-1]
|
||||||
|
cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin)
|
||||||
|
tgt_l = tgt_l * char_ratio
|
||||||
|
min_len = min(src_l, tgt_l)
|
||||||
|
max_len = max(src_l, tgt_l)
|
||||||
|
len_p = np.log2(1 + min_len / max_len)
|
||||||
|
cur_score *= len_p
|
||||||
|
|
||||||
|
score += cur_score
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_a = a
|
||||||
|
|
||||||
|
j_offset = j - i_start
|
||||||
|
cost[i][j_offset] = best_score
|
||||||
|
back[i][j_offset] = best_a
|
||||||
|
|
||||||
|
return back
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False):
|
||||||
|
similarity = nb_dot(src_v, tgt_v)
|
||||||
|
if margin:
|
||||||
|
tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs)
|
||||||
|
src_neighbor_ave_sim = get_neighbor_sim(tgt_v, a_1, i, src_len, src_vecs)
|
||||||
|
neighbor_ave_sim = (tgt_neighbor_ave_sim + src_neighbor_ave_sim)/2
|
||||||
|
similarity -= neighbor_ave_sim
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def get_neighbor_sim(vec, a, j, len, db):
|
||||||
|
left_idx = j - a
|
||||||
|
right_idx = j + 1
|
||||||
|
|
||||||
|
if right_idx > len:
|
||||||
|
neighbor_right_sim = 0
|
||||||
|
else:
|
||||||
|
right_embed = db[0,right_idx-1,:]
|
||||||
|
neighbor_right_sim = nb_dot(vec, right_embed)
|
||||||
|
|
||||||
|
if left_idx == 0:
|
||||||
|
neighbor_left_sim = 0
|
||||||
|
else:
|
||||||
|
left_embed = db[0,left_idx-1,:]
|
||||||
|
neighbor_left_sim = nb_dot(vec, left_embed)
|
||||||
|
|
||||||
|
#if right_idx > LEN or left_idx < 0:
|
||||||
|
if right_idx > len or left_idx == 0:
|
||||||
|
neighbor_ave_sim = neighbor_left_sim + neighbor_right_sim
|
||||||
|
else:
|
||||||
|
neighbor_ave_sim = (neighbor_left_sim + neighbor_right_sim) / 2
|
||||||
|
|
||||||
|
return neighbor_ave_sim
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def nb_dot(x, y):
|
||||||
|
return np.dot(x,y)
|
||||||
|
|
||||||
|
def find_second_search_path(align, w, src_len, tgt_len):
|
||||||
|
'''
|
||||||
|
Convert 1-1 alignment from first-pass to the path for second-pass alignment.
|
||||||
|
The index along X-axis and Y-axis must be consecutive.
|
||||||
|
'''
|
||||||
|
last_bead_src = align[-1][0]
|
||||||
|
last_bead_tgt = align[-1][1]
|
||||||
|
|
||||||
|
if last_bead_src != src_len:
|
||||||
|
if last_bead_tgt == tgt_len:
|
||||||
|
align.pop()
|
||||||
|
align.append((src_len, tgt_len))
|
||||||
|
else:
|
||||||
|
if last_bead_tgt != tgt_len:
|
||||||
|
align.pop()
|
||||||
|
align.append((src_len, tgt_len))
|
||||||
|
|
||||||
|
prev_src, prev_tgt = 0,0
|
||||||
|
path = []
|
||||||
|
max_w = -np.inf
|
||||||
|
for src, tgt in align:
|
||||||
|
lower_bound = max(0, prev_tgt - w)
|
||||||
|
upper_bound = min(tgt_len, tgt + w)
|
||||||
|
path.extend([(lower_bound, upper_bound) for id in range(prev_src+1, src+1)])
|
||||||
|
prev_src, prev_tgt = src, tgt
|
||||||
|
width = upper_bound - lower_bound
|
||||||
|
if width > max_w:
|
||||||
|
max_w = width
|
||||||
|
path = [path[0]] + path
|
||||||
|
|
||||||
|
return max_w + 1, np.array(path)
|
||||||
|
|
||||||
|
def first_back_track(i, j, b, search_path, a_types):
|
||||||
|
alignment = []
|
||||||
|
while ( i !=0 and j != 0 ):
|
||||||
|
j_offset = j - search_path[i][0]
|
||||||
|
a = b[i][j_offset]
|
||||||
|
s = a_types[a][0]
|
||||||
|
t = a_types[a][1]
|
||||||
|
if a == 2:
|
||||||
|
alignment.append((i, j))
|
||||||
|
|
||||||
|
i = i-s
|
||||||
|
j = j-t
|
||||||
|
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k):
|
||||||
|
|
||||||
|
#initialize cost and backpointer matrix
|
||||||
|
cost = np.zeros((src_len + 1, 2 * w + 1))
|
||||||
|
pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64)
|
||||||
|
cost[0][0] = 0
|
||||||
|
pointers[0][0] = -1
|
||||||
|
|
||||||
|
for i in range(1, src_len + 1):
|
||||||
|
i_start = search_path[i][0]
|
||||||
|
i_end = search_path[i][1]
|
||||||
|
for j in range(i_start, i_end + 1):
|
||||||
|
if i + j == 0:
|
||||||
|
continue
|
||||||
|
best_score = -np.inf
|
||||||
|
best_a = -1
|
||||||
|
for a in range(align_types.shape[0]):
|
||||||
|
a_1 = align_types[a][0]
|
||||||
|
a_2 = align_types[a][1]
|
||||||
|
prev_i = i - a_1
|
||||||
|
prev_j = j - a_2
|
||||||
|
if prev_i < 0 or prev_j < 0 : # no previous cell
|
||||||
|
continue
|
||||||
|
prev_i_start = search_path[prev_i][0]
|
||||||
|
prev_i_end = search_path[prev_i][1]
|
||||||
|
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
|
||||||
|
continue
|
||||||
|
prev_j_offset = prev_j - prev_i_start
|
||||||
|
score = cost[prev_i][prev_j_offset]
|
||||||
|
if score == -np.inf:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if a_1 > 0 and a_2 > 0:
|
||||||
|
for k in range(top_k):
|
||||||
|
if index[i-1][k] == j - 1:
|
||||||
|
score += dist[i-1][k]
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_a = a
|
||||||
|
|
||||||
|
j_offset = j - i_start
|
||||||
|
cost[i][j_offset] = best_score
|
||||||
|
pointers[i][j_offset] = best_a
|
||||||
|
|
||||||
|
return pointers
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
|
||||||
|
yx_ratio = tgt_len / src_len
|
||||||
|
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
|
||||||
|
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
|
||||||
|
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
|
||||||
|
w_2 = int(max(src_len, tgt_len) * 0.06)
|
||||||
|
w = max(w_1, w_2)
|
||||||
|
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
|
||||||
|
for i in range(0, src_len + 1):
|
||||||
|
center = int(yx_ratio * i)
|
||||||
|
w_start = max(0, center - w)
|
||||||
|
w_end = min(center + w, tgt_len)
|
||||||
|
search_path[i] = [w_start, w_end]
|
||||||
|
|
||||||
|
return w, search_path
|
||||||
|
|
||||||
|
def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
|
||||||
|
lines = [preprocess_line(line) for line in lines]
|
||||||
|
vecsize = line_embeddings.shape[1]
|
||||||
|
vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32)
|
||||||
|
vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int)
|
||||||
|
|
||||||
|
for ii, overlap in enumerate(range(1, num_overlaps + 1)):
|
||||||
|
for jj, out_line in enumerate(layer(lines, overlap)):
|
||||||
|
try:
|
||||||
|
line_id = sent2line[out_line]
|
||||||
|
except KeyError:
|
||||||
|
logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line)
|
||||||
|
line_id = None
|
||||||
|
|
||||||
|
if line_id is not None:
|
||||||
|
vec = line_embeddings[line_id]
|
||||||
|
else:
|
||||||
|
vec = np.random.random(vecsize) - 0.5
|
||||||
|
vec = vec / np.linalg.norm(vec)
|
||||||
|
|
||||||
|
vecs0[ii, jj, :] = vec
|
||||||
|
vecs1[ii, jj] = len(out_line.encode("utf-8"))
|
||||||
|
|
||||||
|
return vecs0, vecs1
|
||||||
|
|
||||||
|
def preprocess_line(line):
|
||||||
|
line = line.strip()
|
||||||
|
if len(line) == 0:
|
||||||
|
line = 'BLANK_LINE'
|
||||||
|
|
||||||
|
return line
|
||||||
|
|
||||||
|
def layer(lines, num_overlaps, comb=' '):
|
||||||
|
"""
|
||||||
|
make front-padded overlapping sentences
|
||||||
|
"""
|
||||||
|
if num_overlaps < 1:
|
||||||
|
raise Exception('num_overlaps must be >= 1')
|
||||||
|
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
||||||
|
for ii in range(len(lines) - num_overlaps + 1):
|
||||||
|
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def read_in_embeddings(text_file, embed_file):
|
||||||
|
sent2line = dict()
|
||||||
|
with open(text_file, 'rt', encoding="utf-8") as fin:
|
||||||
|
for ii, line in enumerate(fin):
|
||||||
|
if line.strip() in sent2line:
|
||||||
|
raise Exception('got multiple embeddings for the same line')
|
||||||
|
sent2line[line.strip()] = ii
|
||||||
|
|
||||||
|
line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1)
|
||||||
|
if line_embeddings.size == 0:
|
||||||
|
raise Exception('Got empty embedding file')
|
||||||
|
|
||||||
|
embedding_size = line_embeddings.size // len(sent2line)
|
||||||
|
line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size)
|
||||||
|
|
||||||
|
return sent2line, line_embeddings
|
||||||
|
|
||||||
|
def make_alignment_types(max_alignment_size):
|
||||||
|
# Return list of all (n,m) where n+m <= this
|
||||||
|
alignment_types = []
|
||||||
|
for x in range(1, max_alignment_size):
|
||||||
|
for y in range(1, max_alignment_size):
|
||||||
|
if x + y <= max_alignment_size:
|
||||||
|
alignment_types.append([x, y])
|
||||||
|
alignment_types = [[0,1], [1,0]] + alignment_types
|
||||||
|
|
||||||
|
return np.array(alignment_types)
|
||||||
|
|
||||||
|
def create_jobs(meta, src, tgt, out):
|
||||||
|
jobs = []
|
||||||
|
fns = get_fns(meta)
|
||||||
|
for file in fns:
|
||||||
|
src_path = os.path.abspath(os.path.join(src, file))
|
||||||
|
tgt_path = os.path.abspath(os.path.join(tgt, file))
|
||||||
|
|
||||||
|
out_path = os.path.abspath(os.path.join(out, file + '.align'))
|
||||||
|
jobs.append('\t'.join([src_path, tgt_path, out_path]))
|
||||||
|
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
def get_fns(meta):
|
||||||
|
fns = []
|
||||||
|
with open(meta, 'rt', encoding='utf-8') as f:
|
||||||
|
next(f) # skip header
|
||||||
|
for line in f:
|
||||||
|
recs = line.strip().split('\t')
|
||||||
|
fns.append(recs[0])
|
||||||
|
|
||||||
|
return fns
|
||||||
|
|
||||||
|
def print_alignments(alignments, out):
|
||||||
|
with open(out, 'wt', encoding='utf-8') as f:
|
||||||
|
for x, y in alignments:
|
||||||
|
f.write("{}:{}\n".format(x, y))
|
||||||
|
|
||||||
|
def make_dir(path):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t_0 = time.time()
|
||||||
|
main()
|
||||||
|
print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0))
|
||||||
107
bin/embed_sents.py
Normal file
107
bin/embed_sents.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# 2021/11/27
|
||||||
|
# bfsujason@163.com
|
||||||
|
|
||||||
|
'''
|
||||||
|
Usage (Linux):
|
||||||
|
|
||||||
|
python bin/embed_sents.py \
|
||||||
|
-i data/mac/dev/zh \
|
||||||
|
-o data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
|
||||||
|
-m data/mac/test/meta_data.tsv \
|
||||||
|
-n 8
|
||||||
|
'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Multilingual sentence embeddings')
|
||||||
|
parser.add_argument('-i', '--input', type=str, required=True, help='Data directory.')
|
||||||
|
parser.add_argument('-o', '--output', type=str, required=True, nargs=2, help='Overalp and embedding file.')
|
||||||
|
parser.add_argument('-n', '--num_overlaps', type=int, default=5, help='Maximum number of allowed overlaps.')
|
||||||
|
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
fns = get_fns(args.meta)
|
||||||
|
overlap = get_overlap(args.input, fns, args.num_overlaps)
|
||||||
|
write_overlap(overlap, args.output[0])
|
||||||
|
|
||||||
|
model = load_model()
|
||||||
|
embed_overlap(model, overlap, args.output[1])
|
||||||
|
|
||||||
|
def embed_overlap(model, overlap, fout):
|
||||||
|
print("Embedding text ...")
|
||||||
|
t_0 = time.time()
|
||||||
|
embed = model.encode(overlap)
|
||||||
|
embed.tofile(fout)
|
||||||
|
print("It takes {:.3f} seconods to embed text.".format(time.time() - t_0))
|
||||||
|
|
||||||
|
def write_overlap(overlap, outfile):
|
||||||
|
with open(outfile, 'wt', encoding="utf-8") as fout:
|
||||||
|
for line in overlap:
|
||||||
|
fout.write(line + '\n')
|
||||||
|
|
||||||
|
def get_overlap(dir, fns, n):
|
||||||
|
overlap = set()
|
||||||
|
for file in fns:
|
||||||
|
in_path = os.path.join(dir, file)
|
||||||
|
lines = open(in_path, 'rt', encoding="utf-8").readlines()
|
||||||
|
for out_line in yield_overlaps(lines, n):
|
||||||
|
overlap.add(out_line)
|
||||||
|
|
||||||
|
# for reproducibility
|
||||||
|
overlap = list(overlap)
|
||||||
|
overlap.sort()
|
||||||
|
|
||||||
|
return overlap
|
||||||
|
|
||||||
|
def yield_overlaps(lines, num_overlaps):
|
||||||
|
lines = [preprocess_line(line) for line in lines]
|
||||||
|
for overlap in range(1, num_overlaps + 1):
|
||||||
|
for out_line in layer(lines, overlap):
|
||||||
|
# check must be here so all outputs are unique
|
||||||
|
out_line2 = out_line[:10000] # limit line so dont encode arbitrarily long sentences
|
||||||
|
yield out_line2
|
||||||
|
|
||||||
|
def layer(lines, num_overlaps, comb=' '):
|
||||||
|
if num_overlaps < 1:
|
||||||
|
raise Exception('num_overlaps must be >= 1')
|
||||||
|
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
||||||
|
for ii in range(len(lines) - num_overlaps + 1):
|
||||||
|
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def preprocess_line(line):
|
||||||
|
line = line.strip()
|
||||||
|
if len(line) == 0:
|
||||||
|
line = 'BLANK_LINE'
|
||||||
|
return line
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
print("Loading embedding model ...")
|
||||||
|
t0 = time.time()
|
||||||
|
model = SentenceTransformer('LaBSE')
|
||||||
|
print("It takes {:.3f} seconods to load the model.".format(time.time() - t0))
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_fns(meta):
|
||||||
|
fns = []
|
||||||
|
with open(meta, 'rt', encoding='utf-8') as f:
|
||||||
|
next(f) # skip header
|
||||||
|
for line in f:
|
||||||
|
recs = line.strip().split('\t')
|
||||||
|
fns.append(recs[0])
|
||||||
|
|
||||||
|
return fns
|
||||||
|
|
||||||
|
def make_dir(path):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
182
bin/eval.py
Normal file
182
bin/eval.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Copyright 2019 Brian Thompson
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from ast import literal_eval
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
"""
|
||||||
|
Faster implementation of lax and strict precision and recall, based on
|
||||||
|
https://www.aclweb.org/anthology/W11-4624/.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def read_alignments(fin):
|
||||||
|
alignments = []
|
||||||
|
with open(fin, 'rt', encoding="utf-8") as infile:
|
||||||
|
for line in infile:
|
||||||
|
fields = [x.strip() for x in line.split(':') if len(x.strip())]
|
||||||
|
if len(fields) < 2:
|
||||||
|
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
|
||||||
|
try:
|
||||||
|
src = literal_eval(fields[0])
|
||||||
|
tgt = literal_eval(fields[1])
|
||||||
|
except:
|
||||||
|
raise Exception('Failed to parse line "%s"' % line.strip())
|
||||||
|
alignments.append((src, tgt))
|
||||||
|
|
||||||
|
return alignments
|
||||||
|
|
||||||
|
def _precision(goldalign, testalign):
|
||||||
|
"""
|
||||||
|
Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments
|
||||||
|
"""
|
||||||
|
tpstrict = 0 # true positive strict counter
|
||||||
|
tplax = 0 # true positive lax counter
|
||||||
|
fpstrict = 0 # false positive strict counter
|
||||||
|
fplax = 0 # false positive lax counter
|
||||||
|
|
||||||
|
# convert to sets, remove alignments empty on both sides
|
||||||
|
testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)])
|
||||||
|
goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)])
|
||||||
|
|
||||||
|
# mappings from source test sentence idxs to
|
||||||
|
# target gold sentence idxs for which the source test sentence
|
||||||
|
# was found in corresponding source gold alignment
|
||||||
|
src_id_to_gold_tgt_ids = defaultdict(set)
|
||||||
|
for gold_src, gold_tgt in goldalign:
|
||||||
|
for gold_src_id in gold_src:
|
||||||
|
for gold_tgt_id in gold_tgt:
|
||||||
|
src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id)
|
||||||
|
|
||||||
|
for (test_src, test_target) in testalign:
|
||||||
|
if (test_src, test_target) == ((), ()):
|
||||||
|
continue
|
||||||
|
if (test_src, test_target) in goldalign:
|
||||||
|
# strict match
|
||||||
|
tpstrict += 1
|
||||||
|
tplax += 1
|
||||||
|
else:
|
||||||
|
# For anything with partial gold/test overlap on the source,
|
||||||
|
# see if there is also partial overlap on the gold/test target
|
||||||
|
# If so, its a lax match
|
||||||
|
target_ids = set()
|
||||||
|
for src_test_id in test_src:
|
||||||
|
for tgt_id in src_id_to_gold_tgt_ids[src_test_id]:
|
||||||
|
target_ids.add(tgt_id)
|
||||||
|
if set(test_target).intersection(target_ids):
|
||||||
|
fpstrict += 1
|
||||||
|
tplax += 1
|
||||||
|
else:
|
||||||
|
fpstrict += 1
|
||||||
|
fplax += 1
|
||||||
|
|
||||||
|
return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def score_multiple(gold_list, test_list, value_for_div_by_0=0.0):
|
||||||
|
# accumulate counts for all gold/test files
|
||||||
|
pcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
||||||
|
rcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
||||||
|
for goldalign, testalign in zip(gold_list, test_list):
|
||||||
|
pcounts += _precision(goldalign=goldalign, testalign=testalign)
|
||||||
|
# recall is precision with no insertion/deletion and swap args
|
||||||
|
test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)]
|
||||||
|
gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)]
|
||||||
|
rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del)
|
||||||
|
|
||||||
|
# Compute results
|
||||||
|
# pcounts: tpstrict,fnstrict,tplax,fnlax
|
||||||
|
# rcounts: tpstrict,fpstrict,tplax,fplax
|
||||||
|
|
||||||
|
if pcounts[0] + pcounts[1] == 0:
|
||||||
|
pstrict = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
pstrict = pcounts[0] / float(pcounts[0] + pcounts[1])
|
||||||
|
|
||||||
|
if pcounts[2] + pcounts[3] == 0:
|
||||||
|
plax = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
plax = pcounts[2] / float(pcounts[2] + pcounts[3])
|
||||||
|
|
||||||
|
if rcounts[0] + rcounts[1] == 0:
|
||||||
|
rstrict = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
rstrict = rcounts[0] / float(rcounts[0] + rcounts[1])
|
||||||
|
|
||||||
|
if rcounts[2] + rcounts[3] == 0:
|
||||||
|
rlax = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
rlax = rcounts[2] / float(rcounts[2] + rcounts[3])
|
||||||
|
|
||||||
|
if (pstrict + rstrict) == 0:
|
||||||
|
fstrict = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict)
|
||||||
|
|
||||||
|
if (plax + rlax) == 0:
|
||||||
|
flax = value_for_div_by_0
|
||||||
|
else:
|
||||||
|
flax = 2 * (plax * rlax) / (plax + rlax)
|
||||||
|
|
||||||
|
result = dict(recall_strict=rstrict,
|
||||||
|
recall_lax=rlax,
|
||||||
|
precision_strict=pstrict,
|
||||||
|
precision_lax=plax,
|
||||||
|
f1_strict=fstrict,
|
||||||
|
f1_lax=flax)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def log_final_scores(res):
|
||||||
|
print(' ---------------------------------', file=sys.stderr)
|
||||||
|
print('| | Strict | Lax |', file=sys.stderr)
|
||||||
|
print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||||
|
print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||||
|
print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||||
|
print(' ---------------------------------', file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
'Compute strict/lax precision and recall for one or more pairs of gold/test alignments',
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument('-t', '--test', type=str, required=True,
|
||||||
|
help='Test alignment directory.')
|
||||||
|
|
||||||
|
parser.add_argument('-g', '--gold', type=str, required=True,
|
||||||
|
help='Gold alignment directory.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
gold_list = [read_alignments(os.path.join(args.gold, x)) for x in sorted(os.listdir(args.gold))]
|
||||||
|
test_list = [read_alignments(os.path.join(args.test, x)) for x in sorted(os.listdir(args.test))]
|
||||||
|
|
||||||
|
res = score_multiple(gold_list=gold_list, test_list=test_list)
|
||||||
|
log_final_scores(res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
109
bin/eval_bible.py
Normal file
109
bin/eval_bible.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from ast import literal_eval
|
||||||
|
from collections import defaultdict
|
||||||
|
from eval import score_multiple, log_final_scores
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser('Evaluate aligment quality for Bible corpus')
|
||||||
|
parser.add_argument('-t', '--test', type=str, required=True, help='Test alignment file.')
|
||||||
|
parser.add_argument('-g', '--gold', type=str, required=True, help='Gold alignment file.')
|
||||||
|
parser.add_argument('--src_verse', type=str, required=True, help='Source verse file.')
|
||||||
|
parser.add_argument('--tgt_verse', type=str, required=True, help='Target verse file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_alignments = read_alignments(args.test)
|
||||||
|
gold_alignments = read_alignments(args.gold)
|
||||||
|
|
||||||
|
src_verse = get_verse(args.src_verse)
|
||||||
|
tgt_verse = get_verse(args.tgt_verse)
|
||||||
|
|
||||||
|
merged_test_alignments = merge_test_alignments(test_alignments, src_verse, tgt_verse)
|
||||||
|
res = score_multiple(gold_list=[gold_alignments], test_list=[merged_test_alignments])
|
||||||
|
log_final_scores(res)
|
||||||
|
|
||||||
|
def merge_test_alignments(alignments, src_verse, tgt_verse):
|
||||||
|
merged_align = []
|
||||||
|
last_beads_type = None
|
||||||
|
|
||||||
|
for beads in alignments:
|
||||||
|
beads_type = find_beads_type(beads, src_verse, tgt_verse)
|
||||||
|
if not last_beads_type:
|
||||||
|
merged_align.append(beads)
|
||||||
|
else:
|
||||||
|
if beads_type == last_beads_type:
|
||||||
|
merged_align[-1][0].extend(beads[0])
|
||||||
|
merged_align[-1][1].extend(beads[1])
|
||||||
|
else:
|
||||||
|
merged_align.append(beads)
|
||||||
|
|
||||||
|
last_beads_type = beads_type
|
||||||
|
|
||||||
|
return merged_align
|
||||||
|
|
||||||
|
def find_beads_type(beads, src_verse, tgt_verse):
|
||||||
|
src_bead = beads[0]
|
||||||
|
tgt_bead = beads[1]
|
||||||
|
|
||||||
|
src_bead_type = find_bead_type(src_bead, src_verse)
|
||||||
|
tgt_bead_type = find_bead_type(tgt_bead, tgt_verse)
|
||||||
|
|
||||||
|
src_bead_len = len(src_bead_type)
|
||||||
|
tgt_bead_len = len(tgt_bead_type)
|
||||||
|
|
||||||
|
if src_bead_len != 1 or tgt_bead_len != 1:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
src_verse = src_bead_type[0]
|
||||||
|
tgt_verse = tgt_bead_type[0]
|
||||||
|
if src_verse != tgt_verse:
|
||||||
|
if src_verse == 'NULL':
|
||||||
|
return tgt_verse
|
||||||
|
elif tgt_verse == 'NULL':
|
||||||
|
return src_verse
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return src_verse
|
||||||
|
|
||||||
|
def find_bead_type(bead, verse):
|
||||||
|
bead_type = ['NULL']
|
||||||
|
if len(bead) > 0:
|
||||||
|
bead_type = unique_list([verse[id] for id in bead])
|
||||||
|
|
||||||
|
return bead_type
|
||||||
|
|
||||||
|
def unique_list(list):
|
||||||
|
unique_list = []
|
||||||
|
for x in list:
|
||||||
|
if x not in unique_list:
|
||||||
|
unique_list.append(x)
|
||||||
|
|
||||||
|
return unique_list
|
||||||
|
|
||||||
|
def get_verse(file):
|
||||||
|
verse = defaultdict()
|
||||||
|
with open(file, 'rt', encoding='utf-8') as f:
|
||||||
|
for (i, line) in enumerate(f):
|
||||||
|
verse[i] = line.strip()
|
||||||
|
|
||||||
|
return verse
|
||||||
|
|
||||||
|
def read_alignments(fin):
|
||||||
|
alignments = []
|
||||||
|
with open(fin, 'rt', encoding="utf-8") as infile:
|
||||||
|
for line in infile:
|
||||||
|
fields = [x.strip() for x in line.split(':') if len(x.strip())]
|
||||||
|
if len(fields) < 2:
|
||||||
|
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
|
||||||
|
try:
|
||||||
|
src = literal_eval(fields[0])
|
||||||
|
tgt = literal_eval(fields[1])
|
||||||
|
except:
|
||||||
|
raise Exception('Failed to parse line "%s"' % line.strip())
|
||||||
|
alignments.append((src, tgt))
|
||||||
|
|
||||||
|
return alignments
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
226
bin/gale_align.py
Normal file
226
bin/gale_align.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
# 2021/11/27
|
||||||
|
# bfsujason@163.com
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
python bin/gale_align.py \
|
||||||
|
-m data/mac/test/meta_data.tsv \
|
||||||
|
-s data/mac/test/zh \
|
||||||
|
-t data/mac/test/en \
|
||||||
|
-o data/mac/test/auto
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import numba as nb
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# user-defined parameters
|
||||||
|
parser = argparse.ArgumentParser(description='Sentence alignment using Gale-Church Algrorithm')
|
||||||
|
parser.add_argument('-s', '--src', type=str, required=True, help='Source directory.')
|
||||||
|
parser.add_argument('-t', '--tgt', type=str, required=True, help='Target directory.')
|
||||||
|
parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.')
|
||||||
|
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
make_dir(args.out)
|
||||||
|
|
||||||
|
# fixed parameters to determine the window size for alignment
|
||||||
|
min_win_size = 10
|
||||||
|
max_win_size = 600
|
||||||
|
win_per_100 = 8
|
||||||
|
|
||||||
|
# alignment types
|
||||||
|
align_types = np.array(
|
||||||
|
[
|
||||||
|
[0,1],
|
||||||
|
[1,0],
|
||||||
|
[1,1],
|
||||||
|
[1,2],
|
||||||
|
[2,1],
|
||||||
|
[2,2],
|
||||||
|
], dtype=np.int)
|
||||||
|
|
||||||
|
# prior probability
|
||||||
|
priors = np.array([0, 0.0099, 0.89, 0.089, 0.011])
|
||||||
|
|
||||||
|
# mean and variance
|
||||||
|
c = 1
|
||||||
|
s2 = 6.8
|
||||||
|
|
||||||
|
# perform gale-church align
|
||||||
|
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
|
||||||
|
for rec in jobs:
|
||||||
|
src_file, tgt_file, align_file = rec.split("\t")
|
||||||
|
print("Aligning {} to {}".format(src_file, tgt_file))
|
||||||
|
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
|
||||||
|
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
|
||||||
|
src_len = calculate_txt_len(src_lines)
|
||||||
|
tgt_len = calculate_txt_len(tgt_lines)
|
||||||
|
|
||||||
|
m = src_len.shape[0] - 1
|
||||||
|
n = tgt_len.shape[0] - 1
|
||||||
|
|
||||||
|
# find search path
|
||||||
|
w, search_path = find_search_path(m, n, min_win_size, max_win_size, win_per_100)
|
||||||
|
cost, back = align(src_len, tgt_len, w, search_path, align_types, priors, c, s2)
|
||||||
|
alignments = back_track(m, n, back, search_path, align_types)
|
||||||
|
|
||||||
|
# save alignments
|
||||||
|
save_alignments(alignments, align_file)
|
||||||
|
|
||||||
|
def save_alignments(alignments, file):
|
||||||
|
with open(file, 'wt', encoding='utf-8') as f:
|
||||||
|
for id in alignments:
|
||||||
|
f.write("{}:{}\n".format(id[0], id[1]))
|
||||||
|
|
||||||
|
def back_track(i, j, b, search_path, a_types):
|
||||||
|
alignment = []
|
||||||
|
while ( i !=0 and j != 0 ):
|
||||||
|
j_offset = j - search_path[i][0]
|
||||||
|
a = b[i][j_offset]
|
||||||
|
s = a_types[a][0]
|
||||||
|
t = a_types[a][1]
|
||||||
|
src_range = [i - offset - 1 for offset in range(s)][::-1]
|
||||||
|
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
|
||||||
|
alignment.append((src_range, tgt_range))
|
||||||
|
|
||||||
|
i = i-s
|
||||||
|
j = j-t
|
||||||
|
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def align(src_len, tgt_len, w, search_path, align_types, priors, c, s2):
|
||||||
|
#initialize cost and backpointer matrix
|
||||||
|
m = src_len.shape[0] - 1
|
||||||
|
cost = np.zeros((m + 1, 2 * w + 1))
|
||||||
|
back = np.zeros((m + 1, 2 * w + 1), dtype=nb.int64)
|
||||||
|
cost[0][0] = 0
|
||||||
|
back[0][0] = -1
|
||||||
|
|
||||||
|
for i in range(m + 1):
|
||||||
|
i_start = search_path[i][0]
|
||||||
|
i_end = search_path[i][1]
|
||||||
|
|
||||||
|
for j in range(i_start, i_end + 1):
|
||||||
|
if i + j == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_score = np.inf
|
||||||
|
best_a = -1
|
||||||
|
for a in range(align_types.shape[0]):
|
||||||
|
a_1 = align_types[a][0]
|
||||||
|
a_2 = align_types[a][1]
|
||||||
|
prev_i = i - a_1
|
||||||
|
prev_j = j - a_2
|
||||||
|
|
||||||
|
if prev_i < 0 or prev_j < 0 : # no previous cell
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev_i_start = search_path[prev_i][0]
|
||||||
|
prev_i_end = search_path[prev_i][1]
|
||||||
|
|
||||||
|
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
|
||||||
|
continue
|
||||||
|
|
||||||
|
prev_j_offset = prev_j - prev_i_start
|
||||||
|
|
||||||
|
score = cost[prev_i][prev_j_offset] - math.log(priors[a_1 + a_2]) + \
|
||||||
|
get_score(src_len[i] - src_len[i - a_1], tgt_len[j] - tgt_len[j - a_2], c, s2)
|
||||||
|
|
||||||
|
if score < best_score:
|
||||||
|
best_score = score
|
||||||
|
best_a = a
|
||||||
|
|
||||||
|
j_offset = j - i_start
|
||||||
|
cost[i][j_offset] = best_score
|
||||||
|
back[i][j_offset] = best_a
|
||||||
|
|
||||||
|
return cost, back
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def get_score(len_s, len_t, c, s2):
|
||||||
|
mean = (len_s + len_t / c) / 2
|
||||||
|
z = (len_t - len_s * c) / math.sqrt(mean * s2)
|
||||||
|
|
||||||
|
pd = 2 * (1 - norm_cdf(abs(z)))
|
||||||
|
if pd > 0:
|
||||||
|
return -math.log(pd)
|
||||||
|
|
||||||
|
return 25
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def find_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
|
||||||
|
yx_ratio = tgt_len / src_len
|
||||||
|
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
|
||||||
|
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
|
||||||
|
|
||||||
|
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
|
||||||
|
w_2 = int(max(src_len, tgt_len) * 0.06)
|
||||||
|
w = max(w_1, w_2)
|
||||||
|
|
||||||
|
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
|
||||||
|
for i in range(0, src_len + 1):
|
||||||
|
center = int(yx_ratio * i)
|
||||||
|
w_start = max(0, center - w)
|
||||||
|
w_end = min(center + w, tgt_len)
|
||||||
|
search_path[i] = [w_start, w_end]
|
||||||
|
|
||||||
|
return w, search_path
|
||||||
|
|
||||||
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
|
def norm_cdf(z):
|
||||||
|
t = 1/float(1+0.2316419*z) # t = 1/(1+pz) , z=0.2316419
|
||||||
|
p_norm = 1 - 0.3989423*math.exp(-z*z/2) * ((0.319381530 * t)+ \
|
||||||
|
(-0.356563782 * t)+ \
|
||||||
|
(1.781477937 * t) + \
|
||||||
|
(-1.821255978* t) + \
|
||||||
|
(1.330274429 * t))
|
||||||
|
|
||||||
|
return p_norm
|
||||||
|
|
||||||
|
def calculate_txt_len(lines):
|
||||||
|
txt_len = []
|
||||||
|
txt_len.append(0)
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
# UTF-8 byte length
|
||||||
|
txt_len.append(txt_len[i] + len(line.strip().encode("utf-8")))
|
||||||
|
|
||||||
|
return np.array(txt_len)
|
||||||
|
|
||||||
|
def create_jobs(meta, src, tgt, out):
|
||||||
|
jobs = []
|
||||||
|
fns = get_fns(meta)
|
||||||
|
for file in fns:
|
||||||
|
src_path = os.path.abspath(os.path.join(src, file))
|
||||||
|
tgt_path = os.path.abspath(os.path.join(tgt, file))
|
||||||
|
out_path = os.path.abspath(os.path.join(out, file + '.align'))
|
||||||
|
jobs.append('\t'.join([src_path, tgt_path, out_path]))
|
||||||
|
|
||||||
|
return jobs
|
||||||
|
|
||||||
|
def get_fns(meta):
|
||||||
|
fns = []
|
||||||
|
with open(meta, 'rt', encoding='utf-8') as f:
|
||||||
|
next(f) # skip header
|
||||||
|
for line in f:
|
||||||
|
recs = line.strip().split('\t')
|
||||||
|
fns.append(recs[0])
|
||||||
|
|
||||||
|
return fns
|
||||||
|
|
||||||
|
def make_dir(path):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
t_0 = time.time()
|
||||||
|
main()
|
||||||
|
print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0))
|
||||||
Reference in New Issue
Block a user