110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
import os
|
|
import argparse
|
|
from ast import literal_eval
|
|
from collections import defaultdict
|
|
from eval import score_multiple, log_final_scores
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser('Evaluate aligment quality for Bible corpus')
|
|
parser.add_argument('-t', '--test', type=str, required=True, help='Test alignment file.')
|
|
parser.add_argument('-g', '--gold', type=str, required=True, help='Gold alignment file.')
|
|
parser.add_argument('--src_verse', type=str, required=True, help='Source verse file.')
|
|
parser.add_argument('--tgt_verse', type=str, required=True, help='Target verse file.')
|
|
args = parser.parse_args()
|
|
|
|
test_alignments = read_alignments(args.test)
|
|
gold_alignments = read_alignments(args.gold)
|
|
|
|
src_verse = get_verse(args.src_verse)
|
|
tgt_verse = get_verse(args.tgt_verse)
|
|
|
|
merged_test_alignments = merge_test_alignments(test_alignments, src_verse, tgt_verse)
|
|
res = score_multiple(gold_list=[gold_alignments], test_list=[merged_test_alignments])
|
|
log_final_scores(res)
|
|
|
|
def merge_test_alignments(alignments, src_verse, tgt_verse):
|
|
merged_align = []
|
|
last_beads_type = None
|
|
|
|
for beads in alignments:
|
|
beads_type = find_beads_type(beads, src_verse, tgt_verse)
|
|
if not last_beads_type:
|
|
merged_align.append(beads)
|
|
else:
|
|
if beads_type == last_beads_type:
|
|
merged_align[-1][0].extend(beads[0])
|
|
merged_align[-1][1].extend(beads[1])
|
|
else:
|
|
merged_align.append(beads)
|
|
|
|
last_beads_type = beads_type
|
|
|
|
return merged_align
|
|
|
|
def find_beads_type(beads, src_verse, tgt_verse):
|
|
src_bead = beads[0]
|
|
tgt_bead = beads[1]
|
|
|
|
src_bead_type = find_bead_type(src_bead, src_verse)
|
|
tgt_bead_type = find_bead_type(tgt_bead, tgt_verse)
|
|
|
|
src_bead_len = len(src_bead_type)
|
|
tgt_bead_len = len(tgt_bead_type)
|
|
|
|
if src_bead_len != 1 or tgt_bead_len != 1:
|
|
return None
|
|
else:
|
|
src_verse = src_bead_type[0]
|
|
tgt_verse = tgt_bead_type[0]
|
|
if src_verse != tgt_verse:
|
|
if src_verse == 'NULL':
|
|
return tgt_verse
|
|
elif tgt_verse == 'NULL':
|
|
return src_verse
|
|
else:
|
|
return None
|
|
else:
|
|
return src_verse
|
|
|
|
def find_bead_type(bead, verse):
|
|
bead_type = ['NULL']
|
|
if len(bead) > 0:
|
|
bead_type = unique_list([verse[id] for id in bead])
|
|
|
|
return bead_type
|
|
|
|
def unique_list(list):
|
|
unique_list = []
|
|
for x in list:
|
|
if x not in unique_list:
|
|
unique_list.append(x)
|
|
|
|
return unique_list
|
|
|
|
def get_verse(file):
|
|
verse = defaultdict()
|
|
with open(file, 'rt', encoding='utf-8') as f:
|
|
for (i, line) in enumerate(f):
|
|
verse[i] = line.strip()
|
|
|
|
return verse
|
|
|
|
def read_alignments(fin):
|
|
alignments = []
|
|
with open(fin, 'rt', encoding="utf-8") as infile:
|
|
for line in infile:
|
|
fields = [x.strip() for x in line.split(':') if len(x.strip())]
|
|
if len(fields) < 2:
|
|
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
|
|
try:
|
|
src = literal_eval(fields[0])
|
|
tgt = literal_eval(fields[1])
|
|
except:
|
|
raise Exception('Failed to parse line "%s"' % line.strip())
|
|
alignments.append((src, tgt))
|
|
|
|
return alignments
|
|
|
|
if __name__ == '__main__':
|
|
main()
|