Skip to content

Commit

Permalink
feat: move plotting methods into Column and Table classes (#88)
Browse files Browse the repository at this point in the history
Closes #62.

### Summary of Changes

* Move plotting methods into `Column` and `Table` classes since they are
specific to these classes. We don't want to create a generic plotting
library.
* Remove the `plotting` module.

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann committed Mar 26, 2023
1 parent 15999b5 commit 5ec6189
Show file tree
Hide file tree
Showing 20 changed files with 153 additions and 218 deletions.
20 changes: 5 additions & 15 deletions docs/tutorials/data_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import correlation_heatmap\n",
"\n",
"correlation_heatmap(titanic_numerical)"
"titanic_numerical.correlation_heatmap()"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -145,9 +143,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import lineplot\n",
"\n",
"lineplot(titanic_numerical, \"survived\", \"fare\")"
"titanic_numerical.lineplot(\"survived\", \"fare\")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -180,9 +176,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import boxplot\n",
"\n",
"boxplot(titanic_numerical.get_column(\"age\"))"
"titanic_numerical.get_column(\"age\").boxplot()"
],
"metadata": {
"collapsed": false,
Expand All @@ -206,9 +200,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import histogram\n",
"\n",
"histogram(titanic_numerical.get_column(\"fare\"))"
"titanic_numerical.get_column(\"fare\").histogram()"
],
"metadata": {
"collapsed": false,
Expand All @@ -232,9 +224,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import scatterplot\n",
"\n",
"scatterplot(titanic_numerical, \"age\", \"fare\")\n"
"titanic_numerical.scatterplot(\"age\", \"fare\")\n"
],
"metadata": {
"collapsed": false,
Expand Down
42 changes: 42 additions & 0 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from numbers import Number
from typing import Any, Callable, Iterable, Iterator, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display_functions import DisplayHandle, display
from safeds.data.tabular.typing import ColumnType
from safeds.exceptions import (
Expand Down Expand Up @@ -483,3 +485,43 @@ def idness(self) -> float:
if self._data.size == 0:
raise ColumnSizeError("> 0", "0")
return self._data.nunique() / self._data.size

def boxplot(self) -> None:
"""
Plot this column in a boxplot. This function can only plot real numerical data.
Raises
-------
TypeError
If the column contains non-numerical data or complex data.
"""
for data in self._data:
if (
not isinstance(data, int)
and not isinstance(data, float)
and not isinstance(data, complex)
):
raise NonNumericColumnError(self.name)
if isinstance(data, complex):
raise TypeError(
"The column contains complex data. Boxplots cannot plot the imaginary part of complex "
"data. Please provide a Column with only real numbers"
)
ax = sns.boxplot(data=self._data)
ax.set(xlabel=self.name)
plt.tight_layout()
plt.show()

def histogram(self) -> None:
"""
Plot a column in a histogram.
"""

ax = sns.histplot(data=self._data)
ax.set_xticks(ax.get_xticks())
ax.set(xlabel=self.name)
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()
95 changes: 95 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from pathlib import Path
from typing import Callable, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display_functions import DisplayHandle, display
from pandas import DataFrame, Series
from safeds.data.tabular.containers._column import Column
Expand Down Expand Up @@ -970,3 +972,96 @@ def shuffle(self) -> Table:
new_df = self._data.sample(frac=1.0)
new_df.columns = self.schema.get_column_names()
return Table(new_df)

def correlation_heatmap(self) -> None:
"""
Plot a correlation heatmap of an entire table. This function can only plot real numerical data.
Raises
-------
TypeError
If the table contains non-numerical data or complex data.
"""
for column in self.to_columns():
if not column.type.is_numeric():
raise NonNumericColumnError(column.name)
sns.heatmap(
data=self._data.corr(),
vmin=-1,
vmax=1,
xticklabels=self.get_column_names(),
yticklabels=self.get_column_names(),
cmap="vlag",
)
plt.tight_layout()
plt.show()

def lineplot(self, x_column_name: str, y_column_name: str) -> None:
"""
Plot two columns against each other in a lineplot. If there are multiple x-values for a y-value,
the resulting plot will consist of a line representing the mean and the lower-transparency area around the line
representing the 95% confidence interval.
Parameters
----------
x_column_name : str
The column name of the column to be plotted on the x-Axis.
y_column_name : str
The column name of the column to be plotted on the y-Axis.
Raises
---------
UnknownColumnNameError
If either of the columns do not exist.
"""
if not self.has_column(x_column_name):
raise UnknownColumnNameError([x_column_name])
if not self.has_column(y_column_name):
raise UnknownColumnNameError([y_column_name])

ax = sns.lineplot(
data=self._data,
x=self.schema._get_column_index_by_name(x_column_name),
y=self.schema._get_column_index_by_name(y_column_name),
)
ax.set(xlabel=x_column_name, ylabel=y_column_name)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()

def scatterplot(self, x_column_name: str, y_column_name: str) -> None:
"""
Plot two columns against each other in a scatterplot.
Parameters
----------
x_column_name : str
The column name of the column to be plotted on the x-Axis.
y_column_name : str
The column name of the column to be plotted on the y-Axis.
Raises
---------
UnknownColumnNameError
If either of the columns do not exist.
"""
if not self.has_column(x_column_name):
raise UnknownColumnNameError([x_column_name])
if not self.has_column(y_column_name):
raise UnknownColumnNameError([y_column_name])

ax = sns.scatterplot(
data=self._data,
x=self.schema._get_column_index_by_name(x_column_name),
y=self.schema._get_column_index_by_name(y_column_name),
)
ax.set(xlabel=x_column_name, ylabel=y_column_name)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()
5 changes: 0 additions & 5 deletions src/safeds/plotting/__init__.py

This file was deleted.

38 changes: 0 additions & 38 deletions src/safeds/plotting/_boxplot.py

This file was deleted.

35 changes: 0 additions & 35 deletions src/safeds/plotting/_correlation_heatmap.py

This file was deleted.

23 changes: 0 additions & 23 deletions src/safeds/plotting/_histogram.py

This file was deleted.

44 changes: 0 additions & 44 deletions src/safeds/plotting/_lineplot.py

This file was deleted.

Loading

0 comments on commit 5ec6189

Please sign in to comment.