Skip to content

Commit

Permalink
update of example notebook 1
Browse files Browse the repository at this point in the history
Signed-off-by: Jonita Ruiter <[email protected]>
  • Loading branch information
JonitaRuiter committed Jan 18, 2024
1 parent f84bfd0 commit 17a2eb0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,6 @@ dmypy.json
.pyre/

# Trained models
*trained_models*
mlflow_artifacts/
mlflow_trained_models/
trained_models/
47 changes: 37 additions & 10 deletions examples/01. Train a model using high-level pipelines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
}
},
"source": [
"# Example to train a model\n",
"Using the openstf tasks"
"# Train a model\n",
"In this example notebook, a model is trained for a location with id '287'. The data for this location can be found in the 'data' folder. \n",
"First, the prediction job will be defined, which contains the properties of the training and prediction. Thereafter, the model can be trained by the ```train_model_pipeline()```. "
]
},
{
Expand All @@ -30,7 +31,8 @@
"import IPython\n",
"from openstef.pipeline.train_model import train_model_pipeline\n",
"from openstef.pipeline.create_forecast import create_forecast_pipeline\n",
"from openstef.data_classes.prediction_job import PredictionJobDataClass"
"from openstef.data_classes.prediction_job import PredictionJobDataClass\n",
"\n"
]
},
{
Expand All @@ -45,9 +47,9 @@
},
"outputs": [],
"source": [
"# define properties of training/prediction. We call this a 'prediction_job'\n",
"# Define properties of training/prediction. We call this a 'prediction_job'\n",
"pj = dict(id=287,\n",
" model='xgb',\n",
" model='xgb', \n",
" quantiles=[10,30,50,70,90],\n",
" forecast_type=\"demand\",\n",
" lat=52.0,\n",
Expand Down Expand Up @@ -81,6 +83,8 @@
},
"outputs": [],
"source": [
"# Print the train data. \n",
"# For every timestamp, bot the load as well as feature data is available. \n",
"train_data.head()"
]
},
Expand All @@ -105,7 +109,7 @@
"metadata": {},
"source": [
"# Train a model\n",
"Train a model using the high-level pipeline. Store the model and reports on training proces in ./trained_models"
"Train the model by using the high-level pipelin ```train_model_pipeline```. Store the model and reports on training proces in ./trained_models by setting mlflow_tracking_uri and artifact_folder to this path. "
]
},
{
Expand All @@ -125,8 +129,8 @@
" pj,\n",
" train_data,\n",
" check_old_model_age=False,\n",
" mlflow_tracking_uri=\"./trained_models\",\n",
" artifact_folder=\"./trained_models\",\n",
" mlflow_tracking_uri=\"./mlflow_trained_models\",\n",
" artifact_folder=\"./mlflow_artifacts\",\n",
" )"
]
},
Expand All @@ -135,7 +139,8 @@
"id": "7209dca5",
"metadata": {},
"source": [
"You can find the trained model in ./trained_models, along with reports on the training process"
"You can find the trained model in ./trained_models, along with reports on the training process. Below the Predictor0.25 and Predictor47.0 plots are shown, as well as the weight plot. The predictor plots show {nog invullen}.\n",
"The weight plot shows the importance and weight of every feature."
]
},
{
Expand All @@ -155,6 +160,28 @@
" f\"<iframe src=./trained_models/{pj['id']}/Predictor47.0.html width=800 height=400></iframe>\"\n",
" f\"<iframe src=./trained_models/{pj['id']}/weight_plot.html width=800 height=400></iframe>\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "990f046f",
"metadata": {},
"outputs": [],
"source": [
"html_file_path=\"examples\\\\mlflow_artifacts\\*\"\n",
"with open(html_file_path, 'r') as f:\n",
" html_content = f.read()\n",
"\n",
"import plotly.io as pio\n",
"\n",
"# This ensures Plotly output works in multiple places:\n",
"# plotly_mimetype: VS Code notebook UI\n",
"# notebook: \"Jupyter: Export to HTML\" command in VS Code\n",
"# See https://plotly.com/python/renderers/#multiple-renderers\n",
"pio.renderers.default = \"plotly_mimetype+notebook\"\n",
"\n",
"fig= go.Figure(data=html_content)"
]
}
],
"metadata": {
Expand All @@ -173,7 +200,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.13"
},
"toc": {
"base_numbering": 1,
Expand Down

0 comments on commit 17a2eb0

Please sign in to comment.