Skip to content

Commit

Permalink
add slot function support
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed Nov 25, 2024
1 parent b1ddaf7 commit 130dce5
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ACTRModels"
uuid = "c095b0ea-a6ca-5cbd-afed-dbab2e976880"
authors = ["itsdfish"]
version = "0.13.4"
version = "0.13.5"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand Down
83 changes: 47 additions & 36 deletions src/MemoryFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ Computes the activation of a vector of chunks. By default, current time is compu
function compute_activation!(
actr::AbstractACTR,
chunks::Vector{<:AbstractChunk};
funs = (),
request...
)
return compute_activation!(actr, chunks, get_time(actr); request...)
return compute_activation!(actr, chunks, get_time(actr); funs, request...)
end

"""
Expand All @@ -110,13 +111,14 @@ function compute_activation!(
actr::AbstractACTR,
chunks::Vector{<:AbstractChunk},
cur_time::Float64;
funs = (),
request...
)
(; sa) = actr.parms
sa ? cache_denomitors(actr) : nothing
# compute activation for each chunk
for chunk in chunks
activation!(actr, chunk, cur_time; request...)
activation!(actr, chunk, cur_time; funs, request...)
end
return nothing
end
Expand All @@ -135,8 +137,8 @@ with `get_time`.
- `request...`: optional keywords for the retrieval request
"""
function compute_activation!(actr::AbstractACTR, chunk::AbstractChunk; request...)
return compute_activation!(actr, chunk, get_time(actr); request...)
function compute_activation!(actr::AbstractACTR, chunk::AbstractChunk; funs = (), request...)
return compute_activation!(actr, chunk, get_time(actr); funs, request...)
end

"""
Expand Down Expand Up @@ -171,8 +173,8 @@ with `get_time`.
- `request...`: optional keywords for the retrieval request
"""
compute_activation!(actr::AbstractACTR; request...) =
compute_activation!(actr, actr.declarative.memory, get_time(actr); request...)
compute_activation!(actr::AbstractACTR; funs = (), request...) =
compute_activation!(actr, actr.declarative.memory, get_time(actr); funs, request...)

"""
compute_activation!(actr::AbstractACTR, cur_time::Float64; request...)
Expand All @@ -188,8 +190,8 @@ Computes the activation of all chunks in declarative memory
- `request...`: optional keywords for the retrieval request
"""
compute_activation!(actr::AbstractACTR, cur_time::Float64; request...) =
compute_activation!(actr, actr.declarative.memory, cur_time; request...)
compute_activation!(actr::AbstractACTR, cur_time::Float64; funs = (), request...) =
compute_activation!(actr, actr.declarative.memory, cur_time; funs, request...)

"""
activation!(actr, chunk::AbstractChunk, cur_time; request...)
Expand All @@ -206,7 +208,7 @@ Computes the activation of a chunk
- `request...`: optional keywords for the retrieval request
"""
function activation!(actr::AbstractACTR, chunk::AbstractChunk, cur_time = 0.0; request...)
function activation!(actr::AbstractACTR, chunk::AbstractChunk, cur_time = 0.0; funs = (), request...)
memory = actr.declarative
(; sa_fun, bll, mmp, sa, noise, blc, τ) = actr.parms
reset_activation!(chunk)
Expand All @@ -216,7 +218,8 @@ function activation!(actr::AbstractACTR, chunk::AbstractChunk, cur_time = 0.0; r
baselevel!(actr, chunk)
end
if mmp
partial_matching!(actr, chunk; request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
partial_matching!(actr, chunk; funs = _funs, request...)
end
if sa
sa_fun(actr, chunk)
Expand Down Expand Up @@ -282,13 +285,15 @@ Computes activation for partial matching component
- `request...`: optional keyword arguments for retrieval request
"""
function partial_matching!(actr::AbstractACTR, chunk::AbstractChunk; request...)
function partial_matching!(actr::AbstractACTR, chunk::AbstractChunk; funs, request...)
slots = chunk.slots
p = 0.0
δ = actr.parms.δ
i = 1
for (k, v) in request
dissim = actr.parms.dissim_func(k, slots[k], v)
dissim = actr.parms.dissim_func(k, slots[k], v, funs[i])
p += δ * dissim
i += 1
end
chunk.act_pm = p
return nothing
Expand Down Expand Up @@ -448,15 +453,17 @@ function retrieval_prob(
actr::AbstractACTR,
target::Array{<:AbstractChunk, 1},
cur_time;
funs = (),
request...
)
(; τ, s, noise) = actr.parms
σ = s * sqrt(2)
chunks = retrieval_request(actr; request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
chunks = retrieval_request(actr; funs = _funs, request...)
filter!(x -> (x chunks), target)
isempty(target) ? (return (0.0, 1.0)) : nothing
set_noise!(actr, false)
compute_activation!(actr, chunks, cur_time; request...)
compute_activation!(actr, chunks, cur_time; funs = _funs, request...)
set_noise!(actr, noise)
denom = fill(target[1].act_mean, length(chunks) + 1)
map!(x -> exp(x.act_mean / σ), denom, chunks)
Expand All @@ -482,8 +489,8 @@ By default, current time is computed from `get_time`.
- `request...`: optional keyword pairs representing a retrieval request
"""
function retrieval_prob(actr::AbstractACTR, chunk::AbstractChunk; request...)
return retrieval_prob(actr, chunk, get_time(actr); request...)
function retrieval_prob(actr::AbstractACTR, chunk::AbstractChunk; funs = (), request...)
return retrieval_prob(actr, chunk, get_time(actr); funs, request...)
end

"""
Expand All @@ -501,13 +508,14 @@ Uses the softmax approximation to compute the retrieval probability of retrievin
- `request...`: optional keyword pairs representing a retrieval request
"""
function retrieval_prob(actr::AbstractACTR, chunk::AbstractChunk, cur_time; request...)
function retrieval_prob(actr::AbstractACTR, chunk::AbstractChunk, cur_time; funs = (), request...)
(; τ, s, noise) = actr.parms
σ = s * sqrt(2)
chunks = retrieval_request(actr; request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
chunks = retrieval_request(actr; funs = _funs, request...)
!(chunk chunks) ? (return (0.0, 1.0)) : nothing
set_noise!(actr, false)
compute_activation!(actr, chunks, cur_time; request...)
compute_activation!(actr, chunks, cur_time; funs = _funs, request...)
set_noise!(actr, noise)
v = fill(chunk.act_mean, length(chunks) + 1)
map!(x -> exp(x.act_mean / σ), v, chunks)
Expand All @@ -532,8 +540,8 @@ current time is computed from `get_time`.
- `request...`: optional keyword pairs representing a retrieval request
"""

function retrieval_probs(actr::AbstractACTR; request...)
return retrieval_probs(actr, get_time(actr); request...)
function retrieval_probs(actr::AbstractACTR; funs = (), request...)
return retrieval_probs(actr, get_time(actr); funs = (), request...)
end

"""
Expand All @@ -550,13 +558,14 @@ Computes the retrieval probability for each chunk matching the retrieval request
- `request...`: optional keyword pairs representing a retrieval request
"""
function retrieval_probs(actr::AbstractACTR, cur_time; request...)
function retrieval_probs(actr::AbstractACTR, cur_time; funs = (), request...)
(; τ, s, noise) = actr.parms
σ = s * sqrt(2)
chunks = retrieval_request(actr; request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
chunks = retrieval_request(actr; funs = _funs, request...)
isempty(chunks) ? (return ([0.0], chunks)) : nothing
set_noise!(actr, false)
compute_activation!(actr, chunks, cur_time; request...)
compute_activation!(actr, chunks, cur_time; funs = _funs, request...)
set_noise!(actr, noise)
v = Array{typeof(chunks[1].act), 1}(undef, length(chunks) + 1)
map!(x -> exp(x.act_mean / σ), v, chunks)
Expand Down Expand Up @@ -1030,9 +1039,9 @@ Returns chunks matching a retrieval request.
- `request...`: optional keyword arguments corresponding to retrieval request e.g. dog = :fiddo
"""
function retrieval_request(actr::AbstractACTR; request...)
function retrieval_request(actr::AbstractACTR; funs = (), request...)
(; mmp,) = actr.parms
!mmp ? (return get_chunks(actr; request...)) : nothing
!mmp ? (return get_chunks(actr, funs...; request...)) : nothing
chunks = get_chunks(actr; check_value = false, request...)
c = get_subset(actr; request...)
return get_chunks(chunks; check_value = true, c...)
Expand Down Expand Up @@ -1103,8 +1112,8 @@ actr = ACTR(;declarative, parms...)
retrieve(actr; country=:Germany)
```
"""
function retrieve(actr::AbstractACTR; request...)
return retrieve(actr, get_time(actr); request...)
function retrieve(actr::AbstractACTR; funs = (), request...)
return retrieve(actr, get_time(actr); funs, request...)
end

"""
Expand All @@ -1121,14 +1130,15 @@ Retrieves a chunk given a retrieval request
- `request...`: optional keyword arguments representing a retrieval request, e.g. person=:bob
"""
function retrieve(actr::AbstractACTR, cur_time; request...)
function retrieve(actr::AbstractACTR, cur_time; funs = (), request...)
(; declarative, parms) = actr
arr = Array{eltype(declarative.memory), 1}()
chunks = retrieval_request(actr; request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
chunks = retrieval_request(actr; funs = _funs, request...)
# add noise to threshold even if result of request is empty
actr.parms.noise ? add_noise!(actr) : (parms.τ′ = parms.τ)
isempty(chunks) ? (return arr) : nothing
compute_activation!(actr, chunks, cur_time; request...)
compute_activation!(actr, chunks, cur_time; funs = _funs, request...)
best = get_max_active(chunks)
if best[1].act >= parms.τ′
return best
Expand Down Expand Up @@ -1250,9 +1260,10 @@ for numeric slot-values.
- `request...`: optional keywords for the retrieval request
"""
function blend_chunks(actr::AbstractACTR, blended_slots, cur_time; request...)
chunks = retrieval_request(actr; request...)
compute_activation!(actr, chunks, cur_time; request...)
function blend_chunks(actr::AbstractACTR, blended_slots, cur_time; funs = (), request...)
_funs = isempty(funs) ? fill(==, length(request)) : funs
chunks = retrieval_request(actr; funs = _funs, request...)
compute_activation!(actr, chunks, cur_time; funs = _funs, request...)
probs = soft_max(actr, chunks)
return blend_slots(actr, chunks, probs, blended_slots)
end
Expand All @@ -1261,7 +1272,7 @@ function blend_slots(
actr::AbstractACTR,
chunks::Vector{<:AbstractChunk},
probs::Vector{<:Real},
blended_slots
blended_slots,
)
return map(s -> blend_slots(actr, chunks, probs, s), blended_slots)
end
Expand Down Expand Up @@ -1321,7 +1332,7 @@ function blend_slots(
for i 1:n_unique
v = 0.0
for j 1:n_vals
v += probs[j] * dissm_func(slot, u_values[i], values[j])^2
v += probs[j] * dissm_func(slot, u_values[i], values[j], ==)^2
#println("i $(u_values[i]) j $(values[j]) probs $(probs[j]) distance $(dissm_func(u_values[i], values[j])^2) v $v")
end
vals[i] = v
Expand Down
3 changes: 2 additions & 1 deletion src/Structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,9 @@ A default dissimilarity function which returns 1 for a mismatch and 0 otherwise.
- `s`: the slot
- `v1`: slot value 1
- `v2`: slot value 2
- `f`: function for evaluating v1 ≠ v2. Use `!` for negation
"""
default_dissim_func(s, v1, v2) = v1 v2 ? 1.0 : 0.0
default_dissim_func(s, v1, v2, f) = !f(v1, v2) ? 1.0 : 0.0

Broadcast.broadcastable(x::Declarative) = Ref(x)

Expand Down
Loading

0 comments on commit 130dce5

Please sign in to comment.