Skip to content

Commit

Permalink
Unify predict and ranking for simplicity
Browse files Browse the repository at this point in the history
since `ranking` is used as a simple wrapper of `predict`.
  • Loading branch information
takuti committed Feb 28, 2022
1 parent 37cc748 commit 16cfe67
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 24 deletions.
10 changes: 2 additions & 8 deletions src/base_recommender.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export Recommender
export isdefined, validate, fit!, recommend, predict, ranking
export isdefined, validate, fit!, recommend, predict

abstract type Recommender end

Expand Down Expand Up @@ -31,7 +31,7 @@ end
function recommend(recommender::Recommender, u::Integer, k::Integer, candidates::Array{T}) where {T<:Integer}
d = Dict{T,AbstractFloat}()
for candidate in candidates
score = ranking(recommender, u, candidate)
score = predict(recommender, u, candidate)
if isnan(score); continue; end
d[candidate] = score
end
Expand All @@ -42,9 +42,3 @@ end
function predict(recommender::Recommender, u::Integer, i::Integer)
error("predict is not implemented for recommender type $(typeof(recommender))")
end

# Return a ranking score of item i for user u
function ranking(recommender::Recommender, u::Integer, i::Integer)
validate(recommender)
predict(recommender, u, i)
end
2 changes: 1 addition & 1 deletion src/baseline/co_occurrence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function fit!(recommender::CoOccurrence)
recommender.scores[:] = CC / c * 100.0
end

function ranking(recommender::CoOccurrence, u::Integer, i::Integer)
function predict(recommender::CoOccurrence, u::Integer, i::Integer)
validate(recommender)
recommender.scores[i]
end
2 changes: 1 addition & 1 deletion src/baseline/most_popular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function fit!(recommender::MostPopular)
recommender.scores[:] = vec(sum(!iszero, recommender.data.R, dims=1))
end

function ranking(recommender::MostPopular, u::Integer, i::Integer)
function predict(recommender::MostPopular, u::Integer, i::Integer)
validate(recommender)
recommender.scores[i]
end
2 changes: 1 addition & 1 deletion src/baseline/threshold_percentage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function fit!(recommender::ThresholdPercentage)
recommender.scores[:] = vec(users_rated_higher ./ users_rated * 100.0)
end

function ranking(recommender::ThresholdPercentage, u::Integer, i::Integer)
function predict(recommender::ThresholdPercentage, u::Integer, i::Integer)
validate(recommender)
recommender.scores[i]
end
6 changes: 3 additions & 3 deletions test/baseline/test_co_occurrence.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function test_co_occurrence(data)
recommender = CoOccurrence(data, 1)
fit!(recommender)
@test ranking(recommender, 1, 1) == 100.0
@test ranking(recommender, 1, 2) == 50.0
@test ranking(recommender, 1, 3) == 0.0
@test predict(recommender, 1, 1) == 100.0
@test predict(recommender, 1, 2) == 50.0
@test predict(recommender, 1, 3) == 0.0
end

println("-- Testing CoOccurrence recommender")
Expand Down
14 changes: 7 additions & 7 deletions test/baseline/test_most_popular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ function test_most_popular()
data = DataAccessor([1 2 3; 4 5 nothing])
recommender = MostPopular(data)
fit!(recommender)
@test ranking(recommender, 1, 1) == 2.0
@test ranking(recommender, 1, 3) == 1.0
@test predict(recommender, 1, 1) == 2.0
@test predict(recommender, 1, 3) == 1.0

data = DataAccessor(sparse([1 2 3; 4 5 0]))
recommender = MostPopular(data)
fit!(recommender)
@test ranking(recommender, 1, 1) == 2.0
@test ranking(recommender, 1, 3) == 1.0
@test predict(recommender, 1, 1) == 2.0
@test predict(recommender, 1, 3) == 1.0

n_users, n_items = 5, 10
events = [Event(1, 2, 1), Event(3, 2, 1), Event(2, 6, 4)]
data = DataAccessor(events, n_users, n_items)
recommender = MostPopular(data)
fit!(recommender)
@test ranking(recommender, 1, 1) == 0.0
@test ranking(recommender, 1, 2) == 2.0
@test ranking(recommender, 1, 6) == 1.0
@test predict(recommender, 1, 1) == 0.0
@test predict(recommender, 1, 2) == 2.0
@test predict(recommender, 1, 6) == 1.0
end

test_most_popular()
4 changes: 2 additions & 2 deletions test/baseline/test_threshold_percentage.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function test_threshold_percentage(data)
recommender = ThresholdPercentage(data, 2.0)
fit!(recommender)
@test ranking(recommender, 1, 1) == 50.0
@test ranking(recommender, 1, 2) == 100.0
@test predict(recommender, 1, 1) == 50.0
@test predict(recommender, 1, 2) == 100.0
end

println("-- Testing ThresholdPercentage recommender")
Expand Down
2 changes: 1 addition & 1 deletion test/test_base_recommender.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function test_not_implemented_error()
recommender = Foo(data)
@test_throws ErrorException fit!(recommender)
@test_throws ErrorException predict(recommender, 1, 1)
@test_throws ErrorException ranking(recommender, 1, 1)
@test_throws ErrorException predict(recommender, 1, 1)
end

function test_not_build_error()
Expand Down

0 comments on commit 16cfe67

Please sign in to comment.