Bertalign
Bertalign
This commit is contained in:
165
bin/vecalign/vecalign.py
Normal file
165
bin/vecalign/vecalign.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Copyright 2019 Brian Thompson
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
from math import ceil
|
||||
from random import seed as seed
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger('vecalign')
|
||||
logger.setLevel(logging.WARNING)
|
||||
logFormatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
|
||||
consoleHandler = logging.StreamHandler()
|
||||
consoleHandler.setFormatter(logFormatter)
|
||||
logger.addHandler(consoleHandler)
|
||||
|
||||
from dp_utils import make_alignment_types, print_alignments, read_alignments, \
|
||||
read_in_embeddings, make_doc_embedding, vecalign
|
||||
|
||||
from score import score_multiple, log_final_scores
|
||||
|
||||
|
||||
def _main():
|
||||
# make runs consistent
|
||||
seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
parser = argparse.ArgumentParser('Sentence alignment using sentence embeddings and FastDTW',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
#parser.add_argument('-s', '--src', type=str, nargs='+', required=True,
|
||||
# help='preprocessed source file to align')
|
||||
|
||||
#parser.add_argument('-t', '--tgt', type=str, nargs='+', required=True,
|
||||
# help='preprocessed target file to align')
|
||||
|
||||
parser.add_argument('--job', type=str, required=True, help='Job file for alignment task.')
|
||||
|
||||
parser.add_argument('-g', '--gold_alignment', type=str, nargs='+', required=False,
|
||||
help='preprocessed target file to align')
|
||||
|
||||
parser.add_argument('--src_embed', type=str, nargs=2, required=True,
|
||||
help='Source embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
|
||||
|
||||
parser.add_argument('--tgt_embed', type=str, nargs=2, required=True,
|
||||
help='Target embeddings. Requires two arguments: first is a text file, sencond is a binary embeddings file. ')
|
||||
|
||||
parser.add_argument('-a', '--alignment_max_size', type=int, default=5,
|
||||
help='Searches for alignments up to size N-M, where N+M <= this value. Note that the the embeddings must support the requested number of overlaps')
|
||||
|
||||
parser.add_argument('-d', '--del_percentile_frac', type=float, default=0.2,
|
||||
help='Deletion penalty is set to this percentile (as a fraction) of the cost matrix distribution. Should be between 0 and 1.')
|
||||
|
||||
parser.add_argument('-v', '--verbose', help='sets consle to logging.DEBUG instead of logging.WARN',
|
||||
action='store_true')
|
||||
|
||||
parser.add_argument('--max_size_full_dp', type=int, default=300,
|
||||
help='Maximum size N for which is is acceptable to run full N^2 dynamic programming.')
|
||||
|
||||
parser.add_argument('--costs_sample_size', type=int, default=20000,
|
||||
help='Sample size to estimate costs distribution, used to set deletion penalty in conjunction with deletion_percentile.')
|
||||
|
||||
parser.add_argument('--num_samps_for_norm', type=int, default=100,
|
||||
help='Number of samples used for normalizing embeddings')
|
||||
|
||||
parser.add_argument('--search_buffer_size', type=int, default=5,
|
||||
help='Width (one side) of search buffer. Larger values makes search more likely to recover from errors but increases runtime.')
|
||||
|
||||
parser.add_argument('--debug_save_stack', type=str,
|
||||
help='Write stack to pickle file for debug purposes')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#if len(args.src) != len(args.tgt):
|
||||
# raise Exception('number of source files must match number of target files')
|
||||
|
||||
#if args.gold_alignment is not None:
|
||||
# if len(args.gold_alignment) != len(args.src):
|
||||
# raise Exception('number of gold alignment files, if provided, must match number of source and target files')
|
||||
|
||||
if args.verbose:
|
||||
import logging
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
if args.alignment_max_size < 2:
|
||||
logger.warning('Alignment_max_size < 2. Increasing to 2 so that 1-1 alignments will be considered')
|
||||
args.alignment_max_size = 2
|
||||
|
||||
src_sent2line, src_line_embeddings = read_in_embeddings(args.src_embed[0], args.src_embed[1])
|
||||
tgt_sent2line, tgt_line_embeddings = read_in_embeddings(args.tgt_embed[0], args.tgt_embed[1])
|
||||
|
||||
width_over2 = ceil(args.alignment_max_size / 2.0) + args.search_buffer_size
|
||||
|
||||
test_alignments = []
|
||||
stack_list = []
|
||||
|
||||
# read in alignment jobs
|
||||
job = read_job(args.job)
|
||||
#for src_file, tgt_file in zip(args.src, args.tgt):
|
||||
for rec in job:
|
||||
#logger.info('Aligning src="%s" to tgt="%s"', src_file, tgt_file)
|
||||
src_file, tgt_file, align_file = rec.split("\t")
|
||||
print("Aligning {} to {}".format(src_file, tgt_file))
|
||||
src_lines = open(src_file, 'rt', encoding="utf-8").readlines()
|
||||
vecs0 = make_doc_embedding(src_sent2line, src_line_embeddings, src_lines, args.alignment_max_size)
|
||||
|
||||
tgt_lines = open(tgt_file, 'rt', encoding="utf-8").readlines()
|
||||
vecs1 = make_doc_embedding(tgt_sent2line, tgt_line_embeddings, tgt_lines, args.alignment_max_size)
|
||||
|
||||
final_alignment_types = make_alignment_types(args.alignment_max_size)
|
||||
logger.debug('Considering alignment types %s', final_alignment_types)
|
||||
|
||||
stack = vecalign(vecs0=vecs0,
|
||||
vecs1=vecs1,
|
||||
final_alignment_types=final_alignment_types,
|
||||
del_percentile_frac=args.del_percentile_frac,
|
||||
width_over2=width_over2,
|
||||
max_size_full_dp=args.max_size_full_dp,
|
||||
costs_sample_size=args.costs_sample_size,
|
||||
num_samps_for_norm=args.num_samps_for_norm)
|
||||
|
||||
# write final alignments to stdout
|
||||
#print_alignments(stack[0]['final_alignments'], stack[0]['alignment_scores'])
|
||||
out_f = open(align_file, 'w', encoding="utf-8")
|
||||
#print_alignments(stack[0]['final_alignments'], stack[0]['alignment_scores'],file=out_f)
|
||||
print_alignments(stack[0]['final_alignments'],file=out_f)
|
||||
|
||||
#test_alignments.append(stack[0]['final_alignments'])
|
||||
#stack_list.append(stack)
|
||||
|
||||
#if args.gold_alignment is not None:
|
||||
# gold_list = [read_alignments(x) for x in args.gold_alignment]
|
||||
# res = score_multiple(gold_list=gold_list, test_list=test_alignments)
|
||||
# log_final_scores(res)
|
||||
|
||||
#if args.debug_save_stack:
|
||||
# pickle.dump(stack_list, open(args.debug_save_stack, 'wb'))
|
||||
|
||||
def read_job(file):
|
||||
job = []
|
||||
with open(file, 'r', encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.startswith("#"):
|
||||
job.append(line.strip())
|
||||
return job
|
||||
|
||||
if __name__ == '__main__':
|
||||
_main()
|
||||
Reference in New Issue
Block a user