Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 1st pytorch example #1180

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 350 additions & 0 deletions examples/pytorch/01-Getting-started.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
{
Copy link
Contributor

@bschifferer bschifferer Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we propose with Loader(train, batch_size=1024) as loader: which is different to our TensorFlow examples?


Reply via ReviewNB

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something @edknv suggested, it ensures that the background thread gets removed. I am working on a way to see if we can move this inside our model/trainer code. Because I think the context manager approach works for single GPU, but I don't think it will work in a multi-GPU setting.

Copy link
Contributor

@bschifferer bschifferer Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, if we have the next notebook available on PyTorch - we might need to reference the TensorFlow one OR the next steps are removed OR we link to the other training examples


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! removed the cell for now and will add it once we have more examples

Copy link
Contributor

@rnyak rnyak Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #4.    model.initialize(train_loader)

can we add some explanation why we do need model.initialize() step?


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added note on the functionality of initialize

Copy link
Contributor

@rnyak rnyak Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this entire block again?


Reply via ReviewNB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure -- was just copying over what we have on the TF side, we follow the same pattern there

"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bb28e271",
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2023 NVIDIA Corporation. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# ==============================================================================\n",
"\n",
"# Each user is responsible for checking the content of datasets and the\n",
"# applicable licenses and determining if suitable for the intended use."
]
},
{
"cell_type": "markdown",
"id": "23d9bf34",
"metadata": {},
"source": [
"<img src=\"https://developer.download.nvidia.com/notebooks/dlsw%20notebooks/merlin_models_pytorch_01-getting-started/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
"\n",
"# Getting Started with Merlin Models: Develop a Model for MovieLens using the PyTorch API\n",
"\n",
"This notebook is created using the latest stable [merlin-pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-pytorch/tags) container. \n",
"\n",
"## Overview\n",
"\n",
"[Merlin Models](https://github.com/NVIDIA-Merlin/models/) is a library for training recommender models. Merlin Models let Data Scientists and ML Engineers easily train standard RecSys models on their own dataset, getting GPU-accelerated models with best practices baked into the library. This will also let researchers to build custom models by incorporating standard components of deep learning recommender models, and then benchmark their new models on example offline datasets. Merlin Models is part of the [Merlin open source framework](https://developer.nvidia.com/nvidia-merlin).\n",
"\n",
"Core features are:\n",
"- Many different recommender system architectures (tabular, two-tower, sequential) or tasks (binary, multi-class classification, multi-task)\n",
"- Flexible APIs targeted to both production and research\n",
"- Deep integration with NVIDIA Merlin platform, including NVTabular for ETL and Merlin Systems model serving\n",
"\n",
"\n",
"### Learning objectives\n",
"\n",
"- Training [Facebook's DLRM model](https://arxiv.org/pdf/1906.00091.pdf) very easily with our high-level API.\n",
"- Understanding Merlin Models high-level API"
]
},
{
"cell_type": "markdown",
"id": "1c5598ae",
"metadata": {},
"source": [
"## Downloading and preparing the dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "60653f70",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n",
" warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n",
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"import merlin.models.torch as mm\n",
"from merlin.loader.torch import Loader\n",
"import pytorch_lightning as pl\n",
"\n",
"from merlin.datasets.entertainment import get_movielens"
]
},
{
"cell_type": "markdown",
"id": "5327924b",
"metadata": {},
"source": [
"We provide the `get_movielens()` function as a convenience to download the dataset, perform simple preprocessing, and split the data into training and validation datasets."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9ba8b53d",
"metadata": {},
"outputs": [],
"source": [
"input_path = os.environ.get(\"INPUT_DATA_DIR\", os.path.expanduser(\"~/merlin-models-data/movielens/\"))\n",
"train, valid = get_movielens(variant=\"ml-1m\", path=input_path)"
]
},
{
"cell_type": "markdown",
"id": "2ee5c7c2",
"metadata": {},
"source": [
"## Training the DLRM Model with Merlin Models"
]
},
{
"cell_type": "markdown",
"id": "688b89c7",
"metadata": {},
"source": [
"We define the DLRM model, whose prediction task is a binary classification. From the `schema`, the categorical features are identified (and embedded) and the target columns are also automatically inferred, because of the schema tags. We talk more about the schema in the next [example notebook (02)](02-Merlin-Models-and-NVTabular-integration.ipynb),"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d3b8942c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.\n",
" warnings.warn('Lazy modules are a new feature under heavy development '\n"
]
}
],
"source": [
"model = mm.DLRMModel(\n",
" train.schema,\n",
" embedding_dim=64,\n",
" bottom_block=mm.MLPBlock([128, 64]),\n",
" top_block=mm.MLPBlock([128, 64, 32]),\n",
" output_block=mm.BinaryOutput(train.schema.select_by_name('rating_binary')),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "64ee4cef",
"metadata": {},
"source": [
"Next, we train the model."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "33343067",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
" warning_cache.warn(\n",
"/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:70: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
" rank_zero_warn(\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"--------------------------------------\n",
"0 | values | ModuleList | 1.1 M \n",
"--------------------------------------\n",
"1.1 M Trainable params\n",
"0 Non-trainable params\n",
"1.1 M Total params\n",
"4.459 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 101.54it/s, v_num=10]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=1` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [00:07<00:00, 101.21it/s, v_num=10]\n"
]
}
],
"source": [
"trainer = pl.Trainer(max_epochs=1)\n",
"train_loader = Loader(train, batch_size=1024)\n",
"\n",
"# The initialize step ensures the model and data are on the correct device\n",
"# and prepares the model for training\n",
"model.initialize(train_loader)\n",
"trainer.fit(model, train_loader)"
]
},
{
"cell_type": "markdown",
"id": "4bd668ab",
"metadata": {},
"source": [
"We evaluate the model and check the evaluation metrics."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "34f01ce5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation DataLoader 0: 7%|██████▌ | 13/196 [00:00<00:01, 161.86it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1024. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:01<00:00, 165.71it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Validate metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" val_binary_accuracy 0.7193425297737122\n",
" val_binary_auroc 0.7803523540496826\n",
" val_binary_precision 0.7274115681648254\n",
" val_binary_recall 0.8201844692230225\n",
" val_loss 0.5525734424591064\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 361. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
" warning_cache.warn(\n"
]
}
],
"source": [
"val_loader = Loader(valid, batch_size=1024)\n",
"metrics = trainer.validate(model, val_loader)"
]
},
{
"cell_type": "markdown",
"id": "2a6ad327",
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
"id": "eeba861b",
"metadata": {},
"source": [
"Merlin Models enables users to define and train a deep learning recommeder model with just a handful of commands.\n",
"\n",
"```python\n",
"model = mm.DLRMModel(\n",
" train.schema,\n",
" embedding_dim=64,\n",
" bottom_block=mm.MLPBlock([128, 64]),\n",
" top_block=mm.MLPBlock([128, 64, 32]),\n",
" output_block=mm.BinaryOutput(train.schema.select_by_name('rating_binary')),\n",
")\n",
"\n",
"trainer = pl.Trainer(max_epochs=1)\n",
"train_loader = Loader(train, batch_size=1024)\n",
"\n",
"model.initialize(train_loader)\n",
"trainer.fit(model, train_loader)\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"merlin": {
"containers": [
"nvcr.io/nvidia/merlin/merlin-tensorflow:latest"
]
},
"vscode": {
"interpreter": {
"hash": "ab403bb43341787581f43b51cdd291d61392c89ddb0f92179de653921d4e05db"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
10 changes: 5 additions & 5 deletions merlin/models/torch/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class DLRMModel(Model):
----------
schema : Schema
The schema to use for selection.
dim : int
The dimensionality of the output vectors.
embedding_dim : int
The dimensionality of the embedding vectors for CONTINUOUS and CATEGORICAL features.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
Expand All @@ -46,7 +46,7 @@ class DLRMModel(Model):
-------------
>>> model = mm.DLRMModel(
... schema,
... dim=64,
... embedding_dim=64,
... bottom_block=mm.MLPBlock([256, 64]),
... output_block=mm.BinaryOutput(ColumnSchema("target")),
... )
Expand All @@ -59,7 +59,7 @@ class DLRMModel(Model):
def __init__(
self,
schema: Schema,
dim: int,
embedding_dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
interaction: Optional[nn.Module] = None,
Expand All @@ -70,7 +70,7 @@ def __init__(

dlrm_body = DLRMBlock(
schema,
dim,
embedding_dim,
bottom_block,
top_block=top_block,
interaction=interaction,
Expand Down
Loading
Loading