diff --git a/src/localsearch/gmm.jl b/src/localsearch/gmm.jl index 22712a2..80bc80a 100644 --- a/src/localsearch/gmm.jl +++ b/src/localsearch/gmm.jl @@ -111,6 +111,23 @@ function GMMResult(n::Integer, clusters::AbstractVector{<:AbstractVector{<:Real} return result end +function initialize!(result::GMMResult, data::AbstractMatrix{<:Real}, indices::AbstractVector{<:Integer}; verbose::Bool = false) + n, d = size(data) + k = length(indices) + + for i in 1:k + for j in 1:d + result.clusters[i][j] = data[indices[i], j] + end + end + + if verbose + print_initial_clusters(indices) + end + + return nothing +end + function estimate_gaussian_parameters( gmm::GMM, data::AbstractMatrix{<:Real}, @@ -373,15 +390,7 @@ function fit(gmm::GMM, data::AbstractMatrix{<:Real}, initial_clusters::AbstractV @assert k > 0 @assert n >= k - for i in 1:k - for j in 1:d - result.clusters[i][j] = data[initial_clusters[i], j] - end - end - - if gmm.verbose - print_initial_clusters(initial_clusters) - end + initialize!(result, data, initial_clusters, verbose = gmm.verbose) fit!(gmm, data, result) @@ -418,13 +427,19 @@ result = fit(gmm, data, k) function fit(gmm::GMM, data::AbstractMatrix{<:Real}, k::Integer)::GMMResult n, d = size(data) + result = GMMResult(d, n, k) if n == 0 - return GMMResult(d, n, k) + return result end + @assert d > 0 @assert k > 0 @assert n >= k - initial_clusters = StatsBase.sample(gmm.rng, 1:n, k, replace = false) - return fit(gmm, data, initial_clusters) + unique_data, indices = sample_unique_data(gmm.rng, data, k) + initialize!(result, unique_data, indices, verbose = gmm.verbose) + + fit!(gmm, data, result) + + return result end diff --git a/src/localsearch/kmedoids.jl b/src/localsearch/kmedoids.jl index e1a12ab..5b5580a 100644 --- a/src/localsearch/kmedoids.jl +++ b/src/localsearch/kmedoids.jl @@ -90,6 +90,20 @@ function KmedoidsResult(n::Integer, clusters::AbstractVector{<:Integer}) return result end +function initialize!(result::KmedoidsResult, indices::AbstractVector{<:Integer}; verbose::Bool = false) + k = length(indices) + + for i in 1:k + result.clusters[i] = indices[i] + end + + if verbose + print_initial_clusters(indices) + end + + return nothing +end + @doc """ fit!( kmedoids::Kmedoids, @@ -241,13 +255,7 @@ function fit(kmedoids::Kmedoids, distances::AbstractMatrix{<:Real}, initial_clus @assert k > 0 @assert n >= k - for i in 1:k - result.clusters[i] = initial_clusters[i] - end - - if kmedoids.verbose - print_initial_clusters(initial_clusters) - end + initialize!(result, initial_clusters, verbose = kmedoids.verbose) fit!(kmedoids, distances, result) @@ -285,13 +293,18 @@ result = fit(kmedoids, distances, k) function fit(kmedoids::Kmedoids, distances::AbstractMatrix{<:Real}, k::Integer)::KmedoidsResult n = size(distances, 1) + result = KmedoidsResult(n, k) if n == 0 - return KmedoidsResult(n, k) + return result end @assert k > 0 @assert n >= k initial_clusters = StatsBase.sample(kmedoids.rng, 1:n, k, replace = false) - return fit(kmedoids, distances, initial_clusters) + initialize!(result, initial_clusters, verbose = kmedoids.verbose) + + fit!(kmedoids, distances, result) + + return result end diff --git a/src/localsearch/ksegmentation.jl b/src/localsearch/ksegmentation.jl index d108389..0ff8158 100644 --- a/src/localsearch/ksegmentation.jl +++ b/src/localsearch/ksegmentation.jl @@ -88,11 +88,12 @@ end function fit(ksegmentation::Ksegmentation, data::AbstractMatrix{<:Real}, k::Integer)::KsegmentationResult n, d = size(data) + result = KsegmentationResult(d, n, k) if n == 0 - return KsegmentationResult(d, n, k) + return result end - result = KsegmentationResult(d, n, k) fit!(ksegmentation, data, result) + return result end