129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
import sys
|
|
import numpy as np
|
|
|
|
from ast import literal_eval
|
|
from collections import defaultdict
|
|
|
|
def score_multiple(gold_list, test_list, value_for_div_by_0=0.0):
|
|
# accumulate counts for all gold/test files
|
|
pcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
|
rcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
|
for goldalign, testalign in zip(gold_list, test_list):
|
|
pcounts += _precision(goldalign=goldalign, testalign=testalign)
|
|
# recall is precision with no insertion/deletion and swap args
|
|
test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)]
|
|
gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)]
|
|
rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del)
|
|
|
|
# Compute results
|
|
# pcounts: tpstrict,fnstrict,tplax,fnlax
|
|
# rcounts: tpstrict,fpstrict,tplax,fplax
|
|
|
|
if pcounts[0] + pcounts[1] == 0:
|
|
pstrict = value_for_div_by_0
|
|
else:
|
|
pstrict = pcounts[0] / float(pcounts[0] + pcounts[1])
|
|
|
|
if pcounts[2] + pcounts[3] == 0:
|
|
plax = value_for_div_by_0
|
|
else:
|
|
plax = pcounts[2] / float(pcounts[2] + pcounts[3])
|
|
|
|
if rcounts[0] + rcounts[1] == 0:
|
|
rstrict = value_for_div_by_0
|
|
else:
|
|
rstrict = rcounts[0] / float(rcounts[0] + rcounts[1])
|
|
|
|
if rcounts[2] + rcounts[3] == 0:
|
|
rlax = value_for_div_by_0
|
|
else:
|
|
rlax = rcounts[2] / float(rcounts[2] + rcounts[3])
|
|
|
|
if (pstrict + rstrict) == 0:
|
|
fstrict = value_for_div_by_0
|
|
else:
|
|
fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict)
|
|
|
|
if (plax + rlax) == 0:
|
|
flax = value_for_div_by_0
|
|
else:
|
|
flax = 2 * (plax * rlax) / (plax + rlax)
|
|
|
|
result = dict(recall_strict=rstrict,
|
|
recall_lax=rlax,
|
|
precision_strict=pstrict,
|
|
precision_lax=plax,
|
|
f1_strict=fstrict,
|
|
f1_lax=flax)
|
|
|
|
return result
|
|
|
|
def _precision(goldalign, testalign):
|
|
"""
|
|
Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments
|
|
"""
|
|
tpstrict = 0 # true positive strict counter
|
|
tplax = 0 # true positive lax counter
|
|
fpstrict = 0 # false positive strict counter
|
|
fplax = 0 # false positive lax counter
|
|
|
|
# convert to sets, remove alignments empty on both sides
|
|
testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)])
|
|
goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)])
|
|
|
|
# mappings from source test sentence idxs to
|
|
# target gold sentence idxs for which the source test sentence
|
|
# was found in corresponding source gold alignment
|
|
src_id_to_gold_tgt_ids = defaultdict(set)
|
|
for gold_src, gold_tgt in goldalign:
|
|
for gold_src_id in gold_src:
|
|
for gold_tgt_id in gold_tgt:
|
|
src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id)
|
|
|
|
for (test_src, test_target) in testalign:
|
|
if (test_src, test_target) == ((), ()):
|
|
continue
|
|
if (test_src, test_target) in goldalign:
|
|
# strict match
|
|
tpstrict += 1
|
|
tplax += 1
|
|
else:
|
|
# For anything with partial gold/test overlap on the source,
|
|
# see if there is also partial overlap on the gold/test target
|
|
# If so, its a lax match
|
|
target_ids = set()
|
|
for src_test_id in test_src:
|
|
for tgt_id in src_id_to_gold_tgt_ids[src_test_id]:
|
|
target_ids.add(tgt_id)
|
|
if set(test_target).intersection(target_ids):
|
|
fpstrict += 1
|
|
tplax += 1
|
|
else:
|
|
fpstrict += 1
|
|
fplax += 1
|
|
|
|
return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32)
|
|
|
|
def log_final_scores(res):
|
|
print(' ---------------------------------', file=sys.stderr)
|
|
print('| | Strict | Lax |', file=sys.stderr)
|
|
print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr)
|
|
print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr)
|
|
print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr)
|
|
print(' ---------------------------------', file=sys.stderr)
|
|
|
|
def read_alignments(file):
|
|
alignments = []
|
|
with open(file, 'rt', encoding="utf-8") as f:
|
|
for line in f:
|
|
fields = [x.strip() for x in line.split(':') if len(x.strip())]
|
|
if len(fields) < 2:
|
|
raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip())
|
|
try:
|
|
src = literal_eval(fields[0])
|
|
tgt = literal_eval(fields[1])
|
|
except:
|
|
raise Exception('Failed to parse line "%s"' % line.strip())
|
|
alignments.append((src, tgt))
|
|
return alignments
|