diff --git a/bin/bert_align.py b/bin/bert_align.py new file mode 100644 index 0000000..d61814b --- /dev/null +++ b/bin/bert_align.py @@ -0,0 +1,450 @@ +# 2021/11/27 +# bfsujason@163.com + +""" +Usage: + +python bin/bert_align.py \ + -s data/mac/dev/zh \ + -t data/mac/dev/en \ + -o data/mac/dev/auto \ + -m data/mac/dev/meta_data.tsv \ + --src_embed data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \ + --tgt_embed data/mac/dev/en/overlap data/mac/dev/en/overlap.emb \ + --max_align 8 --margin +""" + +import os +import sys +import time +import torch +import faiss +import shutil +import argparse +import numpy as np +import numba as nb + +def main(): + # user-defined parameters + parser = argparse.ArgumentParser('Sentence alignment using Vecalign') + parser.add_argument('-s', '--src', type=str, required=True, help='preprocessed source file to align') + parser.add_argument('-t', '--tgt', type=str, required=True, help='preprocessed target file to align') + parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.') + parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.') + parser.add_argument('--src_embed', type=str, nargs=2, required=True, + help='Source embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ') + parser.add_argument('--tgt_embed', type=str, nargs=2, required=True, + help='Target embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ') + parser.add_argument('--max_align', type=int, default=5, help='Maximum alignment types, n + m <= this value.') + parser.add_argument('--win', type=int, default=5, help='Window size for the second-pass alignment.') + parser.add_argument('--top_k', type=int, default=3, help='Top-k target neighbors of each source sentence.') + parser.add_argument('--skip', type=float, default=-0.1, help='Similarity score for 0-1 and 1-0 alignment.') + parser.add_argument('--margin', action='store_true', help='Margin-based cosine similarity') + args = parser.parse_args() + + # fixed parameters to determine the + # window size for the first-pass alignment + min_win_size = 10 + max_win_size = 600 + win_per_100 = 8 + + # read in embeddings + src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1]) + tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1]) + embedding_size = src_line_embeddings.shape[1] + + make_dir(args.out) + jobs = create_jobs(args.meta, args.src, args.tgt, args.out) + + # start alignment + for rec in jobs: + src_file, tgt_file, align_file = rec.split("\t") + print("Aligning {} to {}".format(src_file, tgt_file)) + + # read in source and target sentences + src_lines = open(src_file, 'rt', encoding="utf-8").readlines() + tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines() + + # convert source and target texts into embeddings + # and calculate sentence length + t_0 = time.time() + src_vecs, src_lens = doc2feats(src_sent2line, src_line_embeddings, src_lines, args.max_align - 1) + tgt_vecs, tgt_lens = doc2feats(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.max_align - 1) + char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) + print("Reading embeddings takes {:.3f}".format(time.time() - t_0)) + + # using faiss, find in the target text + # the k nearest neighbors of each source sentence + t_1 = time.time() + if torch.cuda.is_available(): # GPU version + res = faiss.StandardGpuResources() + index = faiss.IndexFlatIP(embedding_size) + gpu_index = faiss.index_cpu_to_gpu(res, 0, index) + gpu_index.add(tgt_vecs[0,:]) + xq = src_vecs[0,:] + D,I = gpu_index.search(xq,args.top_k) + else: # CPU version + index = faiss.IndexFlatIP(embedding_size) # use inter product to build index + index.add(tgt_vecs[0,:]) + xq = src_vecs[0,:] + D,I = index.search(xq, args.top_k) + print("Finding top-k neighbors takes {:.3f}".format(time.time() - t_1)) + + # find 1-to-1 alignment + t_2 = time.time() + src_len = len(src_lines) + tgt_len = len(tgt_lines) + first_alignment_types = make_alignment_types(2) # 0-0, 1-0 and 1-1 + first_w, first_search_path = find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100) + first_pointers = first_pass_align(src_len, tgt_len, first_w, first_search_path, first_alignment_types, D, I, args.top_k) + first_alignment = first_back_track(src_len, tgt_len, first_pointers, first_search_path, first_alignment_types) + print("First pass alignment takes {:.3f}".format(time.time() - t_2)) + + # find m-to-n alignment + t_3 = time.time() + second_w, second_search_path = find_second_search_path(first_alignment, args.win, src_len, tgt_len) + second_alignment_types = make_alignment_types(args.max_align) + second_pointers = second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, second_w, second_search_path, second_alignment_types, char_ratio, args.skip, margin=args.margin) + second_alignment = second_back_track(src_len, tgt_len, second_pointers, second_search_path, second_alignment_types) + print("Second pass alignment takes {:.3f}".format(time.time() - t_3)) + + # save alignment + print_alignments(second_alignment, align_file) + +def second_back_track(i, j, b, search_path, a_types): + alignment = [] + while ( i !=0 and j != 0 ): + j_offset = j - search_path[i][0] + a = b[i][j_offset] + s = a_types[a][0] + t = a_types[a][1] + src_range = [i - offset - 1 for offset in range(s)][::-1] + tgt_range = [j - offset - 1 for offset in range(t)][::-1] + alignment.append((src_range, tgt_range)) + + i = i-s + j = j-t + + return alignment[::-1] + +@nb.jit(nopython=True, fastmath=True, cache=True) +def second_pass_align(src_vecs, tgt_vecs, src_lens, tgt_lens, w, search_path, align_types, char_ratio, skip, margin=False): + src_len = src_vecs.shape[1] + tgt_len = tgt_vecs.shape[1] + + # intialize sum matrix + cost = np.zeros((src_len + 1, w)) + #back = np.zeros((tgt_len + 1, w), dtype=nb.int64) + back = np.zeros((src_len + 1, w), dtype=nb.int64) + cost[0][0] = 0 + back[0][0] = -1 + + for i in range(1, src_len + 1): + i_start = search_path[i][0] + i_end = search_path[i][1] + for j in range(i_start, i_end + 1): + if i + j == 0: + continue + best_score = -np.inf + best_a = -1 + for a in range(align_types.shape[0]): + a_1 = align_types[a][0] + a_2 = align_types[a][1] + prev_i = i - a_1 + prev_j = j - a_2 + + if prev_i < 0 or prev_j < 0 : # no previous cell in DP table + continue + prev_i_start = search_path[prev_i][0] + prev_i_end = search_path[prev_i][1] + if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix + continue + prev_j_offset = prev_j - prev_i_start + score = cost[prev_i][prev_j_offset] + if score == -np.inf: + continue + + if a_1 == 0 or a_2 == 0: # deletion or insertion + cur_score = skip + else: + src_v = src_vecs[a_1-1,i-1,:] + tgt_v = tgt_vecs[a_2-1,j-1,:] + src_l = src_lens[a_1-1, i-1] + tgt_l = tgt_lens[a_2-1, j-1] + cur_score = get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=margin) + tgt_l = tgt_l * char_ratio + min_len = min(src_l, tgt_l) + max_len = max(src_l, tgt_l) + len_p = np.log2(1 + min_len / max_len) + cur_score *= len_p + + score += cur_score + if score > best_score: + best_score = score + best_a = a + + j_offset = j - i_start + cost[i][j_offset] = best_score + back[i][j_offset] = best_a + + return back + +@nb.jit(nopython=True, fastmath=True, cache=True) +def get_score(src_v, tgt_v, a_1, a_2, i, j, src_vecs, tgt_vecs, src_len, tgt_len, margin=False): + similarity = nb_dot(src_v, tgt_v) + if margin: + tgt_neighbor_ave_sim = get_neighbor_sim(src_v, a_2, j, tgt_len, tgt_vecs) + src_neighbor_ave_sim = get_neighbor_sim(tgt_v, a_1, i, src_len, src_vecs) + neighbor_ave_sim = (tgt_neighbor_ave_sim + src_neighbor_ave_sim)/2 + similarity -= neighbor_ave_sim + + return similarity + +@nb.jit(nopython=True, fastmath=True, cache=True) +def get_neighbor_sim(vec, a, j, len, db): + left_idx = j - a + right_idx = j + 1 + + if right_idx > len: + neighbor_right_sim = 0 + else: + right_embed = db[0,right_idx-1,:] + neighbor_right_sim = nb_dot(vec, right_embed) + + if left_idx == 0: + neighbor_left_sim = 0 + else: + left_embed = db[0,left_idx-1,:] + neighbor_left_sim = nb_dot(vec, left_embed) + + #if right_idx > LEN or left_idx < 0: + if right_idx > len or left_idx == 0: + neighbor_ave_sim = neighbor_left_sim + neighbor_right_sim + else: + neighbor_ave_sim = (neighbor_left_sim + neighbor_right_sim) / 2 + + return neighbor_ave_sim + +@nb.jit(nopython=True, fastmath=True, cache=True) +def nb_dot(x, y): + return np.dot(x,y) + +def find_second_search_path(align, w, src_len, tgt_len): + ''' + Convert 1-1 alignment from first-pass to the path for second-pass alignment. + The index along X-axis and Y-axis must be consecutive. + ''' + last_bead_src = align[-1][0] + last_bead_tgt = align[-1][1] + + if last_bead_src != src_len: + if last_bead_tgt == tgt_len: + align.pop() + align.append((src_len, tgt_len)) + else: + if last_bead_tgt != tgt_len: + align.pop() + align.append((src_len, tgt_len)) + + prev_src, prev_tgt = 0,0 + path = [] + max_w = -np.inf + for src, tgt in align: + lower_bound = max(0, prev_tgt - w) + upper_bound = min(tgt_len, tgt + w) + path.extend([(lower_bound, upper_bound) for id in range(prev_src+1, src+1)]) + prev_src, prev_tgt = src, tgt + width = upper_bound - lower_bound + if width > max_w: + max_w = width + path = [path[0]] + path + + return max_w + 1, np.array(path) + +def first_back_track(i, j, b, search_path, a_types): + alignment = [] + while ( i !=0 and j != 0 ): + j_offset = j - search_path[i][0] + a = b[i][j_offset] + s = a_types[a][0] + t = a_types[a][1] + if a == 2: + alignment.append((i, j)) + + i = i-s + j = j-t + + return alignment[::-1] + +@nb.jit(nopython=True, fastmath=True, cache=True) +def first_pass_align(src_len, tgt_len, w, search_path, align_types, dist, index, top_k): + + #initialize cost and backpointer matrix + cost = np.zeros((src_len + 1, 2 * w + 1)) + pointers = np.zeros((src_len + 1, 2 * w + 1), dtype=nb.int64) + cost[0][0] = 0 + pointers[0][0] = -1 + + for i in range(1, src_len + 1): + i_start = search_path[i][0] + i_end = search_path[i][1] + for j in range(i_start, i_end + 1): + if i + j == 0: + continue + best_score = -np.inf + best_a = -1 + for a in range(align_types.shape[0]): + a_1 = align_types[a][0] + a_2 = align_types[a][1] + prev_i = i - a_1 + prev_j = j - a_2 + if prev_i < 0 or prev_j < 0 : # no previous cell + continue + prev_i_start = search_path[prev_i][0] + prev_i_end = search_path[prev_i][1] + if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix + continue + prev_j_offset = prev_j - prev_i_start + score = cost[prev_i][prev_j_offset] + if score == -np.inf: + continue + + if a_1 > 0 and a_2 > 0: + for k in range(top_k): + if index[i-1][k] == j - 1: + score += dist[i-1][k] + if score > best_score: + best_score = score + best_a = a + + j_offset = j - i_start + cost[i][j_offset] = best_score + pointers[i][j_offset] = best_a + + return pointers + +@nb.jit(nopython=True, fastmath=True, cache=True) +def find_first_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100): + yx_ratio = tgt_len / src_len + win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100) + win_size_2 = int(abs(tgt_len - src_len) * 3/4) + w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size) + w_2 = int(max(src_len, tgt_len) * 0.06) + w = max(w_1, w_2) + search_path = np.zeros((src_len + 1, 2), dtype=nb.int64) + for i in range(0, src_len + 1): + center = int(yx_ratio * i) + w_start = max(0, center - w) + w_end = min(center + w, tgt_len) + search_path[i] = [w_start, w_end] + + return w, search_path + +def doc2feats(sent2line, line_embeddings, lines, num_overlaps): + lines = [preprocess_line(line) for line in lines] + vecsize = line_embeddings.shape[1] + vecs0 = np.empty((num_overlaps, len(lines), vecsize), dtype=np.float32) + vecs1 = np.empty((num_overlaps, len(lines)), dtype=np.int) + + for ii, overlap in enumerate(range(1, num_overlaps + 1)): + for jj, out_line in enumerate(layer(lines, overlap)): + try: + line_id = sent2line[out_line] + except KeyError: + logger.warning('Failed to find overlap=%d line "%s". Will use random vector.', overlap, out_line) + line_id = None + + if line_id is not None: + vec = line_embeddings[line_id] + else: + vec = np.random.random(vecsize) - 0.5 + vec = vec / np.linalg.norm(vec) + + vecs0[ii, jj, :] = vec + vecs1[ii, jj] = len(out_line.encode("utf-8")) + + return vecs0, vecs1 + +def preprocess_line(line): + line = line.strip() + if len(line) == 0: + line = 'BLANK_LINE' + + return line + +def layer(lines, num_overlaps, comb=' '): + """ + make front-padded overlapping sentences + """ + if num_overlaps < 1: + raise Exception('num_overlaps must be >= 1') + out = ['PAD', ] * min(num_overlaps - 1, len(lines)) + for ii in range(len(lines) - num_overlaps + 1): + out.append(comb.join(lines[ii:ii + num_overlaps])) + + return out + +def read_in_embeddings(text_file, embed_file): + sent2line = dict() + with open(text_file, 'rt', encoding="utf-8") as fin: + for ii, line in enumerate(fin): + if line.strip() in sent2line: + raise Exception('got multiple embeddings for the same line') + sent2line[line.strip()] = ii + + line_embeddings = np.fromfile(embed_file, dtype=np.float32, count=-1) + if line_embeddings.size == 0: + raise Exception('Got empty embedding file') + + embedding_size = line_embeddings.size // len(sent2line) + line_embeddings.resize(line_embeddings.shape[0] // embedding_size, embedding_size) + + return sent2line, line_embeddings + +def make_alignment_types(max_alignment_size): + # Return list of all (n,m) where n+m <= this + alignment_types = [] + for x in range(1, max_alignment_size): + for y in range(1, max_alignment_size): + if x + y <= max_alignment_size: + alignment_types.append([x, y]) + alignment_types = [[0,1], [1,0]] + alignment_types + + return np.array(alignment_types) + +def create_jobs(meta, src, tgt, out): + jobs = [] + fns = get_fns(meta) + for file in fns: + src_path = os.path.abspath(os.path.join(src, file)) + tgt_path = os.path.abspath(os.path.join(tgt, file)) + + out_path = os.path.abspath(os.path.join(out, file + '.align')) + jobs.append('\t'.join([src_path, tgt_path, out_path])) + + return jobs + +def get_fns(meta): + fns = [] + with open(meta, 'rt', encoding='utf-8') as f: + next(f) # skip header + for line in f: + recs = line.strip().split('\t') + fns.append(recs[0]) + + return fns + +def print_alignments(alignments, out): + with open(out, 'wt', encoding='utf-8') as f: + for x, y in alignments: + f.write("{}:{}\n".format(x, y)) + +def make_dir(path): + if os.path.isdir(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + +if __name__ == '__main__': + t_0 = time.time() + main() + print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0)) diff --git a/bin/embed_sents.py b/bin/embed_sents.py new file mode 100644 index 0000000..cc95312 --- /dev/null +++ b/bin/embed_sents.py @@ -0,0 +1,107 @@ +# 2021/11/27 +# bfsujason@163.com + +''' +Usage (Linux): + +python bin/embed_sents.py \ + -i data/mac/dev/zh \ + -o data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \ + -m data/mac/test/meta_data.tsv \ + -n 8 +''' + +import os +import time +import shutil +import argparse +import numpy as np +from sentence_transformers import SentenceTransformer + +def main(): + parser = argparse.ArgumentParser(description='Multilingual sentence embeddings') + parser.add_argument('-i', '--input', type=str, required=True, help='Data directory.') + parser.add_argument('-o', '--output', type=str, required=True, nargs=2, help='Overalp and embedding file.') + parser.add_argument('-n', '--num_overlaps', type=int, default=5, help='Maximum number of allowed overlaps.') + parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.') + args = parser.parse_args() + + fns = get_fns(args.meta) + overlap = get_overlap(args.input, fns, args.num_overlaps) + write_overlap(overlap, args.output[0]) + + model = load_model() + embed_overlap(model, overlap, args.output[1]) + +def embed_overlap(model, overlap, fout): + print("Embedding text ...") + t_0 = time.time() + embed = model.encode(overlap) + embed.tofile(fout) + print("It takes {:.3f} seconods to embed text.".format(time.time() - t_0)) + +def write_overlap(overlap, outfile): + with open(outfile, 'wt', encoding="utf-8") as fout: + for line in overlap: + fout.write(line + '\n') + +def get_overlap(dir, fns, n): + overlap = set() + for file in fns: + in_path = os.path.join(dir, file) + lines = open(in_path, 'rt', encoding="utf-8").readlines() + for out_line in yield_overlaps(lines, n): + overlap.add(out_line) + + # for reproducibility + overlap = list(overlap) + overlap.sort() + + return overlap + +def yield_overlaps(lines, num_overlaps): + lines = [preprocess_line(line) for line in lines] + for overlap in range(1, num_overlaps + 1): + for out_line in layer(lines, overlap): + # check must be here so all outputs are unique + out_line2 = out_line[:10000] # limit line so dont encode arbitrarily long sentences + yield out_line2 + +def layer(lines, num_overlaps, comb=' '): + if num_overlaps < 1: + raise Exception('num_overlaps must be >= 1') + out = ['PAD', ] * min(num_overlaps - 1, len(lines)) + for ii in range(len(lines) - num_overlaps + 1): + out.append(comb.join(lines[ii:ii + num_overlaps])) + return out + +def preprocess_line(line): + line = line.strip() + if len(line) == 0: + line = 'BLANK_LINE' + return line + +def load_model(): + print("Loading embedding model ...") + t0 = time.time() + model = SentenceTransformer('LaBSE') + print("It takes {:.3f} seconods to load the model.".format(time.time() - t0)) + return model + +def get_fns(meta): + fns = [] + with open(meta, 'rt', encoding='utf-8') as f: + next(f) # skip header + for line in f: + recs = line.strip().split('\t') + fns.append(recs[0]) + + return fns + +def make_dir(path): + if os.path.isdir(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + +if __name__ == '__main__': + main() diff --git a/bin/eval.py b/bin/eval.py new file mode 100644 index 0000000..782b259 --- /dev/null +++ b/bin/eval.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 + +""" +Copyright 2019 Brian Thompson + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +""" + +import os +import sys +import argparse +from ast import literal_eval +from collections import defaultdict + +import numpy as np + +""" +Faster implementation of lax and strict precision and recall, based on + https://www.aclweb.org/anthology/W11-4624/. + +""" + +def read_alignments(fin): + alignments = [] + with open(fin, 'rt', encoding="utf-8") as infile: + for line in infile: + fields = [x.strip() for x in line.split(':') if len(x.strip())] + if len(fields) < 2: + raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip()) + try: + src = literal_eval(fields[0]) + tgt = literal_eval(fields[1]) + except: + raise Exception('Failed to parse line "%s"' % line.strip()) + alignments.append((src, tgt)) + + return alignments + +def _precision(goldalign, testalign): + """ + Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments + """ + tpstrict = 0 # true positive strict counter + tplax = 0 # true positive lax counter + fpstrict = 0 # false positive strict counter + fplax = 0 # false positive lax counter + + # convert to sets, remove alignments empty on both sides + testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)]) + goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)]) + + # mappings from source test sentence idxs to + # target gold sentence idxs for which the source test sentence + # was found in corresponding source gold alignment + src_id_to_gold_tgt_ids = defaultdict(set) + for gold_src, gold_tgt in goldalign: + for gold_src_id in gold_src: + for gold_tgt_id in gold_tgt: + src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id) + + for (test_src, test_target) in testalign: + if (test_src, test_target) == ((), ()): + continue + if (test_src, test_target) in goldalign: + # strict match + tpstrict += 1 + tplax += 1 + else: + # For anything with partial gold/test overlap on the source, + # see if there is also partial overlap on the gold/test target + # If so, its a lax match + target_ids = set() + for src_test_id in test_src: + for tgt_id in src_id_to_gold_tgt_ids[src_test_id]: + target_ids.add(tgt_id) + if set(test_target).intersection(target_ids): + fpstrict += 1 + tplax += 1 + else: + fpstrict += 1 + fplax += 1 + + return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32) + + +def score_multiple(gold_list, test_list, value_for_div_by_0=0.0): + # accumulate counts for all gold/test files + pcounts = np.array([0, 0, 0, 0], dtype=np.int32) + rcounts = np.array([0, 0, 0, 0], dtype=np.int32) + for goldalign, testalign in zip(gold_list, test_list): + pcounts += _precision(goldalign=goldalign, testalign=testalign) + # recall is precision with no insertion/deletion and swap args + test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)] + gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)] + rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del) + + # Compute results + # pcounts: tpstrict,fnstrict,tplax,fnlax + # rcounts: tpstrict,fpstrict,tplax,fplax + + if pcounts[0] + pcounts[1] == 0: + pstrict = value_for_div_by_0 + else: + pstrict = pcounts[0] / float(pcounts[0] + pcounts[1]) + + if pcounts[2] + pcounts[3] == 0: + plax = value_for_div_by_0 + else: + plax = pcounts[2] / float(pcounts[2] + pcounts[3]) + + if rcounts[0] + rcounts[1] == 0: + rstrict = value_for_div_by_0 + else: + rstrict = rcounts[0] / float(rcounts[0] + rcounts[1]) + + if rcounts[2] + rcounts[3] == 0: + rlax = value_for_div_by_0 + else: + rlax = rcounts[2] / float(rcounts[2] + rcounts[3]) + + if (pstrict + rstrict) == 0: + fstrict = value_for_div_by_0 + else: + fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict) + + if (plax + rlax) == 0: + flax = value_for_div_by_0 + else: + flax = 2 * (plax * rlax) / (plax + rlax) + + result = dict(recall_strict=rstrict, + recall_lax=rlax, + precision_strict=pstrict, + precision_lax=plax, + f1_strict=fstrict, + f1_lax=flax) + + return result + + +def log_final_scores(res): + print(' ---------------------------------', file=sys.stderr) + print('| | Strict | Lax |', file=sys.stderr) + print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr) + print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr) + print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr) + print(' ---------------------------------', file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser( + 'Compute strict/lax precision and recall for one or more pairs of gold/test alignments', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('-t', '--test', type=str, required=True, + help='Test alignment directory.') + + parser.add_argument('-g', '--gold', type=str, required=True, + help='Gold alignment directory.') + + args = parser.parse_args() + + gold_list = [read_alignments(os.path.join(args.gold, x)) for x in sorted(os.listdir(args.gold))] + test_list = [read_alignments(os.path.join(args.test, x)) for x in sorted(os.listdir(args.test))] + + res = score_multiple(gold_list=gold_list, test_list=test_list) + log_final_scores(res) + + +if __name__ == '__main__': + main() diff --git a/bin/eval_bible.py b/bin/eval_bible.py new file mode 100644 index 0000000..f9b9526 --- /dev/null +++ b/bin/eval_bible.py @@ -0,0 +1,109 @@ +import os +import argparse +from ast import literal_eval +from collections import defaultdict +from eval import score_multiple, log_final_scores + +def main(): + parser = argparse.ArgumentParser('Evaluate aligment quality for Bible corpus') + parser.add_argument('-t', '--test', type=str, required=True, help='Test alignment file.') + parser.add_argument('-g', '--gold', type=str, required=True, help='Gold alignment file.') + parser.add_argument('--src_verse', type=str, required=True, help='Source verse file.') + parser.add_argument('--tgt_verse', type=str, required=True, help='Target verse file.') + args = parser.parse_args() + + test_alignments = read_alignments(args.test) + gold_alignments = read_alignments(args.gold) + + src_verse = get_verse(args.src_verse) + tgt_verse = get_verse(args.tgt_verse) + + merged_test_alignments = merge_test_alignments(test_alignments, src_verse, tgt_verse) + res = score_multiple(gold_list=[gold_alignments], test_list=[merged_test_alignments]) + log_final_scores(res) + +def merge_test_alignments(alignments, src_verse, tgt_verse): + merged_align = [] + last_beads_type = None + + for beads in alignments: + beads_type = find_beads_type(beads, src_verse, tgt_verse) + if not last_beads_type: + merged_align.append(beads) + else: + if beads_type == last_beads_type: + merged_align[-1][0].extend(beads[0]) + merged_align[-1][1].extend(beads[1]) + else: + merged_align.append(beads) + + last_beads_type = beads_type + + return merged_align + +def find_beads_type(beads, src_verse, tgt_verse): + src_bead = beads[0] + tgt_bead = beads[1] + + src_bead_type = find_bead_type(src_bead, src_verse) + tgt_bead_type = find_bead_type(tgt_bead, tgt_verse) + + src_bead_len = len(src_bead_type) + tgt_bead_len = len(tgt_bead_type) + + if src_bead_len != 1 or tgt_bead_len != 1: + return None + else: + src_verse = src_bead_type[0] + tgt_verse = tgt_bead_type[0] + if src_verse != tgt_verse: + if src_verse == 'NULL': + return tgt_verse + elif tgt_verse == 'NULL': + return src_verse + else: + return None + else: + return src_verse + +def find_bead_type(bead, verse): + bead_type = ['NULL'] + if len(bead) > 0: + bead_type = unique_list([verse[id] for id in bead]) + + return bead_type + +def unique_list(list): + unique_list = [] + for x in list: + if x not in unique_list: + unique_list.append(x) + + return unique_list + +def get_verse(file): + verse = defaultdict() + with open(file, 'rt', encoding='utf-8') as f: + for (i, line) in enumerate(f): + verse[i] = line.strip() + + return verse + +def read_alignments(fin): + alignments = [] + with open(fin, 'rt', encoding="utf-8") as infile: + for line in infile: + fields = [x.strip() for x in line.split(':') if len(x.strip())] + if len(fields) < 2: + raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip()) + try: + src = literal_eval(fields[0]) + tgt = literal_eval(fields[1]) + except: + raise Exception('Failed to parse line "%s"' % line.strip()) + alignments.append((src, tgt)) + + return alignments + +if __name__ == '__main__': + main() diff --git a/bin/gale_align.py b/bin/gale_align.py new file mode 100644 index 0000000..de5f38f --- /dev/null +++ b/bin/gale_align.py @@ -0,0 +1,226 @@ +# 2021/11/27 +# bfsujason@163.com + +""" +Usage: + +python bin/gale_align.py \ + -m data/mac/test/meta_data.tsv \ + -s data/mac/test/zh \ + -t data/mac/test/en \ + -o data/mac/test/auto +""" + +import os +import time +import math +import shutil +import argparse +import numba as nb +import numpy as np + +def main(): + # user-defined parameters + parser = argparse.ArgumentParser(description='Sentence alignment using Gale-Church Algrorithm') + parser.add_argument('-s', '--src', type=str, required=True, help='Source directory.') + parser.add_argument('-t', '--tgt', type=str, required=True, help='Target directory.') + parser.add_argument('-o', '--out', type=str, required=True, help='Output directory.') + parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.') + args = parser.parse_args() + + make_dir(args.out) + + # fixed parameters to determine the window size for alignment + min_win_size = 10 + max_win_size = 600 + win_per_100 = 8 + + # alignment types + align_types = np.array( + [ + [0,1], + [1,0], + [1,1], + [1,2], + [2,1], + [2,2], + ], dtype=np.int) + + # prior probability + priors = np.array([0, 0.0099, 0.89, 0.089, 0.011]) + + # mean and variance + c = 1 + s2 = 6.8 + + # perform gale-church align + jobs = create_jobs(args.meta, args.src, args.tgt, args.out) + for rec in jobs: + src_file, tgt_file, align_file = rec.split("\t") + print("Aligning {} to {}".format(src_file, tgt_file)) + src_lines = open(src_file, 'rt', encoding="utf-8").readlines() + tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines() + src_len = calculate_txt_len(src_lines) + tgt_len = calculate_txt_len(tgt_lines) + + m = src_len.shape[0] - 1 + n = tgt_len.shape[0] - 1 + + # find search path + w, search_path = find_search_path(m, n, min_win_size, max_win_size, win_per_100) + cost, back = align(src_len, tgt_len, w, search_path, align_types, priors, c, s2) + alignments = back_track(m, n, back, search_path, align_types) + + # save alignments + save_alignments(alignments, align_file) + +def save_alignments(alignments, file): + with open(file, 'wt', encoding='utf-8') as f: + for id in alignments: + f.write("{}:{}\n".format(id[0], id[1])) + +def back_track(i, j, b, search_path, a_types): + alignment = [] + while ( i !=0 and j != 0 ): + j_offset = j - search_path[i][0] + a = b[i][j_offset] + s = a_types[a][0] + t = a_types[a][1] + src_range = [i - offset - 1 for offset in range(s)][::-1] + tgt_range = [j - offset - 1 for offset in range(t)][::-1] + alignment.append((src_range, tgt_range)) + + i = i-s + j = j-t + + return alignment[::-1] + +@nb.jit(nopython=True, fastmath=True, cache=True) +def align(src_len, tgt_len, w, search_path, align_types, priors, c, s2): + #initialize cost and backpointer matrix + m = src_len.shape[0] - 1 + cost = np.zeros((m + 1, 2 * w + 1)) + back = np.zeros((m + 1, 2 * w + 1), dtype=nb.int64) + cost[0][0] = 0 + back[0][0] = -1 + + for i in range(m + 1): + i_start = search_path[i][0] + i_end = search_path[i][1] + + for j in range(i_start, i_end + 1): + if i + j == 0: + continue + + best_score = np.inf + best_a = -1 + for a in range(align_types.shape[0]): + a_1 = align_types[a][0] + a_2 = align_types[a][1] + prev_i = i - a_1 + prev_j = j - a_2 + + if prev_i < 0 or prev_j < 0 : # no previous cell + continue + + prev_i_start = search_path[prev_i][0] + prev_i_end = search_path[prev_i][1] + + if prev_j < prev_i_start or prev_j > prev_i_end: # out of bound of cost matrix + continue + + prev_j_offset = prev_j - prev_i_start + + score = cost[prev_i][prev_j_offset] - math.log(priors[a_1 + a_2]) + \ + get_score(src_len[i] - src_len[i - a_1], tgt_len[j] - tgt_len[j - a_2], c, s2) + + if score < best_score: + best_score = score + best_a = a + + j_offset = j - i_start + cost[i][j_offset] = best_score + back[i][j_offset] = best_a + + return cost, back + +@nb.jit(nopython=True, fastmath=True, cache=True) +def get_score(len_s, len_t, c, s2): + mean = (len_s + len_t / c) / 2 + z = (len_t - len_s * c) / math.sqrt(mean * s2) + + pd = 2 * (1 - norm_cdf(abs(z))) + if pd > 0: + return -math.log(pd) + + return 25 + +@nb.jit(nopython=True, fastmath=True, cache=True) +def find_search_path(src_len, tgt_len, min_win_size, max_win_size, win_per_100): + yx_ratio = tgt_len / src_len + win_size_1 = int(yx_ratio * tgt_len * win_per_100 / 100) + win_size_2 = int(abs(tgt_len - src_len) * 3/4) + + w_1 = min(max(min_win_size, max(win_size_1, win_size_2)), max_win_size) + w_2 = int(max(src_len, tgt_len) * 0.06) + w = max(w_1, w_2) + + search_path = np.zeros((src_len + 1, 2), dtype=nb.int64) + for i in range(0, src_len + 1): + center = int(yx_ratio * i) + w_start = max(0, center - w) + w_end = min(center + w, tgt_len) + search_path[i] = [w_start, w_end] + + return w, search_path + +@nb.jit(nopython=True, fastmath=True, cache=True) +def norm_cdf(z): + t = 1/float(1+0.2316419*z) # t = 1/(1+pz) , z=0.2316419 + p_norm = 1 - 0.3989423*math.exp(-z*z/2) * ((0.319381530 * t)+ \ + (-0.356563782 * t)+ \ + (1.781477937 * t) + \ + (-1.821255978* t) + \ + (1.330274429 * t)) + + return p_norm + +def calculate_txt_len(lines): + txt_len = [] + txt_len.append(0) + for i, line in enumerate(lines): + # UTF-8 byte length + txt_len.append(txt_len[i] + len(line.strip().encode("utf-8"))) + + return np.array(txt_len) + +def create_jobs(meta, src, tgt, out): + jobs = [] + fns = get_fns(meta) + for file in fns: + src_path = os.path.abspath(os.path.join(src, file)) + tgt_path = os.path.abspath(os.path.join(tgt, file)) + out_path = os.path.abspath(os.path.join(out, file + '.align')) + jobs.append('\t'.join([src_path, tgt_path, out_path])) + + return jobs + +def get_fns(meta): + fns = [] + with open(meta, 'rt', encoding='utf-8') as f: + next(f) # skip header + for line in f: + recs = line.strip().split('\t') + fns.append(recs[0]) + + return fns + +def make_dir(path): + if os.path.isdir(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + +if __name__ == '__main__': + t_0 = time.time() + main() + print("It takes {:.3f} seconds to align all the sentences.".format(time.time() - t_0))