Bertalign and evaluation scripts

This commit is contained in:
nlpfun
2021-11-28 13:58:26 +08:00
parent e4e4c31b22
commit e033edad52
5 changed files with 1074 additions and 0 deletions

450
bin/bert_align.py Normal file
View File

@@ -0,0 +1,450 @@
# 2021/11/27
# bfsujason@163.com
"""
Usage:
python bin/bert_align.py \
-s data/mac/dev/zh \
-t data/mac/dev/en \
-o data/mac/dev/auto \
-m data/mac/dev/meta_data.tsv \
--src_embed data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
--tgt_embed data/mac/dev/en/overlap data/mac/dev/en/overlap.emb \
--max_align 8 --margin
"""
import os
import sys
import time
import torch
import faiss
import shutil
import argparse
import numpy as np
import numba as nb
def main():
# user-defined parameters
parser = argparse.ArgumentParser('Sentence alignment using Vecalign')
parser.add_argument('-s', '--src', type=str, required=True, help='preprocessed source file to align')
parser.add_argument('-t', '--tgt', type=str, required=True, help='preprocessed target file to align')
parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.')
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
parser.add_argument('--src_embed', type=str, nargs=2, required=True,
help='Source embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
parser.add_argument('--tgt_embed', type=str, nargs=2, required=True,
help='Target embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
parser.add_argument('--max_align', type=int, default=5, help='Maximum alignment types, n + m <= this value.')
parser.add_argument('--win', type=int, default=5, help='Window size for the second-pass alignment.')
parser.add_argument('--top_k', type=int, default=3, help='Top-k target neighbors of each source sentence.')
parser.add_argument('--skip', type=float, default=-0.1, help='Similarity score for 0-1 and 1-0 alignment.')
parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity')
args = parser.parse_args()
# fixed parameters to determine the
# window size for the first-pass alignment
min_win_size = 10
max_win_size = 600
win_per_100 = 8
# read in embeddings
src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1])
tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
embedding_size = src_line_embeddings.shape[1]
make_dir(args.out)
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
# start alignment
for rec in jobs:
src_file, tgt_file, align_file = rec.split("\t")
print("Aligning {} to {}".format(src_file, tgt_file))
# read in source and target sentences
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
# convert source and target texts into embeddings
# and calculate sentence length
t_0 = time.time()
src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1)
tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1)
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
print("Reading embeddings takes {:.3f}".format(time.time() - t_0))
# using faiss, find in the target text
# the k nearest neighbors of each source sentence
t_1 = time.time()
if torch.cuda.is_available(): # GPU version
res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(embedding_size)
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
gpu_index.add(tgt_vecs[0,:])
xq = src_vecs[0,:]
D,I = gpu_index.search(xq,args.top_k)
else: # CPU version
index = faiss.IndexFlatIP(embedding_size) # use inter product to build index
index.add(tgt_vecs[0,:])
xq = src_vecs[0,:]
D,I = index.search(xq, args.top_k)
print("Finding top-k neighbors takes {:.3f}".format(time.time() - t_1))
# find 1-to-1 alignment
t_2 = time.time()
src_len = len(src_lines)
tgt_len = len(tgt_lines)
first_alignment_types = make_alignment_types(2) # 0-0 1-0 and 1-1
first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100)
first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k)
first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types)
print("First pass alignment takes {:.3f}".format(time.time() - t_2))
# find m-to-n alignment
t_3 = time.time()
second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len)
second_alignment_types = make_alignment_types(args.max_align)
second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin)
second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types)
print("Second pass alignment takes {:.3f}".format(time.time() - t_3))
# save alignment
print_alignments(second_alignment, align_file)
def second_back_track(i, j, b, search_path, a_types):
alignment = []
while ( i !=0 and j != 0 ):
j_offset = j - search_path[i][0]
a = b[i][j_offset]
s = a_types[a][0]
t = a_types[a][1]
src_range = [i - offset - 1 for offset in range(s)][::-1]
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
alignment.append((src_range, tgt_range))
i = i-s
j = j-t
return alignment[::-1]
@nb.jit(nopython=True, fastmath=True, cache=True)
def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False):
src_len = src_vecs.shape[1]
tgt_len = tgt_vecs.shape[1]
# intialize sum matrix
cost = np.zeros((src_len + 1, w))
#back = np.zeros((tgt_len + 1, w), dtype=nb.int64)
back = np.zeros((src_len + 1, w), dtype=nb.int64)
cost[0][0] = 0
back[0][0] = -1
for i in range(1, src_len + 1):
i_start = search_path[i][0]
i_end = search_path[i][1]
for j in range(i_start, i_end + 1):
if i + j == 0:
continue
best_score = -np.inf
best_a = -1
for a in range(align_types.shape[0]):
a_1 = align_types[a][0]
a_2 = align_types[a][1]
prev_i = i - a_1
prev_j = j - a_2
if prev_i < 0 or prev_j < 0 : # no previous cell in DP table
continue
prev_i_start = search_path[prev_i][0]
prev_i_end = search_path[prev_i][1]
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
continue
prev_j_offset = prev_j - prev_i_start
score = cost[prev_i][prev_j_offset]
if score == -np.inf:
continue
if a_1 == 0 or a_2 == 0: # deletion or insertion
cur_score = skip
else:
src_v = src_vecs[a_1-1,i-1,:]
tgt_v = tgt_vecs[a_2-1,j-1,:]
src_l = src_lens[a_1-1, i-1]
tgt_l = tgt_lens[a_2-1, j-1]
cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin)
tgt_l = tgt_l * char_ratio
min_len = min(src_l, tgt_l)
max_len = max(src_l, tgt_l)
len_p = np.log2(1 + min_len / max_len)
cur_score *= len_p
score += cur_score
if score > best_score:
best_score = score
best_a = a
j_offset = j - i_start
cost[i][j_offset] = best_score
back[i][j_offset] = best_a
return back
@nb.jit(nopython=True, fastmath=True, cache=True)
def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False):
similarity = nb_dot(src_v, tgt_v)
if margin:
tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs)
src_neighbor_ave_sim = get_neighbor_sim(tgt_v, a_1, i, src_len, src_vecs)
neighbor_ave_sim = (tgt_neighbor_ave_sim + src_neighbor_ave_sim)/2
similarity -= neighbor_ave_sim
return similarity
@nb.jit(nopython=True, fastmath=True, cache=True)
def get_neighbor_sim(vec, a, j, len, db):
left_idx = j - a
right_idx = j + 1
if right_idx > len:
neighbor_right_sim = 0
else:
right_embed = db[0,right_idx-1,:]
neighbor_right_sim = nb_dot(vec, right_embed)
if left_idx == 0:
neighbor_left_sim = 0
else:
left_embed = db[0,left_idx-1,:]
neighbor_left_sim = nb_dot(vec, left_embed)
#if right_idx > LEN or left_idx < 0:
if right_idx > len or left_idx == 0:
neighbor_ave_sim = neighbor_left_sim + neighbor_right_sim
else:
neighbor_ave_sim = (neighbor_left_sim + neighbor_right_sim) / 2
return neighbor_ave_sim
@nb.jit(nopython=True, fastmath=True, cache=True)
def nb_dot(x, y):
return np.dot(x,y)
def find_second_search_path(align, w, src_len, tgt_len):
'''
Convert 1-1 alignment from first-pass to the path for second-pass alignment.
The index along X-axis and Y-axis must be consecutive.
'''
last_bead_src = align[-1][0]
last_bead_tgt = align[-1][1]
if last_bead_src != src_len:
if last_bead_tgt == tgt_len:
align.pop()
align.append((src_len, tgt_len))
else:
if last_bead_tgt != tgt_len:
align.pop()
align.append((src_len, tgt_len))
prev_src, prev_tgt = 0,0
path = []
max_w = -np.inf
for src, tgt in align:
lower_bound = max(0, prev_tgt - w)
upper_bound = min(tgt_len, tgt + w)
path.extend([(lower_bound, upper_bound) for id in range(prev_src+1, src+1)])
prev_src, prev_tgt = src, tgt
width = upper_bound - lower_bound
if width > max_w:
max_w = width
path = [path[0]] + path
return max_w + 1, np.array(path)
def first_back_track(i, j, b, search_path, a_types):
alignment = []
while ( i !=0 and j != 0 ):
j_offset = j - search_path[i][0]
a = b[i][j_offset]
s = a_types[a][0]
t = a_types[a][1]
if a == 2:
alignment.append((i, j))
i = i-s
j = j-t
return alignment[::-1]
@nb.jit(nopython=True, fastmath=True, cache=True)
def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k):
#initialize cost and backpointer matrix
cost = np.zeros((src_len + 1, 2 * w + 1))
pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64)
cost[0][0] = 0
pointers[0][0] = -1
for i in range(1, src_len + 1):
i_start = search_path[i][0]
i_end = search_path[i][1]
for j in range(i_start, i_end + 1):
if i + j == 0:
continue
best_score = -np.inf
best_a = -1
for a in range(align_types.shape[0]):
a_1 = align_types[a][0]
a_2 = align_types[a][1]
prev_i = i - a_1
prev_j = j - a_2
if prev_i < 0 or prev_j < 0 : # no previous cell
continue
prev_i_start = search_path[prev_i][0]
prev_i_end = search_path[prev_i][1]
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
continue
prev_j_offset = prev_j - prev_i_start
score = cost[prev_i][prev_j_offset]
if score == -np.inf:
continue
if a_1 > 0 and a_2 > 0:
for k in range(top_k):
if index[i-1][k] == j - 1:
score += dist[i-1][k]
if score > best_score:
best_score = score
best_a = a
j_offset = j - i_start
cost[i][j_offset] = best_score
pointers[i][j_offset] = best_a
return pointers
@nb.jit(nopython=True, fastmath=True, cache=True)
def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
yx_ratio = tgt_len / src_len
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
w_2 = int(max(src_len, tgt_len) * 0.06)
w = max(w_1, w_2)
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
for i in range(0, src_len + 1):
center = int(yx_ratio * i)
w_start = max(0, center - w)
w_end = min(center + w, tgt_len)
search_path[i] = [w_start, w_end]
return w, search_path
def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
lines = [preprocess_line(line) for line in lines]
vecsize = line_embeddings.shape[1]
vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32)
vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int)
for ii, overlap in enumerate(range(1, num_overlaps + 1)):
for jj, out_line in enumerate(layer(lines, overlap)):
try:
line_id = sent2line[out_line]
except KeyError:
logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line)
line_id = None
if line_id is not None:
vec = line_embeddings[line_id]
else:
vec = np.random.random(vecsize) - 0.5
vec = vec / np.linalg.norm(vec)
vecs0[ii, jj, :] = vec
vecs1[ii, jj] = len(out_line.encode("utf-8"))
return vecs0, vecs1
def preprocess_line(line):
line = line.strip()
if len(line) == 0:
line = 'BLANK_LINE'
return line
def layer(lines, num_overlaps, comb=' '):
"""
make front-padded overlapping sentences
"""
if num_overlaps < 1:
raise Exception('num_overlaps must be >= 1')
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
for ii in range(len(lines) - num_overlaps + 1):
out.append(comb.join(lines[ii:ii + num_overlaps]))
return out
def read_in_embeddings(text_file, embed_file):
sent2line = dict()
with open(text_file, 'rt', encoding="utf-8") as fin:
for ii, line in enumerate(fin):
if line.strip() in sent2line:
raise Exception('got multiple embeddings for the same line')
sent2line[line.strip()] = ii
line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1)
if line_embeddings.size == 0:
raise Exception('Got empty embedding file')
embedding_size = line_embeddings.size // len(sent2line)
line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size)
return sent2line, line_embeddings
def make_alignment_types(max_alignment_size):
# Return list of all (n,m) where n+m <= this
alignment_types = []
for x in range(1, max_alignment_size):
for y in range(1, max_alignment_size):
if x + y <= max_alignment_size:
alignment_types.append([x, y])
alignment_types = [[0,1], [1,0]] + alignment_types
return np.array(alignment_types)
def create_jobs(meta, src, tgt, out):
jobs = []
fns = get_fns(meta)
for file in fns:
src_path = os.path.abspath(os.path.join(src, file))
tgt_path = os.path.abspath(os.path.join(tgt, file))
out_path = os.path.abspath(os.path.join(out, file + '.align'))
jobs.append('\t'.join([src_path, tgt_path, out_path]))
return jobs
def get_fns(meta):
fns = []
with open(meta, 'rt', encoding='utf-8') as f:
next(f) # skip header
for line in f:
recs = line.strip().split('\t')
fns.append(recs[0])
return fns
def print_alignments(alignments, out):
with open(out, 'wt', encoding='utf-8') as f:
for x, y in alignments:
f.write("{}:{}\n".format(x, y))
def make_dir(path):
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
if __name__ == '__main__':
t_0 = time.time()
main()
print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0))

107
bin/embed_sents.py Normal file
View File

@@ -0,0 +1,107 @@
# 2021/11/27
# bfsujason@163.com
'''
Usage (Linux):
python bin/embed_sents.py \
-i data/mac/dev/zh \
-o data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
-m data/mac/test/meta_data.tsv \
-n 8
'''
import os
import time
import shutil
import argparse
import numpy as np
from sentence_transformers import SentenceTransformer
def main():
parser = argparse.ArgumentParser(description='Multilingual sentence embeddings')
parser.add_argument('-i', '--input', type=str, required=True, help='Data directory.')
parser.add_argument('-o', '--output', type=str, required=True, nargs=2, help='Overalp and embedding file.')
parser.add_argument('-n', '--num_overlaps', type=int, default=5, help='Maximum number of allowed overlaps.')
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
args = parser.parse_args()
fns = get_fns(args.meta)
overlap = get_overlap(args.input, fns, args.num_overlaps)
write_overlap(overlap, args.output[0])
model = load_model()
embed_overlap(model, overlap, args.output[1])
def embed_overlap(model, overlap, fout):
print("Embedding text ...")
t_0 = time.time()
embed = model.encode(overlap)
embed.tofile(fout)
print("It takes {:.3f} seconods to embed text.".format(time.time() - t_0))
def write_overlap(overlap, outfile):
with open(outfile, 'wt', encoding="utf-8") as fout:
for line in overlap:
fout.write(line + '\n')
def get_overlap(dir, fns, n):
overlap = set()
for file in fns:
in_path = os.path.join(dir, file)
lines = open(in_path, 'rt', encoding="utf-8").readlines()
for out_line in yield_overlaps(lines, n):
overlap.add(out_line)
# for reproducibility
overlap = list(overlap)
overlap.sort()
return overlap
def yield_overlaps(lines, num_overlaps):
lines = [preprocess_line(line) for line in lines]
for overlap in range(1, num_overlaps + 1):
for out_line in layer(lines, overlap):
# check must be here so all outputs are unique
out_line2 = out_line[:10000] # limit line so dont encode arbitrarily long sentences
yield out_line2
def layer(lines, num_overlaps, comb=' '):
if num_overlaps < 1:
raise Exception('num_overlaps must be >= 1')
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
for ii in range(len(lines) - num_overlaps + 1):
out.append(comb.join(lines[ii:ii + num_overlaps]))
return out
def preprocess_line(line):
line = line.strip()
if len(line) == 0:
line = 'BLANK_LINE'
return line
def load_model():
print("Loading embedding model ...")
t0 = time.time()
model = SentenceTransformer('LaBSE')
print("It takes {:.3f} seconods to load the model.".format(time.time() - t0))
return model
def get_fns(meta):
fns = []
with open(meta, 'rt', encoding='utf-8') as f:
next(f) # skip header
for line in f:
recs = line.strip().split('\t')
fns.append(recs[0])
return fns
def make_dir(path):
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
if __name__ == '__main__':
main()

182
bin/eval.py Normal file
View File

@@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
Copyright 2019 Brian Thompson
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import sys
import argparse
from ast import literal_eval
from collections import defaultdict
import numpy as np
"""
Faster implementation of lax and strict precision and recall, based on
https://www.aclweb.org/anthology/W11-4624/.
"""
def read_alignments(fin):
alignments = []
with open(fin, 'rt', encoding="utf-8") as infile:
for line in infile:
fields = [x.strip() for x in line.split(':') if len(x.strip())]
if len(fields) < 2:
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
try:
src = literal_eval(fields[0])
tgt = literal_eval(fields[1])
except:
raise Exception('Failed to parse line "%s"' % line.strip())
alignments.append((src, tgt))
return alignments
def _precision(goldalign, testalign):
"""
Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments
"""
tpstrict = 0 # true positive strict counter
tplax = 0 # true positive lax counter
fpstrict = 0 # false positive strict counter
fplax = 0 # false positive lax counter
# convert to sets, remove alignments empty on both sides
testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)])
goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)])
# mappings from source test sentence idxs to
# target gold sentence idxs for which the source test sentence
# was found in corresponding source gold alignment
src_id_to_gold_tgt_ids = defaultdict(set)
for gold_src, gold_tgt in goldalign:
for gold_src_id in gold_src:
for gold_tgt_id in gold_tgt:
src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id)
for (test_src, test_target) in testalign:
if (test_src, test_target) == ((), ()):
continue
if (test_src, test_target) in goldalign:
# strict match
tpstrict += 1
tplax += 1
else:
# For anything with partial gold/test overlap on the source,
# see if there is also partial overlap on the gold/test target
# If so, its a lax match
target_ids = set()
for src_test_id in test_src:
for tgt_id in src_id_to_gold_tgt_ids[src_test_id]:
target_ids.add(tgt_id)
if set(test_target).intersection(target_ids):
fpstrict += 1
tplax += 1
else:
fpstrict += 1
fplax += 1
return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32)
def score_multiple(gold_list, test_list, value_for_div_by_0=0.0):
# accumulate counts for all gold/test files
pcounts = np.array([0, 0, 0, 0], dtype=np.int32)
rcounts = np.array([0, 0, 0, 0], dtype=np.int32)
for goldalign, testalign in zip(gold_list, test_list):
pcounts += _precision(goldalign=goldalign, testalign=testalign)
# recall is precision with no insertion/deletion and swap args
test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)]
gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)]
rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del)
# Compute results
# pcounts: tpstrict,fnstrict,tplax,fnlax
# rcounts: tpstrict,fpstrict,tplax,fplax
if pcounts[0] + pcounts[1] == 0:
pstrict = value_for_div_by_0
else:
pstrict = pcounts[0] / float(pcounts[0] + pcounts[1])
if pcounts[2] + pcounts[3] == 0:
plax = value_for_div_by_0
else:
plax = pcounts[2] / float(pcounts[2] + pcounts[3])
if rcounts[0] + rcounts[1] == 0:
rstrict = value_for_div_by_0
else:
rstrict = rcounts[0] / float(rcounts[0] + rcounts[1])
if rcounts[2] + rcounts[3] == 0:
rlax = value_for_div_by_0
else:
rlax = rcounts[2] / float(rcounts[2] + rcounts[3])
if (pstrict + rstrict) == 0:
fstrict = value_for_div_by_0
else:
fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict)
if (plax + rlax) == 0:
flax = value_for_div_by_0
else:
flax = 2 * (plax * rlax) / (plax + rlax)
result = dict(recall_strict=rstrict,
recall_lax=rlax,
precision_strict=pstrict,
precision_lax=plax,
f1_strict=fstrict,
f1_lax=flax)
return result
def log_final_scores(res):
print(' ---------------------------------', file=sys.stderr)
print('| | Strict | Lax |', file=sys.stderr)
print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr)
print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr)
print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr)
print(' ---------------------------------', file=sys.stderr)
def main():
parser = argparse.ArgumentParser(
'Compute strict/lax precision and recall for one or more pairs of gold/test alignments',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-t', '--test', type=str, required=True,
help='Test alignment directory.')
parser.add_argument('-g', '--gold', type=str, required=True,
help='Gold alignment directory.')
args = parser.parse_args()
gold_list = [read_alignments(os.path.join(args.gold, x)) for x in sorted(os.listdir(args.gold))]
test_list = [read_alignments(os.path.join(args.test, x)) for x in sorted(os.listdir(args.test))]
res = score_multiple(gold_list=gold_list, test_list=test_list)
log_final_scores(res)
if __name__ == '__main__':
main()

109
bin/eval_bible.py Normal file
View File

@@ -0,0 +1,109 @@
import os
import argparse
from ast import literal_eval
from collections import defaultdict
from eval import score_multiple, log_final_scores
def main():
parser = argparse.ArgumentParser('Evaluate aligment quality for Bible corpus')
parser.add_argument('-t', '--test', type=str, required=True, help='Test alignment file.')
parser.add_argument('-g', '--gold', type=str, required=True, help='Gold alignment file.')
parser.add_argument('--src_verse', type=str, required=True, help='Source verse file.')
parser.add_argument('--tgt_verse', type=str, required=True, help='Target verse file.')
args = parser.parse_args()
test_alignments = read_alignments(args.test)
gold_alignments = read_alignments(args.gold)
src_verse = get_verse(args.src_verse)
tgt_verse = get_verse(args.tgt_verse)
merged_test_alignments = merge_test_alignments(test_alignments, src_verse, tgt_verse)
res = score_multiple(gold_list=[gold_alignments], test_list=[merged_test_alignments])
log_final_scores(res)
def merge_test_alignments(alignments, src_verse, tgt_verse):
merged_align = []
last_beads_type = None
for beads in alignments:
beads_type = find_beads_type(beads, src_verse, tgt_verse)
if not last_beads_type:
merged_align.append(beads)
else:
if beads_type == last_beads_type:
merged_align[-1][0].extend(beads[0])
merged_align[-1][1].extend(beads[1])
else:
merged_align.append(beads)
last_beads_type = beads_type
return merged_align
def find_beads_type(beads, src_verse, tgt_verse):
src_bead = beads[0]
tgt_bead = beads[1]
src_bead_type = find_bead_type(src_bead, src_verse)
tgt_bead_type = find_bead_type(tgt_bead, tgt_verse)
src_bead_len = len(src_bead_type)
tgt_bead_len = len(tgt_bead_type)
if src_bead_len != 1 or tgt_bead_len != 1:
return None
else:
src_verse = src_bead_type[0]
tgt_verse = tgt_bead_type[0]
if src_verse != tgt_verse:
if src_verse == 'NULL':
return tgt_verse
elif tgt_verse == 'NULL':
return src_verse
else:
return None
else:
return src_verse
def find_bead_type(bead, verse):
bead_type = ['NULL']
if len(bead) > 0:
bead_type = unique_list([verse[id] for id in bead])
return bead_type
def unique_list(list):
unique_list = []
for x in list:
if x not in unique_list:
unique_list.append(x)
return unique_list
def get_verse(file):
verse = defaultdict()
with open(file, 'rt', encoding='utf-8') as f:
for (i, line) in enumerate(f):
verse[i] = line.strip()
return verse
def read_alignments(fin):
alignments = []
with open(fin, 'rt', encoding="utf-8") as infile:
for line in infile:
fields = [x.strip() for x in line.split(':') if len(x.strip())]
if len(fields) < 2:
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
try:
src = literal_eval(fields[0])
tgt = literal_eval(fields[1])
except:
raise Exception('Failed to parse line "%s"' % line.strip())
alignments.append((src, tgt))
return alignments
if __name__ == '__main__':
main()

226
bin/gale_align.py Normal file
View File

@@ -0,0 +1,226 @@
# 2021/11/27
# bfsujason@163.com
"""
Usage:
python bin/gale_align.py \
-m data/mac/test/meta_data.tsv \
-s data/mac/test/zh \
-t data/mac/test/en \
-o data/mac/test/auto
"""
import os
import time
import math
import shutil
import argparse
import numba as nb
import numpy as np
def main():
# user-defined parameters
parser = argparse.ArgumentParser(description='Sentence alignment using Gale-Church Algrorithm')
parser.add_argument('-s', '--src', type=str, required=True, help='Source directory.')
parser.add_argument('-t', '--tgt', type=str, required=True, help='Target directory.')
parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.')
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
args = parser.parse_args()
make_dir(args.out)
# fixed parameters to determine the window size for alignment
min_win_size = 10
max_win_size = 600
win_per_100 = 8
# alignment types
align_types = np.array(
[
[0,1],
[1,0],
[1,1],
[1,2],
[2,1],
[2,2],
], dtype=np.int)
# prior probability
priors = np.array([0, 0.0099, 0.89, 0.089, 0.011])
# mean and variance
c = 1
s2 = 6.8
# perform gale-church align
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
for rec in jobs:
src_file, tgt_file, align_file = rec.split("\t")
print("Aligning {} to {}".format(src_file, tgt_file))
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
src_len = calculate_txt_len(src_lines)
tgt_len = calculate_txt_len(tgt_lines)
m = src_len.shape[0] - 1
n = tgt_len.shape[0] - 1
# find search path
w, search_path = find_search_path(m, n, min_win_size, max_win_size, win_per_100)
cost, back = align(src_len, tgt_len, w, search_path, align_types, priors, c, s2)
alignments = back_track(m, n, back, search_path, align_types)
# save alignments
save_alignments(alignments, align_file)
def save_alignments(alignments, file):
with open(file, 'wt', encoding='utf-8') as f:
for id in alignments:
f.write("{}:{}\n".format(id[0], id[1]))
def back_track(i, j, b, search_path, a_types):
alignment = []
while ( i !=0 and j != 0 ):
j_offset = j - search_path[i][0]
a = b[i][j_offset]
s = a_types[a][0]
t = a_types[a][1]
src_range = [i - offset - 1 for offset in range(s)][::-1]
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
alignment.append((src_range, tgt_range))
i = i-s
j = j-t
return alignment[::-1]
@nb.jit(nopython=True, fastmath=True, cache=True)
def align(src_len, tgt_len, w, search_path, align_types, priors, c, s2):
#initialize cost and backpointer matrix
m = src_len.shape[0] - 1
cost = np.zeros((m + 1, 2 * w + 1))
back = np.zeros((m + 1, 2 * w + 1), dtype=nb.int64)
cost[0][0] = 0
back[0][0] = -1
for i in range(m + 1):
i_start = search_path[i][0]
i_end = search_path[i][1]
for j in range(i_start, i_end + 1):
if i + j == 0:
continue
best_score = np.inf
best_a = -1
for a in range(align_types.shape[0]):
a_1 = align_types[a][0]
a_2 = align_types[a][1]
prev_i = i - a_1
prev_j = j - a_2
if prev_i < 0 or prev_j < 0 : # no previous cell
continue
prev_i_start = search_path[prev_i][0]
prev_i_end = search_path[prev_i][1]
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
continue
prev_j_offset = prev_j - prev_i_start
score = cost[prev_i][prev_j_offset] - math.log(priors[a_1 + a_2]) + \
get_score(src_len[i] - src_len[i - a_1], tgt_len[j] - tgt_len[j - a_2], c, s2)
if score < best_score:
best_score = score
best_a = a
j_offset = j - i_start
cost[i][j_offset] = best_score
back[i][j_offset] = best_a
return cost, back
@nb.jit(nopython=True, fastmath=True, cache=True)
def get_score(len_s, len_t, c, s2):
mean = (len_s + len_t / c) / 2
z = (len_t - len_s * c) / math.sqrt(mean * s2)
pd = 2 * (1 - norm_cdf(abs(z)))
if pd > 0:
return -math.log(pd)
return 25
@nb.jit(nopython=True, fastmath=True, cache=True)
def find_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
yx_ratio = tgt_len / src_len
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
w_2 = int(max(src_len, tgt_len) * 0.06)
w = max(w_1, w_2)
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
for i in range(0, src_len + 1):
center = int(yx_ratio * i)
w_start = max(0, center - w)
w_end = min(center + w, tgt_len)
search_path[i] = [w_start, w_end]
return w, search_path
@nb.jit(nopython=True, fastmath=True, cache=True)
def norm_cdf(z):
t = 1/float(1+0.2316419*z) # t = 1/(1+pz) , z=0.2316419
p_norm = 1 - 0.3989423*math.exp(-z*z/2) * ((0.319381530 * t)+ \
(-0.356563782 * t)+ \
(1.781477937 * t) + \
(-1.821255978* t) + \
(1.330274429 * t))
return p_norm
def calculate_txt_len(lines):
txt_len = []
txt_len.append(0)
for i, line in enumerate(lines):
# UTF-8 byte length
txt_len.append(txt_len[i] + len(line.strip().encode("utf-8")))
return np.array(txt_len)
def create_jobs(meta, src, tgt, out):
jobs = []
fns = get_fns(meta)
for file in fns:
src_path = os.path.abspath(os.path.join(src, file))
tgt_path = os.path.abspath(os.path.join(tgt, file))
out_path = os.path.abspath(os.path.join(out, file + '.align'))
jobs.append('\t'.join([src_path, tgt_path, out_path]))
return jobs
def get_fns(meta):
fns = []
with open(meta, 'rt', encoding='utf-8') as f:
next(f) # skip header
for line in f:
recs = line.strip().split('\t')
fns.append(recs[0])
return fns
def make_dir(path):
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
if __name__ == '__main__':
t_0 = time.time()
main()
print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0))