Skip to content

Commit

Permalink
Refactor with k-fold splitter for DataAccessor
Browse files Browse the repository at this point in the history
  • Loading branch information
takuti committed Apr 3, 2022
1 parent 4cd0ced commit 1ec68b7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 71 deletions.
43 changes: 42 additions & 1 deletion src/data_accessor.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export DataAccessor
export create_matrix, set_user_attribute, get_user_attribute, set_item_attribute, get_item_attribute
export create_matrix, set_user_attribute, get_user_attribute, set_item_attribute, get_item_attribute, split_events

struct DataAccessor
events::Array{Event,1}
Expand Down Expand Up @@ -62,3 +62,44 @@ end
function get_item_attribute(data::DataAccessor, item::Integer)
get(data.item_attributes, item, [])
end

function split_events(data::DataAccessor, n_folds::Integer)
if n_folds < 2
error("`n_folds` must be greater than 1 to split the samples into train and test sets.")
end

events = shuffle(data.events)
n_events = length(events)

if n_folds > n_events
error("`n_folds = $n_folds` must be less than $n_events, the number of all samples.")
end

n_users, n_items = size(data.R)

step = convert(Integer, round(n_events / n_folds))

if n_folds == n_events
@info "Splitting $n_events samples for leave-one-out cross validation"
else
@info "Splitting $n_events samples for $n_folds-fold cross validation"
end

train_test_pairs = Array{Tuple{DataAccessor, DataAccessor},1}()

for (index, head) in enumerate(1:step:n_events)
tail = min(head + step - 1, n_events)

truth_events = events[head:tail]
truth_data = DataAccessor(truth_events, n_users, n_items)

train_events = vcat(events[1:head - 1], events[tail + 1:end])
train_data = DataAccessor(train_events, n_users, n_items)

push!(train_test_pairs, (train_data, truth_data))

@debug "fold#$index will test the samples in [$head, $tail]"
end

train_test_pairs
end
80 changes: 10 additions & 70 deletions src/evaluation/cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,18 @@ export cross_validation, leave_one_out
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and ranking metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function cross_validation(n_folds::Integer, metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)

if n_folds < 2
error("`n_folds` must be greater than 1 to split the samples into train and test sets.")
end

events = shuffle(data.events)
n_events = length(events)

if n_folds > n_events
error("`n_folds = $n_folds` must be less than $n_events, the number of all samples.")
end

n_users, n_items = size(data.R)

step = convert(Integer, round(n_events / n_folds))
accum = 0.0

if n_folds == n_events
@info "Leave-one-out cross validation against $n_events samples"
else
@info "$n_folds-fold cross validation against $n_events samples"
end

for (index, head) in enumerate(1:step:n_events)
tail = min(head + step - 1, n_events)

truth_events = events[head:tail]
truth_data = DataAccessor(truth_events, n_users, n_items)

train_events = vcat(events[1:head - 1], events[tail + 1:end])
train_data = DataAccessor(train_events, n_users, n_items)

for (train_data, truth_data) in split_events(data, n_folds)
# get recommender from the specified data type
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)

accuracy = evaluate(recommender, truth_data, metric(), topk)
@info "fold#$index: tested with samples in [$head, $tail]. $metric = $accuracy"
if isnan(accuracy); continue; end
if isnan(accuracy)
@warn "cannot calculate a metric for $truth_data"
continue
end
accum += accuracy
end

accum / n_folds
end

Expand All @@ -70,48 +40,18 @@ end
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function cross_validation(n_folds::Integer, metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)

if n_folds < 2
error("`n_folds` must be greater than 1 to split the samples into train and test sets.")
end

events = shuffle(data.events)
n_events = length(events)

if n_folds > n_events
error("`n_folds = $n_folds` must be less than $n_events, the number of all samples.")
end

n_users, n_items = size(data.R)

step = convert(Integer, round(n_events / n_folds))
accum = 0.0

if n_folds == n_events
@info "Leave-one-out cross validation against $n_events samples"
else
@info "$n_folds-fold cross validation against $n_events samples"
end

for (index, head) in enumerate(1:step:n_events)
tail = min(head + step - 1, n_events)

truth_events = events[head:tail]
truth_data = DataAccessor(truth_events, n_users, n_items)

train_events = vcat(events[1:head - 1], events[tail + 1:end])
train_data = DataAccessor(train_events, n_users, n_items)

for (train_data, truth_data) in split_events(data, n_folds)
# get recommender from the specified data type
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)

accuracy = evaluate(recommender, truth_data, metric())
@info "fold#$index: tested with samples in [$head, $tail]. $metric = $accuracy"
if isnan(accuracy); continue; end
if isnan(accuracy)
@warn "cannot calculate a metric for $truth_data"
continue
end
accum += accuracy
end

accum / n_folds
end

Expand Down
5 changes: 5 additions & 0 deletions test/test_data_accessor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ function test_data_accessor()

data = DataAccessor(sparse([1 0 0; 4 5 0]))
@test size(data.R) == (2, 3)

@test_throws ErrorException split_events(data, 1)
@test_throws ErrorException split_events(data, 100)
@test length(split_events(data, 2)) == 2
@test length(split_events(data, 3)) == 3
end

test_data_accessor()

0 comments on commit 1ec68b7

Please sign in to comment.