@zsh-o
2019-06-16T21:02:10.000000Z
字数 1791
阅读 805
算法
# coding = uts-8
from collections import deque
class Node:
def __init__(self):
self.key = None
self.ending = False
self.maps = dict()
self.failure = None
self.depth = 0
# parent pointer for matched string travel
self.parent = None
def __repr__(self):
return self.key
class ACSearch:
def __init__(self, keywords):
self.root = Node()
self.root.failure = self.root
self.build(keywords)
def build(self, keywords):
# build trie tree
for keyword in keywords:
p = self.root
d = 0
for c in keyword:
d += 1
if c not in p.maps:
t = Node()
t.key = c
t.depth = d
p.maps[c] = t
t.parent = p
p = p.maps[c]
if d == len(keyword):
p.ending = True
# build failure
# build failure with BFS
queue = deque()
# 1, all 1-depth nodes' failure are root
for k, v in self.root.maps.items():
v.failure = self.root
queue.append(v)
while len(queue) != 0:
p = queue.popleft()
for k, v in p.maps.items():
failure = p.failure
while k not in failure.maps and failure is not self.root:
failure = failure.failure
if k in failure.maps:
v.failure = failure.maps[k]
else:
v.failure = failure
queue.append(v)
def find_first(self, text):
p = self.root
for index, c in enumerate(text):
while c not in p.maps and p is not self.root:
p = p.failure
if c in p.maps:
p = p.maps[c]
if p.ending is True:
# matched
keyword = []
while p.parent is not None:
keyword.append(p.key)
p = p.parent
keyword.reverse()
keyword = ''.join(keyword)
return index - len(keyword) + 1, keyword
else:
p = p.failure
return -1, ''
def find_all(self, text):
res = []
p = self.root
for index, c in enumerate(text):
while c not in p.maps and p is not self.root:
p = p.failure
if c in p.maps:
p = p.maps[c]
if p.ending is True:
# matched
keyword = []
q = p
while q.parent is not None:
keyword.append(q.key)
q = q.parent
keyword.reverse()
keyword = ''.join(keyword)
res.append([index - len(keyword) + 1, keyword])
return res
if __name__ == '__main__':
searcher = ACSearch(['ash', 'shex', 'bcd', 'sha'])
print(searcher.find_first('secashcvashebcdashare'))
print(searcher.find_all('secashcvashexbcdashare'))
(3, 'ash')
[[3, 'ash'], [8, 'ash'], [9, 'shex'], [13, 'bcd'], [16, 'ash'], [17, 'sha']]