Files
bertalign/bin/bert_align.py
nlpfun 025bc2afe4 Bertalign
Bertalign
2021-05-17 23:33:49 +08:00

401 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import os
import sys
import numpy as np
import numba as nb
import faiss
import time
def _main():
# user-defined parameters
parser = argparse.ArgumentParser('Multilingual sentence alignment using BERT embeddings',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--job', type=str, required=True, help='Job file for alignment task.')
parser.add_argument('--src_embed', type=str, required=True, nargs=2, help='Source overlap and embedding files.')
parser.add_argument('--tgt_embed', type=str, required=True, nargs=2, help='Target overlap and embedding files.')
parser.add_argument('--max_align', type=int, default=5, help='Maximum alignment types, n + m <= this value.')
parser.add_argument('--win', type=int, default=5, help='Window size for the second-pass alignment.')
parser.add_argument('--top_k', type=int, default=3, help='Top-k target neighbors of each source sentence.')
parser.add_argument('--skip', type=float, default=-0.1, help='Similarity score for 0-1 and 1-0 alignment.')
parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity')
args = parser.parse_args()
# fixed parameters to determine the
# window size for the first-pass alignment
min_win_size = 10
max_win_size = 600
win_per_100 = 8
# read in embeddings
src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1])
tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
embedding_size = src_line_embeddings.shape[1]
# read in alignment jobs
job = read_job(args.job)
# start alignment
for rec in job:
src_file, tgt_file, align_file = rec.split("\t")
print("Aligning {} to {}".format(src_file, tgt_file))
# read in source and target sentences
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
# convert source and target texts into embeddings
# and calculate sentence length
t_0 = time.time()
src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1)
tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1)
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
print("Reading embeddings takes {}".format(time.time() - t_0))
# using faiss, find in the target text
# the k nearest neighbors of each source sentence
#index = faiss.IndexFlatIP(embedding_size) # use inter product to build index
t_1 = time.time()
#index.add(tgt_vecs[0,:])
#xq = src_vecs[0,:]
#D,I = index.search(xq, args.top_k)
res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(embedding_size)
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
gpu_index.add(tgt_vecs[0,:])
xq = src_vecs[0,:]
D,I = gpu_index.search(xq,args.top_k)
print("Finding top-k neighbors takes {}".format(time.time() - t_1))
# find 1-to-1 alignment
t_2 = time.time()
src_len = len(src_lines)
tgt_len = len(tgt_lines)
first_alignment_types = make_alignment_types(2) # 0-0 1-0 and 1-1
first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100)
first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k)
first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types)
print("First pass alignment takes {}".format(time.time() - t_2))
# find m-to-n alignment
t_3 = time.time()
second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len)
second_alignment_types = make_alignment_types(args.max_align)
second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin)
second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types)
print("Second pass alignment takes {}".format(time.time() - t_3))
# save alignment
out_f = open(align_file, 'w', encoding="utf-8")
print_alignments(second_alignment, file=out_f)
def print_alignments(alignments, file=sys.stdout):
for x, y in alignments:
print('%s:%s' % (x, y), file=file)
def second_back_track(i, j, b, search_path, a_types):
alignment = []
while ( i !=0 and j != 0 ):
j_offset = j - search_path[i][0]
a = b[i][j_offset]
s = a_types[a][0]
t = a_types[a][1]
src_range = [i - offset - 1 for offset in range(s)][::-1]
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
alignment.append((src_range, tgt_range))
i = i-s
j = j-t
return alignment[::-1]
@nb.jit(nopython=True, fastmath=True, cache=True)
def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False):
src_len = src_vecs.shape[1]
tgt_len = tgt_vecs.shape[1]
# intialize sum matrix
cost = np.zeros((src_len + 1, w))
#back = np.zeros((tgt_len + 1, w), dtype=nb.int64)
back = np.zeros((src_len + 1, w), dtype=nb.int64)
cost[0][0] = 0
back[0][0] = -1
for i in range(1, src_len + 1):
i_start = search_path[i][0]
i_end = search_path[i][1]
for j in range(i_start, i_end + 1):
if i + j == 0:
continue
best_score = -np.inf
best_a = -1
for a in range(align_types.shape[0]):
a_1 = align_types[a][0]
a_2 = align_types[a][1]
prev_i = i - a_1
prev_j = j - a_2
if prev_i < 0 or prev_j < 0 : # no previous cell in DP table
continue
prev_i_start = search_path[prev_i][0]
prev_i_end = search_path[prev_i][1]
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
continue
prev_j_offset = prev_j - prev_i_start
score = cost[prev_i][prev_j_offset]
if score == -np.inf:
continue
if a_1 == 0 or a_2 == 0: # deletion or insertion
cur_score = skip
else:
src_v = src_vecs[a_1-1,i-1,:]
tgt_v = tgt_vecs[a_2-1,j-1,:]
src_l = src_lens[a_1-1, i-1]
tgt_l = tgt_lens[a_2-1, j-1]
cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin)
tgt_l = tgt_l * char_ratio
min_len = min(src_l, tgt_l)
max_len = max(src_l, tgt_l)
len_p = np.log2(1 + min_len / max_len)
cur_score *= len_p
score += cur_score
if score > best_score:
best_score = score
best_a = a
j_offset = j - i_start
cost[i][j_offset] = best_score
back[i][j_offset] = best_a
return back
@nb.jit(nopython=True, fastmath=True, cache=True)
def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False):
similarity = nb_dot(src_v, tgt_v)
if margin:
tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs)
src_neighbor_ave_sim = get_neighbor_sim(tgt_v, a_1, i, src_len, src_vecs)
neighbor_ave_sim = (tgt_neighbor_ave_sim + src_neighbor_ave_sim)/2
similarity -= neighbor_ave_sim
return similarity
@nb.jit(nopython=True, fastmath=True, cache=True)
def get_neighbor_sim(vec, a, j, len, db):
left_idx = j - a
right_idx = j + 1
if right_idx > len:
neighbor_right_sim = 0
else:
right_embed = db[0,right_idx-1,:]
neighbor_right_sim = nb_dot(vec, right_embed)
if left_idx == 0:
neighbor_left_sim = 0
else:
left_embed = db[0,left_idx-1,:]
neighbor_left_sim = nb_dot(vec, left_embed)
#if right_idx > LEN or left_idx < 0:
if right_idx > len or left_idx == 0:
neighbor_ave_sim = neighbor_left_sim + neighbor_right_sim
else:
neighbor_ave_sim = (neighbor_left_sim + neighbor_right_sim) / 2
return neighbor_ave_sim
@nb.jit(nopython=True, fastmath=True, cache=True)
def nb_dot(x, y):
return np.dot(x,y)
def find_second_search_path(align, w, src_len, tgt_len):
'''
Convert 1-1 alignment from first-pass to the path for second-pass alignment.
The index along X-axis and Y-axis must be consecutive.
'''
last_bead_src = align[-1][0]
last_bead_tgt = align[-1][1]
if last_bead_src != src_len:
if last_bead_tgt == tgt_len:
align.pop()
align.append((src_len, tgt_len))
else:
if last_bead_tgt != tgt_len:
align.pop()
align.append((src_len, tgt_len))
prev_src, prev_tgt = 0,0
path = []
max_w = -np.inf
for src, tgt in align:
lower_bound = max(0, prev_tgt - w)
upper_bound = min(tgt_len, tgt + w)
path.extend([(lower_bound, upper_bound) for id in range(prev_src+1, src+1)])
prev_src, prev_tgt = src, tgt
width = upper_bound - lower_bound
if width > max_w:
max_w = width
path = [path[0]] + path
return max_w + 1, np.array(path)
def first_back_track(i, j, b, search_path, a_types):
alignment = []
while ( i !=0 and j != 0 ):
j_offset = j - search_path[i][0]
a = b[i][j_offset]
s = a_types[a][0]
t = a_types[a][1]
if a == 2:
alignment.append((i, j))
i = i-s
j = j-t
return alignment[::-1]
@nb.jit(nopython=True, fastmath=True, cache=True)
def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k):
#initialize cost and backpointer matrix
cost = np.zeros((src_len + 1, 2 * w + 1))
pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64)
cost[0][0] = 0
pointers[0][0] = -1
for i in range(1, src_len + 1):
i_start = search_path[i][0]
i_end = search_path[i][1]
for j in range(i_start, i_end + 1):
if i + j == 0:
continue
best_score = -np.inf
best_a = -1
for a in range(align_types.shape[0]):
a_1 = align_types[a][0]
a_2 = align_types[a][1]
prev_i = i - a_1
prev_j = j - a_2
if prev_i < 0 or prev_j < 0 : # no previous cell
continue
prev_i_start = search_path[prev_i][0]
prev_i_end = search_path[prev_i][1]
if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix
continue
prev_j_offset = prev_j - prev_i_start
score = cost[prev_i][prev_j_offset]
if score == -np.inf:
continue
if a_1 > 0 and a_2 > 0:
for k in range(top_k):
if index[i-1][k] == j - 1:
score += dist[i-1][k]
if score > best_score:
best_score = score
best_a = a
j_offset = j - i_start
cost[i][j_offset] = best_score
pointers[i][j_offset] = best_a
return pointers
@nb.jit(nopython=True, fastmath=True, cache=True)
def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
yx_ratio = tgt_len / src_len
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
#w_2 = int(max(src_len, tgt_len) * 0.05)
w_2 = int(max(src_len, tgt_len) * 0.06)
w = max(w_1, w_2)
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
for i in range(0, src_len + 1):
center = int(yx_ratio * i)
w_start = max(0, center - w)
w_end = min(center + w, tgt_len)
search_path[i] = [w_start, w_end]
return w, search_path
def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
lines = [preprocess_line(line) for line in lines]
vecsize = line_embeddings.shape[1]
vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32)
vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int)
for ii, overlap in enumerate(range(1, num_overlaps + 1)):
for jj, out_line in enumerate(layer(lines, overlap)):
try:
line_id = sent2line[out_line]
except KeyError:
logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line)
line_id = None
if line_id is not None:
vec = line_embeddings[line_id]
else:
vec = np.random.random(vecsize) - 0.5
vec = vec / np.linalg.norm(vec)
vecs0[ii, jj, :] = vec
vecs1[ii, jj] = len(out_line.encode("utf-8"))
return vecs0, vecs1
def preprocess_line(line):
line = line.strip()
if len(line) == 0:
line = 'BLANK_LINE'
return line
def layer(lines, num_overlaps, comb=' '):
"""
make front-padded overlapping sentences
"""
if num_overlaps < 1:
raise Exception('num_overlaps must be >= 1')
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
for ii in range(len(lines) - num_overlaps + 1):
out.append(comb.join(lines[ii:ii + num_overlaps]))
return out
def read_in_embeddings(text_file, embed_file):
sent2line = dict()
with open(text_file, 'rt', encoding="utf-8") as fin:
for ii, line in enumerate(fin):
if line.strip() in sent2line:
raise Exception('got multiple embeddings for the same line')
sent2line[line.strip()] = ii
line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1)
if line_embeddings.size == 0:
raise Exception('Got empty embedding file')
embedding_size = line_embeddings.size // len(sent2line)
line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size)
return sent2line, line_embeddings
def make_alignment_types(max_alignment_size):
# Return list of all (n,m) where n+m <= this
alignment_types = []
for x in range(1, max_alignment_size):
for y in range(1, max_alignment_size):
if x + y <= max_alignment_size:
alignment_types.append([x, y])
alignment_types = [[0,1], [1,0]] + alignment_types
return np.array(alignment_types)
def read_job(file):
job = []
with open(file, 'r', encoding="utf-8") as f:
for line in f:
if not line.startswith("#"):
job.append(line.strip())
return job
if __name__ == '__main__':
_main()