You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

223 lines
7.0 KiB

module Markov
import JSON
using StatsBase
using Logging
import HTTP
using ..Bot
import ..initialize, ..make_sentence, .. list_usernames, ..fetch_in_env
JSON_FILE = fetch_in_env("INPUT_MARKOV")
Token = Union{String, Nothing}
USERIDS = Dict{String, Int64}()
ANALYSED_SINGLE = Dict{Int64, Dict{Token, Dict{Token, Float64}}}()
ANALYZED_FORWARD = Dict{Int64, Dict{Tuple{Token, Token}, Dict{Token, Float64}}}()
ANALYZED_BACKWARD = Dict{Int64, Dict{Tuple{Token, Token}, Dict{Token, Float64}}}()
INITIALIZED = false
function reset_module()
global INITIALIZED, USERIDS, ANALYSED_SINGLE, ANALYZED_FORWARD, ANALYZED_BACKWARD
USERIDS = Dict{String, Int64}()
ANALYSED_SINGLE = Dict{Int64, Dict{Token, Dict{Token, Float64}}}()
ANALYZED_FORWARD = Dict{Int64, Dict{Tuple{Token, Token}, Dict{Token, Float64}}}()
ANALYZED_BACKWARD = Dict{Int64, Dict{Tuple{Token, Token}, Dict{Token, Float64}}}()
INITIALIZED = false
end
function list_usernames()
global USERIDS
USERIDS |> keys |> collect
end
function list_users()
global USERIDS
USERIDS |> values |> collect
end
function register_user(username, user_id)
global USERIDS
USERIDS[username] = user_id
end
register_user(user_id; default="plop") = begin
username = try
find_username(user_id)
catch e
if isa(e, HTTP.ExceptionRequest.StatusError)
if isnothing(default)
default="deleted"
end
@debug "Could not find username for id $user_id , falling back to default : $default"
default
else
throw(e)
end
end
register_user(username, user_id)
end
function analyse_line(words)
current=nothing
previous=nothing
result_forward = []
result_backward = Pair{Tuple{Token,Token},Token}[(nothing,nothing)=>nothing]
result_single = []
for w in words
result_forward = push!(result_forward, (previous, current)=>w)
result_backward = push!(result_backward, (current, w)=>previous)
result_single = push!(result_single, current=>w)
previous = current
current = w
end
push!(result_forward, (previous,current)=>nothing)
push!(result_forward, (current, nothing)=>nothing)
push!(result_backward, (current, nothing)=>previous)
push!(result_single, current=>nothing)
result_backward, result_forward, result_single
end
function analyse_all_lines(lines)
probabilities_forward = Dict{Tuple{Token, Token}, Dict{Token, Float64}}()
probabilities_backward = Dict{Tuple{Token, Token}, Dict{Token, Float64}}()
probabilities_single = Dict{Token, Dict{Token, Float64}}()
for line in split.(lines)
analysed_backward, analysed_forward, analysed_single = analyse_line(line)
for a in analysed_forward
k = first(a)
v = last(a)
if k keys(probabilities_forward)
probabilities_forward[k] = Dict{Token, Float64}()
end
probabilities_forward[k][v] = get(probabilities_forward[k], v, 0) + 1
end
for a in analysed_backward
k = first(a)
v = last(a)
if k keys(probabilities_backward)
probabilities_backward[k] = Dict{Token, Float64}()
end
probabilities_backward[k][v] = get(probabilities_backward[k], v, 0) + 1
end
for a in analysed_single
k = first(a)
v = last(a)
if k keys(probabilities_single)
probabilities_single[k] = Dict{Token, Float64}()
end
probabilities_single[k][v] = get(probabilities_single[k], v, 0) + 1
end
end
probabilities_backward, probabilities_forward, probabilities_single
end
function initialize(input_file=JSON_FILE; reset=false)
if reset
reset_module()
end
global INITIALIZED
if INITIALIZED
return
end
messages = JSON.parsefile(input_file)["messages"]
user_lines = Dict{Int64, Array{String}}()
for message in messages
if !("from_id" in keys(message))
continue
end
user_id = message["from_id"]
name = message["from"]
text = message["text"]
if typeof(text) == String && length(text)>0
if !(user_id in values(USERIDS))
register_user(user_id; default=name)
end
user_lines[user_id] = vcat(get(user_lines, user_id, []), text)
end
end
Threads.@threads for user in list_users()
ANALYZED_BACKWARD[user], ANALYZED_FORWARD[user], ANALYSED_SINGLE[user] = analyse_all_lines(user_lines[user])
end
INITIALIZED = true
end
function choose_next(user, current)
global ANALYSED_SINGLE
analysed_lines = ANALYSED_SINGLE[user]
items = collect(keys(analysed_lines[current]))
w = weights(collect(values(analysed_lines[current])))
sample(items, w)
end
function choose_next(user, previous, current)
global ANALYZED_FORWARD
analysed_lines = ANALYZED_FORWARD[user]
items = collect(keys(analysed_lines[(previous, current)]))
w = weights(collect(values(analysed_lines[(previous, current)])))
sample(items, w)
end
function choose_prev(user, current, next)
global ANALYZED_BACKWARD
analysed_lines = ANALYZED_BACKWARD[user]
items = collect(keys(analysed_lines[(current, next)]))
w = weights(collect(values(analysed_lines[(current, next)])))
sample(items, w)
end
function make_sentence_forward(user, word1, word2)
previous = word2
current = choose_next(user, word1, word2)
result = []
while !isnothing(current)
result = push!(result, current)
(previous, current) = (current, choose_next(user, previous, current))
end
join(filter(!isnothing, result), " ")
end
function make_sentence_backward(user, word1, word2)
next=word1
current=choose_prev(user, word1, word2)
result = []
while !isnothing(current)
result = pushfirst!(result, current)
(current, next) = (choose_prev(user, current, next), current)
end
join(filter(!isnothing, result), " ")
end
make_sentence(user::Union{Nothing, String}=nothing, word1::Token=nothing, word2::Token=nothing) = begin
try
username, userid = find_user(user)
if !isnothing(word1) && isnothing(word2)
word2 = choose_next(userid, word1)
end
@debug "chose start" word1 word2
start = make_sentence_backward(userid, word1, word2)
@debug "start done" start
finish = make_sentence_forward(userid, word1, word2)
@debug "finish done" finish
join(filter(!isnothing, ["<$username>", ":", start, word1, word2, finish]), " ")
catch e
if isa(e, KeyError)
@debug "Key error" e
"No luck, sorry."
else
throw(e)
end
end
end
is_registered(userid) = begin
global USERIDS
userid in values(USERIDS)
end
find_user(user::String) = begin
global USERIDS
user, USERIDS[user]
end
find_user(::Nothing) = find_user(list_usernames()[rand(1:end)])
end