Skip to content

Commit

Permalink
Merge pull request #72 from itsdfish/blending_update
Browse files Browse the repository at this point in the history
update for non-numeric blending
  • Loading branch information
itsdfish authored Nov 12, 2024
2 parents 598b57c + 24de4af commit 5b72b93
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ACTRModels"
uuid = "c095b0ea-a6ca-5cbd-afed-dbab2e976880"
authors = ["itsdfish"]
version = "0.13.1"
version = "0.13.2"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand Down
21 changes: 14 additions & 7 deletions src/MemoryFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1255,13 +1255,18 @@ function blend_chunks(actr::AbstractACTR, blended_slots, cur_time; request...)
return blend_slots(actr, chunks, probs, blended_slots)
end

function blend_slots(actr::AbstractACTR, chunks, probs, blended_slots)
function blend_slots(
actr::AbstractACTR,
chunks::Vector{<:AbstractChunk},
probs::Vector{<:Real},
blended_slots
)
return map(s -> blend_slots(actr, chunks, probs, s), blended_slots)
end

function blend_slots(actr::AbstractACTR, chunks, probs, slot::Symbol)
function blend_slots(actr::AbstractACTR, chunks::Vector{<:AbstractChunk}, probs::Vector{<:Real}, slot::Symbol)
values = map(c -> c.slots[slot], chunks)
return blend_slots(actr, probs, values)
return blend_slots(actr, probs, values, slot)
end

"""
Expand All @@ -1277,8 +1282,9 @@ Computes an expected value over numerical values.
"""
function blend_slots(
actr::AbstractACTR,
probs,
values::AbstractArray{T}
probs::Vector{<:Real},
values::AbstractArray{T},
slot::Symbol
)::Float64 where {T <: Number}
return probs' * values
end
Expand All @@ -1294,16 +1300,17 @@ Computes an expected value over non-numerical values.
- `probs`: a vector of retrieval probabilities
- `values::AbstractArray{T}`: values to be blended
"""
function blend_slots(actr::AbstractACTR, probs, values::AbstractArray{T})::T where {T}
function blend_slots(actr::AbstractACTR, probs::Vector{<:Real}, values::AbstractArray{T}, slot::Symbol)::T where {T}
n_vals = length(values)
u_values = unique(values)
n_unique = length(u_values)
vals = zeros(n_unique)
dissm_func = actr.parms.dissim_func
println("values $values")
for i 1:n_unique
v = 0.0
for j 1:n_vals
v += probs[j] * dissm_func(u_values[i], values[j])^2
v += probs[j] * dissm_func(slot, u_values[i], values[j])^2
#println("i $(u_values[i]) j $(values[j]) probs $(probs[j]) distance $(dissm_func(u_values[i], values[j])^2) v $v")
end
vals[i] = v
Expand Down
6 changes: 3 additions & 3 deletions test/Memory_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ using SafeTestsets
using ACTRModels, Test, Random
import ACTRModels: blend_slots

function dissim_func(x, y)
function dissim_func(s, x, y)
if (x == :a1 && y == :a2) || (y == :a1 && x == :a2)
return 0.1
elseif (x == :a1 && y == :a3) || (y == :a1 && x == :a3)
Expand Down Expand Up @@ -595,14 +595,14 @@ using SafeTestsets
probs = [0.40, 0.35, 0.15, 0.10]
values = map(c -> c.slots[blended_slots], chunks)

blended_value = blend_slots(actr, probs, values)
blended_value = blend_slots(actr, probs, values, blended_slots)

@test blended_value == :a1

probs = [0.30, 0.05, 0.55, 0.10]
values = map(c -> c.slots[blended_slots], chunks)

blended_value = blend_slots(actr, probs, values)
blended_value = blend_slots(actr, probs, values, blended_slots)

@test blended_value == :a2
end
Expand Down

0 comments on commit 5b72b93

Please sign in to comment.