Skip to content

Commit

Permalink
!1 修复一个 CRF 的 bug
Browse files Browse the repository at this point in the history
Merge pull request !1 from WillQvQ/dev
  • Loading branch information
WillQvQ authored and gitee-org committed Nov 6, 2020
2 parents 148ad1d + 8505617 commit f58e10d
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 80 deletions.
9 changes: 6 additions & 3 deletions .Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ pipeline {
}
}
post {
always {
sh 'post'
failure {
sh 'post 1'
}
success {
sh 'post 0'
sh 'post github'
}

}

}
2 changes: 1 addition & 1 deletion fastNLP/io/pipe/cws.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _find_and_replace_digit_spans(line):
otherwise unkdgt
"""
new_line = ''
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])'
pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])'
prev_end = 0
for match in re.finditer(pattern, line):
start, end = match.span()
Expand Down
51 changes: 35 additions & 16 deletions fastNLP/modules/decoder/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,18 @@ def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float)
has_start = False
has_end = False
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
if from_tag_id==num_tags:
has_start = True
if to_tag_id==num_tags+1:
has_end = True
if not has_start:
constrain[num_tags, :].fill_(0)
if not has_end:
constrain[:, num_tags+1].fill_(0)
self._constrain = nn.Parameter(constrain, requires_grad=False)

initial_parameter(self, initial_method)
Expand Down Expand Up @@ -290,51 +300,60 @@ def viterbi_decode(self, logits, mask, unpad=False):
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。
"""
batch_size, seq_len, n_tags = logits.size()
batch_size, max_len, n_tags = logits.size()
seq_len = mask.long().sum(1)
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(True) # L, B
flip_mask = mask.eq(False)

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0] # bsz x n_tags
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags + 1] += self.end_scores.data

vscore += transitions[n_tags, :n_tags]

trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags

# 针对长度为1的句子
vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \
.masked_fill(seq_len.ne(1).view(-1, 1), 0)
for i in range(1, max_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0)
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag
# 需要考虑当前位置是该序列的最后一个
score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0)

best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score

if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
# 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况
vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device)
lens = (seq_len - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len

ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans = logits.new_empty((max_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
for i in range(max_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i + 1], batch_idx] = last_tags
ans = ans.transpose(0, 1)
if unpad:
paths = []
for idx, seq_len in enumerate(lens):
paths.append(ans[idx, :seq_len + 1].tolist())
for idx, max_len in enumerate(lens):
paths.append(ans[idx, :max_len + 1].tolist())
else:
paths = ans
return paths, ans_score
1 change: 1 addition & 0 deletions test/data_for_tests/modules/decoder/crf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"bio_logits": [[[-1.8154915571212769, -1.3753865957260132, -10001.513671875, -1.619813084602356, -10001.79296875], [-1.742034673690796, -1.5048011541366577, -2.042131185531616, -1.2594754695892334, -1.6648437976837158], [-1.5522804260253906, -1.2926381826400757, -1.8607124090194702, -1.6692707538604736, -1.7734650373458862], [-1.6101375818252563, -1.3285458087921143, -1.7735439538955688, -1.5734118223190308, -1.8438279628753662], [-1.6522153615951538, -1.2640260457992554, -1.9092718362808228, -1.6192445755004883, -1.7168875932693481], [-1.4932769536972046, -1.4628725051879883, -1.9623159170150757, -1.497014045715332, -1.7177777290344238], [-1.8419824838638306, -2.1428799629211426, -1.4285861253738403, -1.2972710132598877, -1.5546820163726807], [-1.671349048614502, -1.4115079641342163, -1.624293565750122, -1.537371277809143, -1.8563929796218872], [-1.5080815553665161, -1.3281997442245483, -1.7912147045135498, -1.5656323432922363, -1.980512022972107], [-2.0562098026275635, -1.4711416959762573, -1.5297126770019531, -1.7554184198379517, -1.3744999170303345]], [[-1.3193378448486328, -1.997290849685669, -10002.0751953125, -1.3334847688674927, -10001.5712890625], [-1.229069471359253, -1.2702847719192505, -2.0717740058898926, -1.9828989505767822, -1.8136863708496094], [-1.8161871433258057, -1.4339262247085571, -1.4476666450500488, -1.8693819046020508, -1.562330722808838], [-1.897119402885437, -1.5767627954483032, -1.54145348072052, -1.6185026168823242, -1.4649395942687988], [-1.8498220443725586, -1.264282464981079, -1.7192784547805786, -1.8041315078735352, -1.530255913734436], [-1.1517643928527832, -1.6473538875579834, -1.5833101272583008, -1.9973593950271606, -1.894622802734375], [-1.7796387672424316, -1.8036197423934937, -1.2666513919830322, -1.4641741514205933, -1.8736846446990967], [-1.555580496788025, -1.5448863506317139, -1.609066128730774, -1.5487936735153198, -1.8138916492462158], [-1.8701002597808838, -2.0567376613616943, -1.6318782567977905, -1.2336504459381104, -1.4643338918685913], [-1.6615228652954102, -1.9764257669448853, -1.277781367301941, -1.3614437580108643, -1.990394949913025]], [[-1.74202299118042, -1.659791111946106, -10001.9951171875, -1.0417697429656982, -10001.9248046875], [-1.2423228025436401, -1.7404581308364868, -1.7569608688354492, -1.5077661275863647, -1.9528108835220337], [-1.7840592861175537, -1.50230872631073, -1.4460601806640625, -1.9473626613616943, -1.4641118049621582], [-1.6109998226165771, -2.0336639881134033, -1.3807575702667236, -1.221280574798584, -2.0938124656677246], [-1.8956525325775146, -1.6966334581375122, -1.8089725971221924, -1.9510140419006348, -1.020185947418213], [-1.7131900787353516, -1.7260419130325317, -2.161870241165161, -1.2767468690872192, -1.3956587314605713], [-1.7567639350891113, -1.1352611780166626, -1.7109652757644653, -1.8825695514678955, -1.7534843683242798], [-1.826012372970581, -1.9964908361434937, -1.7898284196853638, -1.2279980182647705, -1.413594365119934], [-1.522060513496399, -1.56121826171875, -1.5711766481399536, -1.4620665311813354, -2.0226776599884033], [-1.3122025728225708, -2.0931777954101562, -1.8858696222305298, -1.831908106803894, -1.2184979915618896]], [[-1.3956559896469116, -1.8315693140029907, -10001.48046875, -1.844576358795166, -10001.5771484375], [-1.562046766281128, -1.7216087579727173, -1.5044764280319214, -1.4362742900848389, -1.8867106437683105], [-1.5304349660873413, -1.5527287721633911, -1.5590341091156006, -1.6369349956512451, -1.7899152040481567], [-1.6007282733917236, -2.054649829864502, -1.9757367372512817, -1.4219664335250854, -1.2371348142623901], [-1.841418981552124, -1.8178046941757202, -1.5939710140228271, -1.2179311513900757, -1.7144266366958618], [-1.6715152263641357, -1.5060933828353882, -1.6629694700241089, -1.633326530456543, -1.5827515125274658], [-1.9413940906524658, -1.853175163269043, -1.6390701532363892, -1.2217824459075928, -1.5564061403274536], [-1.746218204498291, -1.7089520692825317, -1.6738371849060059, -1.627657175064087, -1.344780445098877], [-1.1776174306869507, -1.629957675933838, -1.79096519947052, -1.7566864490509033, -1.853833556175232], [-1.4880272150039673, -1.4722591638565063, -1.631064534187317, -1.9562634229660034, -1.5718109607696533]]], "bio_scores": [-1.3754, -4.5403, -8.7047, -12.8693], "bio_path": [[1], [3, 0, 1, 1], [3, 0, 1, 3, 4, 3, 1, 3], [0, 1, 1, 0, 3, 0, 3, 0, 1, 0]], "bio_trans_m": [[-0.095858134329319, 0.01011368352919817, -0.33539193868637085, -0.20200660824775696, 0.136741504073143], [0.5436117649078369, 0.37222158908843994, -0.15174923837184906, 0.10455792397260666, -0.35702475905418396], [0.3681447505950928, -0.6996435523033142, -0.002348324516788125, 0.5087339282035828, -0.08750446885824203], [0.6505969762802124, 0.0064192176796495914, -0.10901711881160736, -0.24849674105644226, -0.1375938355922699], [-0.019853945821523666, -0.9098508954048157, 0.06740495562553406, 0.2244909256696701, -0.29204151034355164]], "bio_seq_lens": [1, 4, 8, 10], "bmes_logits": [[[-10002.5830078125, -20002.54296875, -10001.9765625, -2.033155679702759, -10001.712890625, -20001.68359375, -10002.4130859375, -2.1159744262695312], [-1.870416283607483, -2.2075278759002686, -1.9922529458999634, -2.1696650981903076, -2.4956214427948, -2.1040704250335693, -2.065218925476074, -1.869700312614441], [-1.8947919607162476, -2.398089647293091, -2.1316606998443604, -1.6458176374435425, -2.001098871231079, -2.362668514251709, -2.513232707977295, -1.9884836673736572], [-1.5058399438858032, -2.3359181880950928, -2.382275342941284, -2.4573683738708496, -1.7870502471923828, -2.342841148376465, -2.1982951164245605, -2.0483522415161133], [-2.0845396518707275, -2.0447516441345215, -1.7635326385498047, -1.9375617504119873, -2.530120611190796, -1.8380637168884277, -2.099860906600952, -2.666682481765747], [-2.299673557281494, -2.3165550231933594, -1.9403637647628784, -1.8729832172393799, -1.8798956871032715, -1.8799573183059692, -2.2314014434814453, -2.39471173286438], [-1.9613308906555176, -2.136000633239746, -2.1178860664367676, -2.1553683280944824, -1.7840471267700195, -2.4148807525634766, -2.4621479511260986, -1.817263126373291], [-2.056917428970337, -2.5026133060455322, -1.9233015775680542, -2.0078444480895996, -2.064028024673462, -1.776533842086792, -2.3748488426208496, -2.114560127258301], [-2.3671767711639404, -1.7896978855133057, -2.416537284851074, -2.26574444770813, -2.2460145950317383, -1.7739624977111816, -1.9555294513702393, -2.045677661895752], [-2.3571174144744873, -1.820650577545166, -2.2781612873077393, -1.9325084686279297, -1.863953948020935, -2.2260994911193848, -2.5020244121551514, -1.8891260623931885]], [[-2.0461926460266113, -10002.0625, -10001.712890625, -2.251368761062622, -2.2985825538635254, -10002.146484375, -10002.0185546875, -2.225799560546875], [-1.9879356622695923, -2.4706358909606934, -2.3151662349700928, -1.5818747282028198, -2.329188346862793, -2.1170380115509033, -2.159011125564575, -1.9593485593795776], [-2.2397706508636475, -2.2388737201690674, -1.826286792755127, -2.444268226623535, -1.7793290615081787, -2.402519941329956, -1.8540253639221191, -2.09319806098938], [-1.7938345670700073, -2.525993585586548, -1.9962739944458008, -1.9414381980895996, -2.5183513164520264, -2.5057737827301025, -1.7933388948440552, -1.925837755203247], [-2.2330663204193115, -2.098536491394043, -1.9872602224349976, -1.7660422325134277, -2.5269722938537598, -1.9648237228393555, -1.80750572681427, -2.551790475845337], [-1.802718162536621, -2.4936702251434326, -1.846991777420044, -2.6299049854278564, -1.8180453777313232, -2.010246992111206, -1.9285591840744019, -2.5121750831604004], [-1.7665618658065796, -2.2445054054260254, -1.822519063949585, -2.5471863746643066, -2.719733715057373, -1.9708809852600098, -1.7871110439300537, -2.2026400566101074], [-2.2046854496002197, -2.375577926635742, -1.9162014722824097, -2.397550344467163, -1.9547137022018433, -1.759222149848938, -1.818831443786621, -2.4931435585021973], [-1.9187703132629395, -2.5046753883361816, -1.871201515197754, -2.3421711921691895, -2.372368335723877, -1.883248209953308, -1.8868682384490967, -2.0830271244049072], [-2.406679630279541, -1.7564219236373901, -2.340674877166748, -1.8392919301986694, -2.3711328506469727, -1.913435935974121, -2.221808433532715, -2.019878625869751]], [[-10001.7607421875, -20002.30078125, -10001.9677734375, -1.7931804656982422, -10002.2451171875, -20002.15234375, -10002.208984375, -2.4127495288848877], [-2.162931442260742, -2.121459484100342, -2.4020097255706787, -2.5620131492614746, -1.7713403701782227, -2.1945695877075195, -1.8392865657806396, -1.8513271808624268], [-2.2151875495910645, -1.9279260635375977, -2.24403977394104, -2.1955597400665283, -2.2283377647399902, -1.7366830110549927, -2.634793519973755, -1.757084608078003], [-1.813708782196045, -1.93169105052948, -2.2419192790985107, -2.307635545730591, -2.19914174079895, -2.070988178253174, -2.0030927658081055, -2.1678688526153564], [-2.118651866912842, -1.867727518081665, -2.312565326690674, -2.274792194366455, -1.9973562955856323, -2.000102996826172, -1.8425841331481934, -2.3635623455047607], [-2.435579538345337, -1.7167878150939941, -2.3040761947631836, -1.657408595085144, -2.462364912033081, -2.2767324447631836, -1.7957141399383545, -2.425132989883423], [-1.806656837463379, -1.7759110927581787, -2.5295629501342773, -1.9216285943984985, -2.2615668773651123, -1.8556532859802246, -2.4842538833618164, -2.3384106159210205], [-1.9859262704849243, -1.6575560569763184, -2.2854154109954834, -1.9267034530639648, -2.5214226245880127, -2.0166244506835938, -2.479127883911133, -2.0595011711120605], [-2.0371243953704834, -2.2420313358306885, -2.0946967601776123, -2.2463889122009277, -1.8954271078109741, -1.942257285118103, -2.0445871353149414, -2.1946396827697754], [-2.0210611820220947, -2.362877130508423, -1.9862446784973145, -1.8275481462478638, -2.140009880065918, -1.869648814201355, -2.6818318367004395, -2.0021097660064697]], [[-1.986312985420227, -10002.50390625, -10002.0361328125, -1.908732295036316, -2.21740984916687, -10002.1318359375, -10002.1044921875, -1.87873113155365], [-1.9292036294937134, -2.163956880569458, -2.3703503608703613, -1.939669132232666, -1.8776776790618896, -2.4469380378723145, -2.423905611038208, -1.7453217506408691], [-2.0289347171783447, -2.520860195159912, -2.5013701915740967, -2.078547477722168, -1.9699862003326416, -1.8206181526184082, -1.7796630859375, -2.1984922885894775], [-1.8523262739181519, -1.978093147277832, -2.558772087097168, -2.498471260070801, -1.9756053686141968, -1.8080697059631348, -1.9115748405456543, -2.357147216796875], [-2.314960479736328, -2.2433876991271973, -1.6113512516021729, -2.19716477394104, -1.78402578830719, -2.343987226486206, -2.3425848484039307, -2.084155797958374], [-2.002289056777954, -2.2630276679992676, -1.887984275817871, -2.044983386993408, -2.217646360397339, -1.9103771448135376, -2.154231548309326, -2.2321436405181885], [-2.199540853500366, -2.063075065612793, -1.813851237297058, -2.3199379444122314, -1.7984188795089722, -2.4952447414398193, -2.4516515731811523, -1.7922154664993286], [-2.509786367416382, -1.79443359375, -1.8561275005340576, -2.2977330684661865, -2.2080044746398926, -1.7294546365737915, -2.4617154598236084, -2.0944302082061768], [-2.491340160369873, -2.403804063796997, -1.8452543020248413, -1.6882175207138062, -2.5513625144958496, -2.294516086578369, -1.9522627592086792, -1.8124374151229858], [-2.1524035930633545, -2.2049806118011475, -2.3353655338287354, -2.317572832107544, -2.2914233207702637, -1.8211665153503418, -1.69517982006073, -2.0270023345947266]]], "bmes_scores": [-2.0332, -6.1623, -1.7932, -16.7561], "bmes_path": [[3], [7, 3, 4, 6], [3], [3, 4, 5, 6, 7, 3, 4, 5, 6, 7]], "bmes_trans_m": [[0.47934335470199585, -0.2151593416929245, -0.12467780709266663, -0.44244644045829773, 0.16480575501918793, -0.006573359947651625, -1.187401294708252, -0.17424514889717102], [-0.03494556248188019, -0.8173441290855408, -0.2682552933692932, 0.18933893740177155, 0.2203899323940277, 0.3905894160270691, -0.007638207171112299, 0.19527725875377655], [-0.2779119908809662, -0.37053248286247253, 0.34394705295562744, -0.26433902978897095, -0.0001995275670196861, -0.39156094193458557, -0.035449881106615067, 0.02454843744635582], [-0.01391045656055212, 0.3419516384601593, -0.48559853434562683, -0.5893992781639099, 0.9119477272033691, 0.1731061041355133, -0.15039317309856415, 0.1523006409406662], [0.4866299033164978, 0.28264448046684265, -0.25895795226097107, 0.0404033362865448, -0.060920555144548416, 0.12364576756954193, 0.1294233351945877, 0.2434755265712738], [-0.04159824922680855, 0.25353407859802246, 0.12913571298122406, -0.036356933414936066, -0.18522876501083374, -0.5329958200454712, 0.2505933344364166, 0.26512718200683594], [-0.2509276270866394, 0.3572998046875, 0.01873799040913582, -0.30620086193084717, -0.09893298894166946, -0.37399813532829285, -0.6530448198318481, -0.17514197528362274], [-0.29702028632164, 0.680363118648529, -0.6010262370109558, 0.17669369280338287, 0.45010149478912354, -0.1026386097073555, 0.34120017290115356, -0.04910941794514656]], "bmes_seq_lens": [1, 4, 1, 10]}
Loading

0 comments on commit f58e10d

Please sign in to comment.