Bertalign and evaluation scripts
This commit is contained in:
109
bin/eval_bible.py
Normal file
109
bin/eval_bible.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import argparse
|
||||
from ast import literal_eval
|
||||
from collections import defaultdict
|
||||
from eval import score_multiple, log_final_scores
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser('Evaluate aligment quality for Bible corpus')
|
||||
parser.add_argument('-t', '--test', type=str, required=True, help='Test alignment file.')
|
||||
parser.add_argument('-g', '--gold', type=str, required=True, help='Gold alignment file.')
|
||||
parser.add_argument('--src_verse', type=str, required=True, help='Source verse file.')
|
||||
parser.add_argument('--tgt_verse', type=str, required=True, help='Target verse file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
test_alignments = read_alignments(args.test)
|
||||
gold_alignments = read_alignments(args.gold)
|
||||
|
||||
src_verse = get_verse(args.src_verse)
|
||||
tgt_verse = get_verse(args.tgt_verse)
|
||||
|
||||
merged_test_alignments = merge_test_alignments(test_alignments, src_verse, tgt_verse)
|
||||
res = score_multiple(gold_list=[gold_alignments], test_list=[merged_test_alignments])
|
||||
log_final_scores(res)
|
||||
|
||||
def merge_test_alignments(alignments, src_verse, tgt_verse):
|
||||
merged_align = []
|
||||
last_beads_type = None
|
||||
|
||||
for beads in alignments:
|
||||
beads_type = find_beads_type(beads, src_verse, tgt_verse)
|
||||
if not last_beads_type:
|
||||
merged_align.append(beads)
|
||||
else:
|
||||
if beads_type == last_beads_type:
|
||||
merged_align[-1][0].extend(beads[0])
|
||||
merged_align[-1][1].extend(beads[1])
|
||||
else:
|
||||
merged_align.append(beads)
|
||||
|
||||
last_beads_type = beads_type
|
||||
|
||||
return merged_align
|
||||
|
||||
def find_beads_type(beads, src_verse, tgt_verse):
|
||||
src_bead = beads[0]
|
||||
tgt_bead = beads[1]
|
||||
|
||||
src_bead_type = find_bead_type(src_bead, src_verse)
|
||||
tgt_bead_type = find_bead_type(tgt_bead, tgt_verse)
|
||||
|
||||
src_bead_len = len(src_bead_type)
|
||||
tgt_bead_len = len(tgt_bead_type)
|
||||
|
||||
if src_bead_len != 1 or tgt_bead_len != 1:
|
||||
return None
|
||||
else:
|
||||
src_verse = src_bead_type[0]
|
||||
tgt_verse = tgt_bead_type[0]
|
||||
if src_verse != tgt_verse:
|
||||
if src_verse == 'NULL':
|
||||
return tgt_verse
|
||||
elif tgt_verse == 'NULL':
|
||||
return src_verse
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return src_verse
|
||||
|
||||
def find_bead_type(bead, verse):
|
||||
bead_type = ['NULL']
|
||||
if len(bead) > 0:
|
||||
bead_type = unique_list([verse[id] for id in bead])
|
||||
|
||||
return bead_type
|
||||
|
||||
def unique_list(list):
|
||||
unique_list = []
|
||||
for x in list:
|
||||
if x not in unique_list:
|
||||
unique_list.append(x)
|
||||
|
||||
return unique_list
|
||||
|
||||
def get_verse(file):
|
||||
verse = defaultdict()
|
||||
with open(file, 'rt', encoding='utf-8') as f:
|
||||
for (i, line) in enumerate(f):
|
||||
verse[i] = line.strip()
|
||||
|
||||
return verse
|
||||
|
||||
def read_alignments(fin):
|
||||
alignments = []
|
||||
with open(fin, 'rt', encoding="utf-8") as infile:
|
||||
for line in infile:
|
||||
fields = [x.strip() for x in line.split(':') if len(x.strip())]
|
||||
if len(fields) < 2:
|
||||
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
|
||||
try:
|
||||
src = literal_eval(fields[0])
|
||||
tgt = literal_eval(fields[1])
|
||||
except:
|
||||
raise Exception('Failed to parse line "%s"' % line.strip())
|
||||
alignments.append((src, tgt))
|
||||
|
||||
return alignments
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user