import sys import numpy as np from ast import literal_eval from collections import defaultdict def score_multiple(gold_list, test_list, value_for_div_by_0=0.0): # accumulate counts for all gold/test files pcounts = np.array([0, 0, 0, 0], dtype=np.int32) rcounts = np.array([0, 0, 0, 0], dtype=np.int32) for goldalign, testalign in zip(gold_list, test_list): pcounts += _precision(goldalign=goldalign, testalign=testalign) # recall is precision with no insertion/deletion and swap args test_no_del = [(x, y) for x, y in testalign if len(x) and len(y)] gold_no_del = [(x, y) for x, y in goldalign if len(x) and len(y)] rcounts += _precision(goldalign=test_no_del, testalign=gold_no_del) # Compute results # pcounts: tpstrict,fnstrict,tplax,fnlax # rcounts: tpstrict,fpstrict,tplax,fplax if pcounts[0] + pcounts[1] == 0: pstrict = value_for_div_by_0 else: pstrict = pcounts[0] / float(pcounts[0] + pcounts[1]) if pcounts[2] + pcounts[3] == 0: plax = value_for_div_by_0 else: plax = pcounts[2] / float(pcounts[2] + pcounts[3]) if rcounts[0] + rcounts[1] == 0: rstrict = value_for_div_by_0 else: rstrict = rcounts[0] / float(rcounts[0] + rcounts[1]) if rcounts[2] + rcounts[3] == 0: rlax = value_for_div_by_0 else: rlax = rcounts[2] / float(rcounts[2] + rcounts[3]) if (pstrict + rstrict) == 0: fstrict = value_for_div_by_0 else: fstrict = 2 * (pstrict * rstrict) / (pstrict + rstrict) if (plax + rlax) == 0: flax = value_for_div_by_0 else: flax = 2 * (plax * rlax) / (plax + rlax) result = dict(recall_strict=rstrict, recall_lax=rlax, precision_strict=pstrict, precision_lax=plax, f1_strict=fstrict, f1_lax=flax) return result def _precision(goldalign, testalign): """ Computes tpstrict, fpstrict, tplax, fplax for gold/test alignments """ tpstrict = 0 # true positive strict counter tplax = 0 # true positive lax counter fpstrict = 0 # false positive strict counter fplax = 0 # false positive lax counter # convert to sets, remove alignments empty on both sides testalign = set([(tuple(x), tuple(y)) for x, y in testalign if len(x) or len(y)]) goldalign = set([(tuple(x), tuple(y)) for x, y in goldalign if len(x) or len(y)]) # mappings from source test sentence idxs to # target gold sentence idxs for which the source test sentence # was found in corresponding source gold alignment src_id_to_gold_tgt_ids = defaultdict(set) for gold_src, gold_tgt in goldalign: for gold_src_id in gold_src: for gold_tgt_id in gold_tgt: src_id_to_gold_tgt_ids[gold_src_id].add(gold_tgt_id) for (test_src, test_target) in testalign: if (test_src, test_target) == ((), ()): continue if (test_src, test_target) in goldalign: # strict match tpstrict += 1 tplax += 1 else: # For anything with partial gold/test overlap on the source, # see if there is also partial overlap on the gold/test target # If so, its a lax match target_ids = set() for src_test_id in test_src: for tgt_id in src_id_to_gold_tgt_ids[src_test_id]: target_ids.add(tgt_id) if set(test_target).intersection(target_ids): fpstrict += 1 tplax += 1 else: fpstrict += 1 fplax += 1 return np.array([tpstrict, fpstrict, tplax, fplax], dtype=np.int32) def log_final_scores(res): print(' ---------------------------------', file=sys.stderr) print('| | Strict | Lax |', file=sys.stderr) print('| Precision | {precision_strict:.3f} | {precision_lax:.3f} |'.format(**res), file=sys.stderr) print('| Recall | {recall_strict:.3f} | {recall_lax:.3f} |'.format(**res), file=sys.stderr) print('| F1 | {f1_strict:.3f} | {f1_lax:.3f} |'.format(**res), file=sys.stderr) print(' ---------------------------------', file=sys.stderr) def read_alignments(file): alignments = [] with open(file, 'rt', encoding="utf-8") as f: for line in f: fields = [x.strip() for x in line.split(':') if len(x.strip())] if len(fields) < 2: raise Exception('Got line "%s", which does not have at least two ":" separated fields' % line.strip()) try: src = literal_eval(fields[0]) tgt = literal_eval(fields[1]) except: raise Exception('Failed to parse line "%s"' % line.strip()) alignments.append((src, tgt)) return alignments