Skip to content

Commit

Permalink
Merge pull request #42 from takuti/ml-latest
Browse files Browse the repository at this point in the history
Add ML-latest-small loader with `binarize_multi_label`
  • Loading branch information
takuti authored Feb 13, 2022
2 parents dd8924b + 8e173cb commit 42fcd36
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 2 deletions.
90 changes: 89 additions & 1 deletion src/datasets.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export get_data_home, download_file, unzip, load_movielens_100k, load_amazon_review
export get_data_home, download_file, unzip, load_movielens_100k, load_movielens_latest, load_amazon_review

"""
get_data_home([data_home=nothing]) -> String
Expand Down Expand Up @@ -124,6 +124,94 @@ function load_movielens_100k(path::Union{String, Nothing}=nothing)
end


"""
load_movielens_latest([path=nothing])
`path` points to a locally saved [MovieLens Latest (Small)](https://files.grouplens.org/datasets/movielens/ml-latest-small-README.html).
Read user-item-rating triples in the folder, and convert them into a `DataAccessor` instance.
Download and decompress a corresponding zip file, if `path` is not given or the specified folder does not exist.
"""
function load_movielens_latest(path::Union{String, Nothing}=nothing)
n_user = 610
n_item = 9742
R = matrix(n_user, n_item)

if path == nothing || !isdir(path)
zip_path = path
if zip_path != nothing
zip_path = joinpath(dirname(zip_path), "ml-latest-small.zip")
end
zip_path = download_file("https://files.grouplens.org/datasets/movielens/ml-latest-small.zip", zip_path)
path = joinpath(unzip(zip_path, path), "ml-latest-small/")
end

events = Array{Event, 1}()
n_user, n_item = 0, 0
user_ids, item_ids = Dict{Integer, Integer}(), Dict{Integer, Integer}()
open(joinpath(path, "ratings.csv"), "r") do io
for (index, line) in enumerate(eachline(io))
if index == 1
continue
end
l = split(line, ",")
user, item, rating = parse(Int, l[1]), parse(Int, l[2]), parse(Float64, l[3])
if haskey(user_ids, user)
u = user_ids[user]
else
n_user += 1
u = n_user
user_ids[user] = n_user
end
if haskey(item_ids, item)
i = item_ids[item]
else
n_item += 1
i = n_item
item_ids[item] = n_item
end
push!(events, Event(u, i, rating))
end
end
data = DataAccessor(events, n_user, n_item)

all_genres = [
"Action",
"Adventure",
"Animation",
"Children",
"Comedy",
"Crime",
"Documentary",
"Drama",
"Fantasy",
"Film-Noir",
"Horror",
"Musical",
"Mystery",
"Romance",
"Sci-Fi",
"Thriller",
"War",
"Western"
]
open(joinpath(path, "movies.csv"), "r") do io
for (index, line) in enumerate(eachline(io))
if index == 1
continue
end
l = split(line, ",")
item, genres = parse(Int, l[1]), binarize_multi_label(split(l[3], "|"), all_genres)
if haskey(item_ids, item)
set_item_attribute(data, item_ids[item], genres)
end
end
end

data
end


"""
load_amazon_review([path=nothing; category="Electronics"]) -> DataAccessor
Expand Down
15 changes: 14 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export matrix, vector, isfilled, onehot
export matrix, vector, isfilled, onehot, binarize_multi_label

function matrix(m::Integer, n::Integer)
Array{Union{Missing, AbstractFloat}}(missing, m, n)
Expand Down Expand Up @@ -59,3 +59,16 @@ Onehot-encode the individual columns and horizontally concatenate them as an out
function onehot(mat::AbstractMatrix)
hcat(map(onehot, eachcol(mat))...)
end

function binarize_multi_label(values::AbstractVector, value_set::AbstractVector)
if !allunique(value_set)
error("duplicated value exists in a value set")
end
filter!(val -> !isa(val, Unknown), values)
filter!(val -> !isa(val, Unknown), value_set)
dims = length(value_set)
indices = findall(v -> v in values, value_set)
vec = zeros(dims)
vec[indices] .= 1.0
vec
end
18 changes: 18 additions & 0 deletions test/test_datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ function validate_movielens_100k(data::DataAccessor)
@test get_item_attribute(data, 1) == expected_item_attribute
end

function test_load_movielens_latest()
path = tempname()
println("-- Testing to download and read the MovieLens latest (small) data file: $path")
data = load_movielens_latest(path)

@test data.R[1, 1] == 4.0
@test data.R[1, 2] == 4.0
@test data.R[1, 3] == 4.0
@test data.R[1, 4] == 5.0

# Genres from 1st item in `movies.csv` - Toy Story: Adventure|Animation|Children|Comedy|Fantasy
expected_item_attribute = [
0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0
]
@test get_item_attribute(data, 1) == expected_item_attribute
end

function test_load_amazon_review()
println("-- Testing download and read Amazon Review Dataset")
@test_throws ErrorException load_amazon_review(category="foo")
Expand All @@ -78,4 +95,5 @@ test_get_data_home()
test_download_file()
test_unzip()
test_load_movielens_100k()
test_load_movielens_latest()
test_load_amazon_review()
8 changes: 8 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ function test_onehot_matrix()
@test m_encoded == expected
end

function test_binarize_multi_label()
println("-- Testing multi-label binarization")
v = ["Comedy", nothing, "Action", missing, "Anime"]
@test binarize_multi_label(v, ["Action", "Anime", "Bollywood", "Comedy"]) == [1, 1, 0, 1]
@test_throws ErrorException binarize_multi_label([1, 2, 3, 4], [1, 1, 2, 3, 4])
end

test_onehot_value()
test_onehot_vector()
test_onehot_matrix()
test_binarize_multi_label()

0 comments on commit 42fcd36

Please sign in to comment.