Update bert_align.py
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# 2021/11/27
|
# 2021/11/29
|
||||||
# bfsujason@163.com
|
# bfsujason@163.com
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -15,7 +15,6 @@ python bin/bert_align.py \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import faiss
|
import faiss
|
||||||
@@ -42,99 +41,100 @@ def main():
|
|||||||
parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity')
|
parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# fixed parameters to determine the
|
# Read in source and target embeddings.
|
||||||
# window size for the first-pass alignment
|
src_sent2line, src_line_embeddings = \
|
||||||
min_win_size = 10
|
read_in_embeddings(args.src_embed[0], args.src_embed[1])
|
||||||
max_win_size = 600
|
tgt_sent2line, tgt_line_embeddings = \
|
||||||
win_per_100 = 8
|
read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
|
||||||
|
|
||||||
# read in embeddings
|
|
||||||
src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1])
|
|
||||||
tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
|
|
||||||
embedding_size = src_line_embeddings.shape[1]
|
|
||||||
|
|
||||||
|
# Perform stentence alignment.
|
||||||
make_dir(args.out)
|
make_dir(args.out)
|
||||||
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
|
jobs = create_jobs(args.meta, args.src, args.tgt, args.out)
|
||||||
|
for job in jobs:
|
||||||
# start alignment
|
src_file, tgt_file, out_file = job.split('\t')
|
||||||
for rec in jobs:
|
|
||||||
src_file, tgt_file, align_file = rec.split("\t")
|
|
||||||
print("Aligning {} to {}".format(src_file, tgt_file))
|
print("Aligning {} to {}".format(src_file, tgt_file))
|
||||||
|
|
||||||
# read in source and target sentences
|
# Convert source and target texts into feature matrix.
|
||||||
|
t_0 = time.time()
|
||||||
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
|
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
|
||||||
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
|
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
|
||||||
|
src_vecs, src_lens = \
|
||||||
# convert source and target texts into embeddings
|
doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1)
|
||||||
# and calculate sentence length
|
tgt_vecs, tgt_lens = \
|
||||||
t_0 = time.time()
|
doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1)
|
||||||
src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1)
|
|
||||||
tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1)
|
|
||||||
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
|
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
|
||||||
print("Reading embeddings takes {:.3f}".format(time.time() - t_0))
|
print("Vectorizing soure and target texts takes {:.3f} seconds.".format(time.time() - t_0))
|
||||||
|
|
||||||
# using faiss, find in the target text
|
# Find the top_k similar target sentences for each source sentence.
|
||||||
# the k nearest neighbors of each source sentence
|
|
||||||
t_1 = time.time()
|
t_1 = time.time()
|
||||||
if torch.cuda.is_available(): # GPU version
|
D, I = find_top_k_sents(src_vecs[0,:], tgt_vecs[0,:], k=args.top_k)
|
||||||
res = faiss.StandardGpuResources()
|
print("Finding top-k sentences takes {:.3f} seconds.".format(time.time() - t_1))
|
||||||
index = faiss.IndexFlatIP(embedding_size)
|
|
||||||
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
|
|
||||||
gpu_index.add(tgt_vecs[0,:])
|
|
||||||
xq = src_vecs[0,:]
|
|
||||||
D,I = gpu_index.search(xq,args.top_k)
|
|
||||||
else: # CPU version
|
|
||||||
index = faiss.IndexFlatIP(embedding_size) # use inter product to build index
|
|
||||||
index.add(tgt_vecs[0,:])
|
|
||||||
xq = src_vecs[0,:]
|
|
||||||
D,I = index.search(xq, args.top_k)
|
|
||||||
print("Finding top-k neighbors takes {:.3f}".format(time.time() - t_1))
|
|
||||||
|
|
||||||
# find 1-to-1 alignment
|
# Find optimal 1-1 alignments using dynamic programming.
|
||||||
t_2 = time.time()
|
t_2 = time.time()
|
||||||
src_len = len(src_lines)
|
m = len(src_lines)
|
||||||
tgt_len = len(tgt_lines)
|
n = len(tgt_lines)
|
||||||
first_alignment_types = make_alignment_types(2) # 0-0, 1-0 and 1-1
|
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
|
||||||
first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100)
|
first_w, first_path = find_first_search_path(m, n)
|
||||||
first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k)
|
first_pointers = first_pass_align(m, n, first_w,
|
||||||
first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types)
|
first_path, first_alignment_types,
|
||||||
print("First pass alignment takes {:.3f}".format(time.time() - t_2))
|
D, I, args.top_k)
|
||||||
|
first_alignment = first_back_track(m, n,
|
||||||
|
first_pointers, first_path,
|
||||||
|
first_alignment_types)
|
||||||
|
print("First-pass alignment takes {:.3f} seconds.".format(time.time() - t_2))
|
||||||
|
|
||||||
# find m-to-n alignment
|
# Find optimal m-to-n alignments using dynamic programming.
|
||||||
t_3 = time.time()
|
t_3 = time.time()
|
||||||
second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len)
|
second_alignment_types = get_alignment_types(args.max_align)
|
||||||
second_alignment_types = make_alignment_types(args.max_align)
|
second_w, second_path = find_second_path(first_alignment, args.win, m, n)
|
||||||
second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin)
|
second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens,
|
||||||
second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types)
|
second_w, second_path, second_alignment_types,
|
||||||
|
char_ratio, args.skip, margin=args.margin)
|
||||||
|
second_alignment = second_back_track(m, n, second_pointers,
|
||||||
|
second_path, second_alignment_types)
|
||||||
print("Second pass alignment takes {:.3f}".format(time.time() - t_3))
|
print("Second pass alignment takes {:.3f}".format(time.time() - t_3))
|
||||||
|
|
||||||
# save alignment
|
# save alignment results
|
||||||
print_alignments(second_alignment, align_file)
|
print_alignments(second_alignment, out_file)
|
||||||
|
|
||||||
def second_back_track(i, j, b, search_path, a_types):
|
def print_alignments(alignments, out):
|
||||||
alignment = []
|
with open(out, 'wt', encoding='utf-8') as f:
|
||||||
while ( i !=0 and j != 0 ):
|
for x, y in alignments:
|
||||||
j_offset = j - search_path[i][0]
|
f.write("{}:{}\n".format(x, y))
|
||||||
a = b[i][j_offset]
|
|
||||||
s = a_types[a][0]
|
|
||||||
t = a_types[a][1]
|
|
||||||
src_range = [i - offset - 1 for offset in range(s)][::-1]
|
|
||||||
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
|
|
||||||
alignment.append((src_range, tgt_range))
|
|
||||||
|
|
||||||
i = i-s
|
|
||||||
j = j-t
|
|
||||||
|
|
||||||
return alignment[::-1]
|
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False):
|
def second_pass_align(src_vecs,
|
||||||
|
tgt_vecs,
|
||||||
|
src_lens,
|
||||||
|
tgt_lens,
|
||||||
|
w,
|
||||||
|
search_path,
|
||||||
|
align_types,
|
||||||
|
char_ratio,
|
||||||
|
skip,
|
||||||
|
margin=False):
|
||||||
|
"""
|
||||||
|
Perform the second-pass alignment to extract n-m bitext segments.
|
||||||
|
Args:
|
||||||
|
src_vecs: numpy array of shape (max_align-1, num_src_sents, embedding_size).
|
||||||
|
tgt_vecs: numpy array of shape (max_align-1, num_tgt_sents, embedding_size)
|
||||||
|
src_lens: numpy array of shape (max_align-1, num_src_sents).
|
||||||
|
tgt_lens: numpy array of shape (max_align-1, num_tgt_sents).
|
||||||
|
w: int. Predefined window size for the second-pass alignment.
|
||||||
|
search_path: numpy array. Second-pass alignment search path.
|
||||||
|
align_types: numpy array. Second-pass alignment types.
|
||||||
|
char_ratio: float. Ratio between source length to target length.
|
||||||
|
skip: float. Cost for instertion and deletion.
|
||||||
|
margin: boolean. Set to true if choosing modified cosine similarity score.
|
||||||
|
Returns:
|
||||||
|
pointers: numpy array recording best alignments for each DP cell.
|
||||||
|
"""
|
||||||
src_len = src_vecs.shape[1]
|
src_len = src_vecs.shape[1]
|
||||||
tgt_len = tgt_vecs.shape[1]
|
tgt_len = tgt_vecs.shape[1]
|
||||||
|
|
||||||
# intialize sum matrix
|
# Intialize cost and backpointer matrix
|
||||||
cost = np.zeros((src_len + 1, w))
|
cost = np.zeros((src_len + 1, w))
|
||||||
#back = np.zeros((tgt_len + 1, w), dtype=nb.int64)
|
|
||||||
back = np.zeros((src_len + 1, w), dtype=nb.int64)
|
back = np.zeros((src_len + 1, w), dtype=nb.int64)
|
||||||
cost[0][0] = 0
|
cost[0][0] = 0
|
||||||
back[0][0] = -1
|
back[0][0] = -1
|
||||||
@@ -171,7 +171,11 @@ def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, al
|
|||||||
tgt_v = tgt_vecs[a_2-1,j-1,:]
|
tgt_v = tgt_vecs[a_2-1,j-1,:]
|
||||||
src_l = src_lens[a_1-1, i-1]
|
src_l = src_lens[a_1-1, i-1]
|
||||||
tgt_l = tgt_lens[a_2-1, j-1]
|
tgt_l = tgt_lens[a_2-1, j-1]
|
||||||
cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin)
|
cur_score = get_score(src_v, tgt_v,
|
||||||
|
a_1, a_2, i, j,
|
||||||
|
src_vecs, tgt_vecs,
|
||||||
|
src_len, tgt_len,
|
||||||
|
margin=margin)
|
||||||
tgt_l = tgt_l * char_ratio
|
tgt_l = tgt_l * char_ratio
|
||||||
min_len = min(src_l, tgt_l)
|
min_len = min(src_l, tgt_l)
|
||||||
max_len = max(src_l, tgt_l)
|
max_len = max(src_l, tgt_l)
|
||||||
@@ -189,8 +193,29 @@ def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, al
|
|||||||
|
|
||||||
return back
|
return back
|
||||||
|
|
||||||
|
def second_back_track(i, j, b, search_path, a_types):
|
||||||
|
alignment = []
|
||||||
|
while ( i !=0 and j != 0 ):
|
||||||
|
j_offset = j - search_path[i][0]
|
||||||
|
a = b[i][j_offset]
|
||||||
|
s = a_types[a][0]
|
||||||
|
t = a_types[a][1]
|
||||||
|
src_range = [i - offset - 1 for offset in range(s)][::-1]
|
||||||
|
tgt_range = [j - offset - 1 for offset in range(t)][::-1]
|
||||||
|
alignment.append((src_range, tgt_range))
|
||||||
|
|
||||||
|
i = i-s
|
||||||
|
j = j-t
|
||||||
|
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False):
|
def get_score(src_v, tgt_v,
|
||||||
|
a_1, a_2,
|
||||||
|
i, j,
|
||||||
|
src_vecs, tgt_vecs,
|
||||||
|
src_len, tgt_len,
|
||||||
|
margin=False):
|
||||||
similarity = nb_dot(src_v, tgt_v)
|
similarity = nb_dot(src_v, tgt_v)
|
||||||
if margin:
|
if margin:
|
||||||
tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs)
|
tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs)
|
||||||
@@ -229,10 +254,17 @@ def get_neighbor_sim(vec, a, j, len, db):
|
|||||||
def nb_dot(x, y):
|
def nb_dot(x, y):
|
||||||
return np.dot(x,y)
|
return np.dot(x,y)
|
||||||
|
|
||||||
def find_second_search_path(align, w, src_len, tgt_len):
|
def find_second_path(align, w, src_len, tgt_len):
|
||||||
'''
|
'''
|
||||||
Convert 1-1 alignment from first-pass to the path for second-pass alignment.
|
Convert 1-1 alignment from first-pass to the path for second-pass alignment.
|
||||||
The index along X-axis and Y-axis must be consecutive.
|
The indices along X-axis and Y-axis must be consecutive.
|
||||||
|
Args:
|
||||||
|
align: list of tuples. First-pass alignment results.
|
||||||
|
w: int. Predefined window size for the second path.
|
||||||
|
src_len: int. Number of source sentences.
|
||||||
|
tgt_len: int. Number of target sentences.
|
||||||
|
Returns:
|
||||||
|
path: numpy array for the second search path.
|
||||||
'''
|
'''
|
||||||
last_bead_src = align[-1][0]
|
last_bead_src = align[-1][0]
|
||||||
last_bead_tgt = align[-1][1]
|
last_bead_tgt = align[-1][1]
|
||||||
@@ -262,13 +294,23 @@ def find_second_search_path(align, w, src_len, tgt_len):
|
|||||||
return max_w + 1, np.array(path)
|
return max_w + 1, np.array(path)
|
||||||
|
|
||||||
def first_back_track(i, j, b, search_path, a_types):
|
def first_back_track(i, j, b, search_path, a_types):
|
||||||
|
"""
|
||||||
|
Retrieve 1-1 alignments from the first-pass DP table.
|
||||||
|
Args:
|
||||||
|
i: int. Number of source sentences.
|
||||||
|
j: int. Number of target sentences.
|
||||||
|
search_path: numpy array. First-pass search path.
|
||||||
|
a_types: numpy array. First-pass alignment types.
|
||||||
|
Returns:
|
||||||
|
alignment: list of tuples for 1-1 alignments.
|
||||||
|
"""
|
||||||
alignment = []
|
alignment = []
|
||||||
while ( i !=0 and j != 0 ):
|
while ( i !=0 and j != 0 ):
|
||||||
j_offset = j - search_path[i][0]
|
j_offset = j - search_path[i][0]
|
||||||
a = b[i][j_offset]
|
a = b[i][j_offset]
|
||||||
s = a_types[a][0]
|
s = a_types[a][0]
|
||||||
t = a_types[a][1]
|
t = a_types[a][1]
|
||||||
if a == 2:
|
if a == 2: # best 1-1 alignment
|
||||||
alignment.append((i, j))
|
alignment.append((i, j))
|
||||||
|
|
||||||
i = i-s
|
i = i-s
|
||||||
@@ -277,9 +319,29 @@ def first_back_track(i, j, b, search_path, a_types):
|
|||||||
return alignment[::-1]
|
return alignment[::-1]
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k):
|
def first_pass_align(src_len,
|
||||||
|
tgt_len,
|
||||||
#initialize cost and backpointer matrix
|
w,
|
||||||
|
search_path,
|
||||||
|
align_types,
|
||||||
|
dist,
|
||||||
|
index,
|
||||||
|
top_k):
|
||||||
|
"""
|
||||||
|
Perform the first-pass alignment to extract 1-1 bitext segments.
|
||||||
|
Args:
|
||||||
|
src_len: int. Number of source sentences.
|
||||||
|
tgt_len: int. Number of target sentences.
|
||||||
|
w: int. Window size for the first-pass alignment.
|
||||||
|
search_path: numpy array. Search path for the first-pass alignment.
|
||||||
|
align_types: numpy array. Alignment types for the first-pass alignment.
|
||||||
|
dist: numpy array. Distance matrix for top-k similar vecs.
|
||||||
|
index: numpy array. Index matrix for top-k similar vecs.
|
||||||
|
top_k: int. Number of most similar top-k vecs.
|
||||||
|
Returns:
|
||||||
|
pointers: numpy array recording best alignments for each DP cell.
|
||||||
|
"""
|
||||||
|
# Initialize cost and backpointer matrix.
|
||||||
cost = np.zeros((src_len + 1, 2 * w + 1))
|
cost = np.zeros((src_len + 1, 2 * w + 1))
|
||||||
pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64)
|
pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64)
|
||||||
cost[0][0] = 0
|
cost[0][0] = 0
|
||||||
@@ -323,29 +385,92 @@ def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index,
|
|||||||
|
|
||||||
return pointers
|
return pointers
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
def find_first_search_path(src_len,
|
||||||
def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100):
|
tgt_len,
|
||||||
|
min_win_size = 250,
|
||||||
|
percent=0.06):
|
||||||
|
"""
|
||||||
|
Find the window size and search path for the first-pass alignment.
|
||||||
|
Args:
|
||||||
|
src_len: int. Number of source sentences.
|
||||||
|
tgt_len: int. Number of target sentences.
|
||||||
|
min_win_size: int. Minimum window size.
|
||||||
|
percent. float. Percent of longer sentences.
|
||||||
|
Returns:
|
||||||
|
win_size: int. Window size along the diagonal of the DP table.
|
||||||
|
search_path: numpy array of shape (src_len + 1, 2), containing the start
|
||||||
|
and end index of target sentences for each source sentence.
|
||||||
|
One extra row is added in the search_path for calculation of
|
||||||
|
deletions and omissions.
|
||||||
|
"""
|
||||||
|
win_size = max(min_win_size, int(max(src_len, tgt_len) * percent))
|
||||||
|
search_path = []
|
||||||
yx_ratio = tgt_len / src_len
|
yx_ratio = tgt_len / src_len
|
||||||
win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100)
|
|
||||||
win_size_2 = int(abs(tgt_len - src_len) * 3/4)
|
|
||||||
w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size)
|
|
||||||
w_2 = int(max(src_len, tgt_len) * 0.06)
|
|
||||||
w = max(w_1, w_2)
|
|
||||||
search_path = np.zeros((src_len + 1, 2), dtype=nb.int64)
|
|
||||||
for i in range(0, src_len + 1):
|
for i in range(0, src_len + 1):
|
||||||
center = int(yx_ratio * i)
|
center = int(yx_ratio * i)
|
||||||
w_start = max(0, center - w)
|
win_start = max(0, center - win_size)
|
||||||
w_end = min(center + w, tgt_len)
|
win_end = min(center + win_size, tgt_len)
|
||||||
search_path[i] = [w_start, w_end]
|
search_path.append([win_start, win_end])
|
||||||
|
return win_size, np.array(search_path)
|
||||||
|
|
||||||
return w, search_path
|
def get_alignment_types(max_alignment_size):
|
||||||
|
"""
|
||||||
|
Get all the possible alignment types.
|
||||||
|
Args:
|
||||||
|
max_alignment_size: int. Source sentences number +
|
||||||
|
Target sentences number <= this value.
|
||||||
|
Returns:
|
||||||
|
alignment_types: numpy array.
|
||||||
|
"""
|
||||||
|
alignment_types = [[0,1], [1,0]]
|
||||||
|
for x in range(1, max_alignment_size):
|
||||||
|
for y in range(1, max_alignment_size):
|
||||||
|
if x + y <= max_alignment_size:
|
||||||
|
alignment_types.append([x, y])
|
||||||
|
return np.array(alignment_types)
|
||||||
|
|
||||||
|
def find_top_k_sents(src_vecs, tgt_vecs, k=3):
|
||||||
|
"""
|
||||||
|
Find the top_k similar vecs in tgt_vecs for each vec in src_vecs.
|
||||||
|
Args:
|
||||||
|
src_vecs: numpy array of shape (num_src_sents, embedding_size)
|
||||||
|
tgt_vecs: numpy array of shape (num_tgt_sents, embedding_size)
|
||||||
|
k: int. Number of most similar target sentences.
|
||||||
|
Returns:
|
||||||
|
D: numpy array. Similarity score matrix of shape (num_src_sents, k).
|
||||||
|
I: numpy array. Target index matrix of shape (num_src_sents, k).
|
||||||
|
"""
|
||||||
|
embedding_size = src_vecs.shape[1]
|
||||||
|
if torch.cuda.is_available(): # GPU version
|
||||||
|
res = faiss.StandardGpuResources()
|
||||||
|
index = faiss.IndexFlatIP(embedding_size)
|
||||||
|
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
|
||||||
|
gpu_index.add(tgt_vecs)
|
||||||
|
D, I = gpu_index.search(src_vecs, k)
|
||||||
|
else: # CPU version
|
||||||
|
index = faiss.IndexFlatIP(embedding_size)
|
||||||
|
index.add(tgt_vecs)
|
||||||
|
D, I = index.search(src_vecs, k)
|
||||||
|
return D, I
|
||||||
|
|
||||||
def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
|
def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
|
||||||
|
"""
|
||||||
|
Convert texts into feature matrix.
|
||||||
|
Args:
|
||||||
|
sent2line: dict. Map each sentence to its ID.
|
||||||
|
line_embeddings: numpy array of sentence embeddings.
|
||||||
|
lines: list of sentences.
|
||||||
|
num_overlaps: int. Maximum number of overlapping sentences allowed.
|
||||||
|
Returns:
|
||||||
|
vecs0: numpy array of shape (num_overlaps, num_lines, size_embedding)
|
||||||
|
for overlapping sentence embeddings.
|
||||||
|
vecs1: numpy array of shape (num_overlap, num_lines)
|
||||||
|
for overlapping sentence lengths.
|
||||||
|
"""
|
||||||
lines = [preprocess_line(line) for line in lines]
|
lines = [preprocess_line(line) for line in lines]
|
||||||
vecsize = line_embeddings.shape[1]
|
vecsize = line_embeddings.shape[1]
|
||||||
vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32)
|
vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32)
|
||||||
vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int)
|
vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int)
|
||||||
|
|
||||||
for ii, overlap in enumerate(range(1, num_overlaps + 1)):
|
for ii, overlap in enumerate(range(1, num_overlaps + 1)):
|
||||||
for jj, out_line in enumerate(layer(lines, overlap)):
|
for jj, out_line in enumerate(layer(lines, overlap)):
|
||||||
try:
|
try:
|
||||||
@@ -353,96 +478,91 @@ def doc2feats(sent2line, line_embeddings, lines, num_overlaps):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line)
|
logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line)
|
||||||
line_id = None
|
line_id = None
|
||||||
|
|
||||||
if line_id is not None:
|
if line_id is not None:
|
||||||
vec = line_embeddings[line_id]
|
vec = line_embeddings[line_id]
|
||||||
else:
|
else:
|
||||||
vec = np.random.random(vecsize) - 0.5
|
vec = np.random.random(vecsize) - 0.5
|
||||||
vec = vec / np.linalg.norm(vec)
|
vec = vec / np.linalg.norm(vec)
|
||||||
|
|
||||||
vecs0[ii, jj, :] = vec
|
vecs0[ii, jj, :] = vec
|
||||||
vecs1[ii, jj] = len(out_line.encode("utf-8"))
|
vecs1[ii, jj] = len(out_line.encode("utf-8"))
|
||||||
|
|
||||||
return vecs0, vecs1
|
return vecs0, vecs1
|
||||||
|
|
||||||
def preprocess_line(line):
|
|
||||||
line = line.strip()
|
|
||||||
if len(line) == 0:
|
|
||||||
line = 'BLANK_LINE'
|
|
||||||
|
|
||||||
return line
|
|
||||||
|
|
||||||
def layer(lines, num_overlaps, comb=' '):
|
def layer(lines, num_overlaps, comb=' '):
|
||||||
"""
|
"""
|
||||||
make front-padded overlapping sentences
|
Make front-padded overlapping sentences.
|
||||||
"""
|
"""
|
||||||
if num_overlaps < 1:
|
if num_overlaps < 1:
|
||||||
raise Exception('num_overlaps must be >= 1')
|
raise Exception('num_overlaps must be >= 1')
|
||||||
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
||||||
for ii in range(len(lines) - num_overlaps + 1):
|
for ii in range(len(lines) - num_overlaps + 1):
|
||||||
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def preprocess_line(line):
|
||||||
|
"""
|
||||||
|
Clean each line of the text.
|
||||||
|
"""
|
||||||
|
line = line.strip()
|
||||||
|
if len(line) == 0:
|
||||||
|
line = 'BLANK_LINE'
|
||||||
|
return line
|
||||||
|
|
||||||
def read_in_embeddings(text_file, embed_file):
|
def read_in_embeddings(text_file, embed_file):
|
||||||
|
"""
|
||||||
|
Read in the overlap lines and line embeddings.
|
||||||
|
Args:
|
||||||
|
text_file: str. Overlap file path.
|
||||||
|
embed_file: str. Embedding file path.
|
||||||
|
Returns:
|
||||||
|
sent2line: dict. Map overlap sentences to line IDs.
|
||||||
|
line_embeddings: numpy array of the shape (num_lines, embedding_size).
|
||||||
|
For sentence-transformers, the embedding_size is 768.
|
||||||
|
"""
|
||||||
sent2line = dict()
|
sent2line = dict()
|
||||||
with open(text_file, 'rt', encoding="utf-8") as fin:
|
with open(text_file, 'rt', encoding="utf-8") as f:
|
||||||
for ii, line in enumerate(fin):
|
for i, line in enumerate(f):
|
||||||
if line.strip() in sent2line:
|
sent2line[line.strip()] = i
|
||||||
raise Exception('got multiple embeddings for the same line')
|
line_embeddings = np.fromfile(embed_file, dtype=np.float32)
|
||||||
sent2line[line.strip()] = ii
|
|
||||||
|
|
||||||
line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1)
|
|
||||||
if line_embeddings.size == 0:
|
|
||||||
raise Exception('Got empty embedding file')
|
|
||||||
|
|
||||||
embedding_size = line_embeddings.size // len(sent2line)
|
embedding_size = line_embeddings.size // len(sent2line)
|
||||||
line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size)
|
line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size)
|
||||||
|
|
||||||
return sent2line, line_embeddings
|
return sent2line, line_embeddings
|
||||||
|
|
||||||
def make_alignment_types(max_alignment_size):
|
def create_jobs(meta_data_file, src_dir, tgt_dir, alignment_dir):
|
||||||
# Return list of all (n,m) where n+m <= this
|
"""
|
||||||
alignment_types = []
|
Creat a job list consisting of source, target and alignment file paths.
|
||||||
for x in range(1, max_alignment_size):
|
"""
|
||||||
for y in range(1, max_alignment_size):
|
|
||||||
if x + y <= max_alignment_size:
|
|
||||||
alignment_types.append([x, y])
|
|
||||||
alignment_types = [[0,1], [1,0]] + alignment_types
|
|
||||||
|
|
||||||
return np.array(alignment_types)
|
|
||||||
|
|
||||||
def create_jobs(meta, src, tgt, out):
|
|
||||||
jobs = []
|
jobs = []
|
||||||
fns = get_fns(meta)
|
text_ids = get_text_ids(meta_data_file)
|
||||||
for file in fns:
|
for id in text_ids:
|
||||||
src_path = os.path.abspath(os.path.join(src, file))
|
src_path = os.path.abspath(os.path.join(src_dir, id))
|
||||||
tgt_path = os.path.abspath(os.path.join(tgt, file))
|
tgt_path = os.path.abspath(os.path.join(tgt_dir, id))
|
||||||
|
out_path = os.path.abspath(os.path.join(alignment_dir, id + '.align'))
|
||||||
out_path = os.path.abspath(os.path.join(out, file + '.align'))
|
|
||||||
jobs.append('\t'.join([src_path, tgt_path, out_path]))
|
jobs.append('\t'.join([src_path, tgt_path, out_path]))
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
def get_fns(meta):
|
def get_text_ids(meta_data_file):
|
||||||
fns = []
|
"""
|
||||||
with open(meta, 'rt', encoding='utf-8') as f:
|
Get the text IDs to be aligned.
|
||||||
|
Args:
|
||||||
|
meta_data_file: str. TSV file with the first column being text ID.
|
||||||
|
Returns:
|
||||||
|
text_ids: list.
|
||||||
|
"""
|
||||||
|
text_ids = []
|
||||||
|
with open(meta_data_file, 'rt', encoding='utf-8') as f:
|
||||||
next(f) # skip header
|
next(f) # skip header
|
||||||
for line in f:
|
for line in f:
|
||||||
recs = line.strip().split('\t')
|
recs = line.strip().split('\t')
|
||||||
fns.append(recs[0])
|
text_ids.append(recs[0])
|
||||||
|
return text_ids
|
||||||
|
|
||||||
return fns
|
def make_dir(auto_alignment_path):
|
||||||
|
"""
|
||||||
def print_alignments(alignments, out):
|
Make an empty diretory for saving automatic alignment results.
|
||||||
with open(out, 'wt', encoding='utf-8') as f:
|
"""
|
||||||
for x, y in alignments:
|
if os.path.isdir(auto_alignment_path):
|
||||||
f.write("{}:{}\n".format(x, y))
|
shutil.rmtree(auto_alignment_path)
|
||||||
|
os.makedirs(auto_alignment_path, exist_ok=True)
|
||||||
def make_dir(path):
|
|
||||||
if os.path.isdir(path):
|
|
||||||
shutil.rmtree(path)
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
t_0 = time.time()
|
t_0 = time.time()
|
||||||
|
|||||||
Reference in New Issue
Block a user