108 lines
3.0 KiB
Python
108 lines
3.0 KiB
Python
# 2021/11/27
|
|
# bfsujason@163.com
|
|
|
|
'''
|
|
Usage (Linux):
|
|
|
|
python bin/embed_sents.py \
|
|
-i data/mac/dev/zh \
|
|
-o data/mac/dev/zh/overlap data/mac/dev/zh/overlap.emb \
|
|
-m data/mac/test/meta_data.tsv \
|
|
-n 8
|
|
'''
|
|
|
|
import os
|
|
import time
|
|
import shutil
|
|
import argparse
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Multilingual sentence embeddings')
|
|
parser.add_argument('-i', '--input', type=str, required=True, help='Data directory.')
|
|
parser.add_argument('-o', '--output', type=str, required=True, nargs=2, help='Overalp and embedding file.')
|
|
parser.add_argument('-n', '--num_overlaps', type=int, default=5, help='Maximum number of allowed overlaps.')
|
|
parser.add_argument('-m', '--meta', type=str, required=True, help='Metadata file.')
|
|
args = parser.parse_args()
|
|
|
|
fns = get_fns(args.meta)
|
|
overlap = get_overlap(args.input, fns, args.num_overlaps)
|
|
write_overlap(overlap, args.output[0])
|
|
|
|
model = load_model()
|
|
embed_overlap(model, overlap, args.output[1])
|
|
|
|
def embed_overlap(model, overlap, fout):
|
|
print("Embedding text ...")
|
|
t_0 = time.time()
|
|
embed = model.encode(overlap)
|
|
embed.tofile(fout)
|
|
print("It takes {:.3f} seconods to embed text.".format(time.time() - t_0))
|
|
|
|
def write_overlap(overlap, outfile):
|
|
with open(outfile, 'wt', encoding="utf-8") as fout:
|
|
for line in overlap:
|
|
fout.write(line + '\n')
|
|
|
|
def get_overlap(dir, fns, n):
|
|
overlap = set()
|
|
for file in fns:
|
|
in_path = os.path.join(dir, file)
|
|
lines = open(in_path, 'rt', encoding="utf-8").readlines()
|
|
for out_line in yield_overlaps(lines, n):
|
|
overlap.add(out_line)
|
|
|
|
# for reproducibility
|
|
overlap = list(overlap)
|
|
overlap.sort()
|
|
|
|
return overlap
|
|
|
|
def yield_overlaps(lines, num_overlaps):
|
|
lines = [preprocess_line(line) for line in lines]
|
|
for overlap in range(1, num_overlaps + 1):
|
|
for out_line in layer(lines, overlap):
|
|
# check must be here so all outputs are unique
|
|
out_line2 = out_line[:10000] # limit line so dont encode arbitrarily long sentences
|
|
yield out_line2
|
|
|
|
def layer(lines, num_overlaps, comb=' '):
|
|
if num_overlaps < 1:
|
|
raise Exception('num_overlaps must be >= 1')
|
|
out = ['PAD', ] * min(num_overlaps - 1, len(lines))
|
|
for ii in range(len(lines) - num_overlaps + 1):
|
|
out.append(comb.join(lines[ii:ii + num_overlaps]))
|
|
return out
|
|
|
|
def preprocess_line(line):
|
|
line = line.strip()
|
|
if len(line) == 0:
|
|
line = 'BLANK_LINE'
|
|
return line
|
|
|
|
def load_model():
|
|
print("Loading embedding model ...")
|
|
t0 = time.time()
|
|
model = SentenceTransformer('LaBSE')
|
|
print("It takes {:.3f} seconods to load the model.".format(time.time() - t0))
|
|
return model
|
|
|
|
def get_fns(meta):
|
|
fns = []
|
|
with open(meta, 'rt', encoding='utf-8') as f:
|
|
next(f) # skip header
|
|
for line in f:
|
|
recs = line.strip().split('\t')
|
|
fns.append(recs[0])
|
|
|
|
return fns
|
|
|
|
def make_dir(path):
|
|
if os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|