delete
This commit is contained in:
@@ -1,170 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Copyright 2019 Brian Thompson
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dp_utils import read_alignments
|
||||
|
||||
"""
|
||||
Faster implementation of lax and strict precision and recall, based on
|
||||
https://www.aclweb.org/anthology/W11-4624/.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _precision(goldalign, testalign):
|
||||
"""
|
||||
Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments
|
||||
"""
|
||||
tpstrict = 0 # true positive strict counter
|
||||
tplax = 0 # true positive lax counter
|
||||
fpstrict = 0 # false positive strict counter
|
||||
fplax = 0 # false positive lax counter
|
||||
|
||||
# convert to sets, remove alignments empty on both sides
|
||||
testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)])
|
||||
goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)])
|
||||
|
||||
# mappings from source test sentence idxs to
|
||||
# target gold sentence idxs for which the source test sentence
|
||||
# was found in corresponding source gold alignment
|
||||
src_id_to_gold_tgt_ids = defaultdict(set)
|
||||
for gold_src, gold_tgt in goldalign:
|
||||
for gold_src_id in gold_src:
|
||||
for gold_tgt_id in gold_tgt:
|
||||
src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id)
|
||||
|
||||
for (test_src, test_target) in testalign:
|
||||
if (test_src, test_target) == ((), ()):
|
||||
continue
|
||||
if (test_src, test_target) in goldalign:
|
||||
# strict match
|
||||
tpstrict += 1
|
||||
tplax += 1
|
||||
else:
|
||||
# For anything with partial gold/test overlap on the source,
|
||||
# see if there is also partial overlap on the gold/test target
|
||||
# If so, its a lax match
|
||||
target_ids = set()
|
||||
for src_test_id in test_src:
|
||||
for tgt_id in src_id_to_gold_tgt_ids[src_test_id]:
|
||||
target_ids.add(tgt_id)
|
||||
if set(test_target).intersection(target_ids):
|
||||
fpstrict += 1
|
||||
tplax += 1
|
||||
else:
|
||||
fpstrict += 1
|
||||
fplax += 1
|
||||
|
||||
return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32)
|
||||
|
||||
|
||||
def score_multiple(gold_list, test_list, value_for_div_by_0=0.0):
|
||||
# accumulate counts for all gold/test files
|
||||
pcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
||||
rcounts = np.array([0, 0, 0, 0], dtype=np.int32)
|
||||
for goldalign, testalign in zip(gold_list, test_list):
|
||||
pcounts += _precision(goldalign=goldalign, testalign=testalign)
|
||||
# recall is precision with no insertion/deletion and swap args
|
||||
test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)]
|
||||
gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)]
|
||||
rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del)
|
||||
|
||||
# Compute results
|
||||
# pcounts: tpstrict,fnstrict,tplax,fnlax
|
||||
# rcounts: tpstrict,fpstrict,tplax,fplax
|
||||
|
||||
if pcounts[0] + pcounts[1] == 0:
|
||||
pstrict = value_for_div_by_0
|
||||
else:
|
||||
pstrict = pcounts[0] / float(pcounts[0] + pcounts[1])
|
||||
|
||||
if pcounts[2] + pcounts[3] == 0:
|
||||
plax = value_for_div_by_0
|
||||
else:
|
||||
plax = pcounts[2] / float(pcounts[2] + pcounts[3])
|
||||
|
||||
if rcounts[0] + rcounts[1] == 0:
|
||||
rstrict = value_for_div_by_0
|
||||
else:
|
||||
rstrict = rcounts[0] / float(rcounts[0] + rcounts[1])
|
||||
|
||||
if rcounts[2] + rcounts[3] == 0:
|
||||
rlax = value_for_div_by_0
|
||||
else:
|
||||
rlax = rcounts[2] / float(rcounts[2] + rcounts[3])
|
||||
|
||||
if (pstrict + rstrict) == 0:
|
||||
fstrict = value_for_div_by_0
|
||||
else:
|
||||
fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict)
|
||||
|
||||
if (plax + rlax) == 0:
|
||||
flax = value_for_div_by_0
|
||||
else:
|
||||
flax = 2 * (plax * rlax) / (plax + rlax)
|
||||
|
||||
result = dict(recall_strict=rstrict,
|
||||
recall_lax=rlax,
|
||||
precision_strict=pstrict,
|
||||
precision_lax=plax,
|
||||
f1_strict=fstrict,
|
||||
f1_lax=flax)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def log_final_scores(res):
|
||||
print(' ---------------------------------', file=sys.stderr)
|
||||
print('| | Strict | Lax |', file=sys.stderr)
|
||||
print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||
print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||
print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr)
|
||||
print(' ---------------------------------', file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
'Compute strict/lax precision and recall for one or more pairs of gold/test alignments',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument('-t', '--test', type=str, nargs='+', required=True,
|
||||
help='one or more test alignment files')
|
||||
|
||||
parser.add_argument('-g', '--gold', type=str, nargs='+', required=True,
|
||||
help='one or more gold alignment files')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.test) != len(args.gold):
|
||||
raise Exception('number of gold/test files must be the same')
|
||||
|
||||
gold_list = [read_alignments(x) for x in args.gold]
|
||||
test_list = [read_alignments(x) for x in args.test]
|
||||
|
||||
res = score_multiple(gold_list=gold_list, test_list=test_list)
|
||||
log_final_scores(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user