Bertalign and evaluation scripts
This commit is contained in:
107
bin/embed_sents.py
Normal file
107
bin/embed_sents.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# 2021/11/27
|
||||
# bfsujason@163.com
|
||||
|
||||
'''
|
||||
Usage (Linux):
|
||||
|
||||
python bin/embed_sents.py \
|
||||
-i data/mac/dev/zh \
|
||||
-o data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
|
||||
-m data/mac/test/meta_data.tsv \
|
||||
-n 8
|
||||
'''
|
||||
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
import argparse
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Multilingual sentence embeddings')
|
||||
parser.add_argument('-i', '--input', type=str, required=True, help='Data directory.')
|
||||
parser.add_argument('-o', '--output', type=str, required=True, nargs=2, help='Overalp and embedding file.')
|
||||
parser.add_argument('-n', '--num_overlaps', type=int, default=5, help='Maximum number of allowed overlaps.')
|
||||
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
fns = get_fns(args.meta)
|
||||
overlap = get_overlap(args.input, fns, args.num_overlaps)
|
||||
write_overlap(overlap, args.output[0])
|
||||
|
||||
model = load_model()
|
||||
embed_overlap(model, overlap, args.output[1])
|
||||
|
||||
def embed_overlap(model, overlap, fout):
|
||||
print("Embedding text ...")
|
||||
t_0 = time.time()
|
||||
embed = model.encode(overlap)
|
||||
embed.tofile(fout)
|
||||
print("It takes {:.3f} seconods to embed text.".format(time.time() - t_0))
|
||||
|
||||
def write_overlap(overlap, outfile):
|
||||
with open(outfile, 'wt', encoding="utf-8") as fout:
|
||||
for line in overlap:
|
||||
fout.write(line + '\n')
|
||||
|
||||
def get_overlap(dir, fns, n):
|
||||
overlap = set()
|
||||
for file in fns:
|
||||
in_path = os.path.join(dir, file)
|
||||
lines = open(in_path, 'rt', encoding="utf-8").readlines()
|
||||
for out_line in yield_overlaps(lines, n):
|
||||
overlap.add(out_line)
|
||||
|
||||
# for reproducibility
|
||||
overlap = list(overlap)
|
||||
overlap.sort()
|
||||
|
||||
return overlap
|
||||
|
||||
def yield_overlaps(lines, num_overlaps):
|
||||
lines = [preprocess_line(line) for line in lines]
|
||||
for overlap in range(1, num_overlaps + 1):
|
||||
for out_line in layer(lines, overlap):
|
||||
# check must be here so all outputs are unique
|
||||
out_line2 = out_line[:10000] # limit line so dont encode arbitrarily long sentences
|
||||
yield out_line2
|
||||
|
||||
def layer(lines, num_overlaps, comb=' '):
|
||||
if num_overlaps < 1:
|
||||
raise Exception('num_overlaps must be >= 1')
|
||||
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
||||
for ii in range(len(lines) - num_overlaps + 1):
|
||||
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
||||
return out
|
||||
|
||||
def preprocess_line(line):
|
||||
line = line.strip()
|
||||
if len(line) == 0:
|
||||
line = 'BLANK_LINE'
|
||||
return line
|
||||
|
||||
def load_model():
|
||||
print("Loading embedding model ...")
|
||||
t0 = time.time()
|
||||
model = SentenceTransformer('LaBSE')
|
||||
print("It takes {:.3f} seconods to load the model.".format(time.time() - t0))
|
||||
return model
|
||||
|
||||
def get_fns(meta):
|
||||
fns = []
|
||||
with open(meta, 'rt', encoding='utf-8') as f:
|
||||
next(f) # skip header
|
||||
for line in f:
|
||||
recs = line.strip().split('\t')
|
||||
fns.append(recs[0])
|
||||
|
||||
return fns
|
||||
|
||||
def make_dir(path):
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user