Update bert_align.py
This commit is contained in:
@@ -196,7 +196,8 @@ def second_pass_align(src_vecs,
|
|||||||
|
|
||||||
def second_back_track(i, j, b, search_path, a_types):
|
def second_back_track(i, j, b, search_path, a_types):
|
||||||
alignment = []
|
alignment = []
|
||||||
while ( i !=0 and j != 0 ):
|
#while ( i !=0 and j != 0 ):
|
||||||
|
while ( 1 ):
|
||||||
j_offset = j - search_path[i][0]
|
j_offset = j - search_path[i][0]
|
||||||
a = b[i][j_offset]
|
a = b[i][j_offset]
|
||||||
s = a_types[a][0]
|
s = a_types[a][0]
|
||||||
@@ -207,8 +208,9 @@ def second_back_track(i, j, b, search_path, a_types):
|
|||||||
|
|
||||||
i = i-s
|
i = i-s
|
||||||
j = j-t
|
j = j-t
|
||||||
|
|
||||||
return alignment[::-1]
|
if i == 0 and j == 0:
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
def get_score(src_v, tgt_v,
|
def get_score(src_v, tgt_v,
|
||||||
@@ -306,7 +308,8 @@ def first_back_track(i, j, b, search_path, a_types):
|
|||||||
alignment: list of tuples for 1-1 alignments.
|
alignment: list of tuples for 1-1 alignments.
|
||||||
"""
|
"""
|
||||||
alignment = []
|
alignment = []
|
||||||
while ( i !=0 and j != 0 ):
|
#while ( i !=0 and j != 0 ):
|
||||||
|
while ( 1 ):
|
||||||
j_offset = j - search_path[i][0]
|
j_offset = j - search_path[i][0]
|
||||||
a = b[i][j_offset]
|
a = b[i][j_offset]
|
||||||
s = a_types[a][0]
|
s = a_types[a][0]
|
||||||
@@ -316,8 +319,9 @@ def first_back_track(i, j, b, search_path, a_types):
|
|||||||
|
|
||||||
i = i-s
|
i = i-s
|
||||||
j = j-t
|
j = j-t
|
||||||
|
|
||||||
return alignment[::-1]
|
if i == 0 and j == 0:
|
||||||
|
return alignment[::-1]
|
||||||
|
|
||||||
@nb.jit(nopython=True, fastmath=True, cache=True)
|
@nb.jit(nopython=True, fastmath=True, cache=True)
|
||||||
def first_pass_align(src_len,
|
def first_pass_align(src_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user