Spaces:
Sleeping
Sleeping
from darknet import * | |
def predict_tactic(net, s): | |
prob = 0 | |
d = c_array(c_float, [0.0]*256) | |
tac = '' | |
if not len(s): | |
s = '\n' | |
for c in s[:-1]: | |
d[ord(c)] = 1 | |
pred = predict(net, d) | |
d[ord(c)] = 0 | |
c = s[-1] | |
while 1: | |
d[ord(c)] = 1 | |
pred = predict(net, d) | |
d[ord(c)] = 0 | |
pred = [pred[i] for i in range(256)] | |
ind = sample(pred) | |
c = chr(ind) | |
prob += math.log(pred[ind]) | |
if len(tac) and tac[-1] == '.': | |
break | |
tac = tac + c | |
return (tac, prob) | |
def predict_tactics(net, s, n): | |
tacs = [] | |
for i in range(n): | |
reset_rnn(net) | |
tacs.append(predict_tactic(net, s)) | |
tacs = sorted(tacs, key=lambda x: -x[1]) | |
return tacs | |
net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0) | |
t = predict_tactics(net, "+++++\n", 10) | |
print t | |