commit 0a92061119d8d3715fee97d584f73622b08114db Author: bfsujason Date: Sun May 16 02:15:28 2021 +0800 Add files via upload diff --git a/bert_align.py b/bert_align.py new file mode 100644 index 0000000..5b8fc75 --- /dev/null +++ b/bert_align.py @@ -0,0 +1,382 @@ +import argparse +import os +import sys +import numpy as np +import numba as nb +import faiss + +def _main(): + # user-defined parameters + parser = argparse.ArgumentParser('Multilingual sentence alignment using BERT embeddings', + formatter_class = argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--job', type=str, required=True, help='Job file for alignment task.') + parser.add_argument('--src_embed', type=str, required=True, nargs=2, help='Source overlap and embedding files.') + parser.add_argument('--tgt_embed', type=str, required=True, nargs=2, help='Target overlap and embedding files.') + parser.add_argument('--max_align', type=int, default=8, help='Maximum alignment types, n + m <= this value.') + parser.add_argument('--win', type=int, default=5, help='Window size for the second-pass alignment.') + parser.add_argument('--top_k', type=int, default=3, help='Top-k target neighbors of each source sentence.') + parser.add_argument('--skip', type=float, default=-0.1, help='Similarity score for 0-1 and 1-0 alignment.') + parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity') + args = parser.parse_args() + + # fixed parameters to determine the + # window size for the first-pass alignment + min_win_size = 10 + max_win_size = 600 + win_per_100 = 8 + + # read in embeddings + src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1]) + tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1]) + embedding_size = src_line_embeddings.shape[1] + + # read in alignment jobs + job = read_job(args.job) + + # start alignment + for rec in job: + src_file, tgt_file, align_file = rec.split("\t") + print("Aligning {} to {}".format(src_file, tgt_file)) + + # read in source and target sentences + src_lines = open(src_file, 'rt', encoding="utf-8").readlines() + tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines() + + # convert source and target texts into embeddings + # and calculate sentence length + src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1) + tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1) + char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) + + # using faiss, find in the target text + # the k nearest neighbors of each source sentence + index = faiss.IndexFlatIP(embedding_size) # use inter product to build index + index.add(tgt_vecs[0,:]) + xq = src_vecs[0,:] + D,I = index.search(xq, args.top_k) + + # find 1-to-1 alignment + src_len = len(src_lines) + tgt_len = len(tgt_lines) + first_alignment_types = make_alignment_types(2) # 0-0, 1-0 and 1-1 + first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100) + first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k) + first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types) + + # find m-to-n alignment + second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len) + second_alignment_types = make_alignment_types(args.max_align) + second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin) + second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types) + + # save alignment + out_f = open(align_file, 'w', encoding="utf-8") + print_alignments(second_alignment, file=out_f) + +def print_alignments(alignments, file=sys.stdout): + for x, y in alignments: + print('%s:%s' % (x, y), file=file) + +def second_back_track(i, j, b, search_path, a_types): + alignment = [] + while ( i !=0 and j != 0 ): + j_offset = j - search_path[i][0] + a = b[i][j_offset] + s = a_types[a][0] + t = a_types[a][1] + src_range = [i - offset - 1 for offset in range(s)][::-1] + tgt_range = [j - offset - 1 for offset in range(t)][::-1] + alignment.append((src_range, tgt_range)) + + i = i-s + j = j-t + + return alignment[::-1] + +@nb.jit(nopython=True, fastmath=True, cache=True) +def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False): + src_len = src_vecs.shape[1] + tgt_len = tgt_vecs.shape[1] + + # intialize sum matrix + cost = np.zeros((src_len + 1, w)) + back = np.zeros((tgt_len + 1, w), dtype=nb.int64) + cost[0][0] = 0 + back[0][0] = -1 + + for i in range(1, src_len + 1): + i_start = search_path[i][0] + i_end = search_path[i][1] + for j in range(i_start, i_end + 1): + if i + j == 0: + continue + best_score = -np.inf + best_a = -1 + for a in range(align_types.shape[0]): + a_1 = align_types[a][0] + a_2 = align_types[a][1] + prev_i = i - a_1 + prev_j = j - a_2 + + if prev_i < 0 or prev_j < 0 : # no previous cell in DP table + continue + prev_i_start = search_path[prev_i][0] + prev_i_end = search_path[prev_i][1] + if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix + continue + prev_j_offset = prev_j - prev_i_start + score = cost[prev_i][prev_j_offset] + if score == -np.inf: + continue + + if a_1 == 0 or a_2 == 0: # deletion or insertion + cur_score = skip + else: + src_v = src_vecs[a_1-1,i-1,:] + tgt_v = tgt_vecs[a_2-1,j-1,:] + src_l = src_lens[a_1-1, i-1] + tgt_l = tgt_lens[a_2-1, j-1] + cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin) + tgt_l = tgt_l * char_ratio + min_len = min(src_l, tgt_l) + max_len = max(src_l, tgt_l) + len_p = np.log2(1 + min_len / max_len) + cur_score *= len_p + + score += cur_score + if score > best_score: + best_score = score + best_a = a + + j_offset = j - i_start + cost[i][j_offset] = best_score + back[i][j_offset] = best_a + + return back + +@nb.jit(nopython=True, fastmath=True, cache=True) +def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False): + similarity = nb_dot(src_v, tgt_v) + if margin: + tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs) + src_neighbor_ave_sim = get_neighbor_sim(tgt_v, a_1, i, src_len, src_vecs) + neighbor_ave_sim = (tgt_neighbor_ave_sim + src_neighbor_ave_sim)/2 + similarity -= neighbor_ave_sim + return similarity + +@nb.jit(nopython=True, fastmath=True, cache=True) +def get_neighbor_sim(vec, a, j, len, db): + left_idx = j - a + right_idx = j + 1 + + if right_idx > len: + neighbor_right_sim = 0 + else: + right_embed = db[0,right_idx-1,:] + neighbor_right_sim = nb_dot(vec, right_embed) + + if left_idx == 0: + neighbor_left_sim = 0 + else: + left_embed = db[0,left_idx-1,:] + neighbor_left_sim = nb_dot(vec, left_embed) + + #if right_idx > LEN or left_idx < 0: + if right_idx > len or left_idx == 0: + neighbor_ave_sim = neighbor_left_sim + neighbor_right_sim + else: + neighbor_ave_sim = (neighbor_left_sim + neighbor_right_sim) / 2 + + return neighbor_ave_sim + +@nb.jit(nopython=True, fastmath=True, cache=True) +def nb_dot(x, y): + return np.dot(x,y) + +def find_second_search_path(align, w, src_len, tgt_len): + ''' + Convert 1-1 alignment from first-pass to the path for second-pass alignment. + The index along X-axis and Y-axis must be consecutive. + ''' + last_bead_src = align[-1][0] + last_bead_tgt = align[-1][1] + + if last_bead_src != src_len: + if last_bead_tgt == tgt_len: + align.pop() + align.append((src_len, tgt_len)) + else: + if last_bead_tgt != tgt_len: + align.pop() + align.append((src_len, tgt_len)) + + prev_src, prev_tgt = 0,0 + path = [] + max_w = -np.inf + for src, tgt in align: + lower_bound = max(0, prev_tgt - w) + upper_bound = min(tgt_len, tgt + w) + path.extend([(lower_bound, upper_bound) for id in range(prev_src+1, src+1)]) + prev_src, prev_tgt = src, tgt + width = upper_bound - lower_bound + if width > max_w: + max_w = width + path = [path[0]] + path + return max_w + 1, np.array(path) + +def first_back_track(i, j, b, search_path, a_types): + alignment = [] + while ( i !=0 and j != 0 ): + j_offset = j - search_path[i][0] + a = b[i][j_offset] + s = a_types[a][0] + t = a_types[a][1] + if a == 2: + alignment.append((i, j)) + + i = i-s + j = j-t + + return alignment[::-1] + +@nb.jit(nopython=True, fastmath=True, cache=True) +def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k): + + #initialize cost and backpointer matrix + cost = np.zeros((src_len + 1, 2 * w + 1)) + pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64) + cost[0][0] = 0 + pointers[0][0] = -1 + + for i in range(1, src_len + 1): + i_start = search_path[i][0] + i_end = search_path[i][1] + for j in range(i_start, i_end + 1): + if i + j == 0: + continue + best_score = -np.inf + best_a = -1 + for a in range(align_types.shape[0]): + a_1 = align_types[a][0] + a_2 = align_types[a][1] + prev_i = i - a_1 + prev_j = j - a_2 + if prev_i < 0 or prev_j < 0 : # no previous cell + continue + prev_i_start = search_path[prev_i][0] + prev_i_end = search_path[prev_i][1] + if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix + continue + prev_j_offset = prev_j - prev_i_start + score = cost[prev_i][prev_j_offset] + if score == -np.inf: + continue + + if a_1 > 0 and a_2 > 0: + for k in range(top_k): + if index[i-1][k] == j - 1: + score += dist[i-1][k] + if score > best_score: + best_score = score + best_a = a + + j_offset = j - i_start + cost[i][j_offset] = best_score + pointers[i][j_offset] = best_a + + return pointers + +@nb.jit(nopython=True, fastmath=True, cache=True) +def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100): + yx_ratio = tgt_len / src_len + win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100) + win_size_2 = int(abs(tgt_len - src_len) * 3/4) + w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size) + w_2 = int(max(src_len, tgt_len) * 0.05) + w = max(w_1, w_2) + search_path = np.zeros((src_len + 1, 2), dtype=nb.int64) + for i in range(0, src_len + 1): + center = int(yx_ratio * i) + w_start = max(0, center - w) + w_end = min(center + w, tgt_len) + search_path[i] = [w_start, w_end] + return w, search_path + +def doc2feats(sent2line, line_embeddings, lines, num_overlaps): + lines = [preprocess_line(line) for line in lines] + vecsize = line_embeddings.shape[1] + vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32) + vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int) + + for ii, overlap in enumerate(range(1, num_overlaps + 1)): + for jj, out_line in enumerate(layer(lines, overlap)): + try: + line_id = sent2line[out_line] + except KeyError: + logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line) + line_id = None + + if line_id is not None: + vec = line_embeddings[line_id] + else: + vec = np.random.random(vecsize) - 0.5 + vec = vec / np.linalg.norm(vec) + + vecs0[ii, jj, :] = vec + vecs1[ii, jj] = len(out_line.encode("utf-8")) + + return vecs0, vecs1 + +def preprocess_line(line): + line = line.strip() + if len(line) == 0: + line = 'BLANK_LINE' + return line + +def layer(lines, num_overlaps, comb=' '): + """ + make front-padded overlapping sentences + """ + if num_overlaps < 1: + raise Exception('num_overlaps must be >= 1') + out = ['PAD', ] * min(num_overlaps - 1, len(lines)) + for ii in range(len(lines) - num_overlaps + 1): + out.append(comb.join(lines[ii:ii + num_overlaps])) + return out + +def read_in_embeddings(text_file, embed_file): + sent2line = dict() + with open(text_file, 'rt', encoding="utf-8") as fin: + for ii, line in enumerate(fin): + if line.strip() in sent2line: + raise Exception('got multiple embeddings for the same line') + sent2line[line.strip()] = ii + + line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1) + if line_embeddings.size == 0: + raise Exception('Got empty embedding file') + + embedding_size = line_embeddings.size // len(sent2line) + line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size) + return sent2line, line_embeddings + +def make_alignment_types(max_alignment_size): + # Return list of all (n,m) where n+m <= this + alignment_types = [] + for x in range(1, max_alignment_size): + for y in range(1, max_alignment_size): + if x + y <= max_alignment_size: + alignment_types.append([x, y]) + alignment_types = [[0,1], [1,0]] + alignment_types + return np.array(alignment_types) + +def read_job(file): + job = [] + with open(file, 'r', encoding="utf-8") as f: + for line in f: + if not line.startswith("#"): + job.append(line.strip()) + return job + +if __name__ == '__main__': + _main()