Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move plotting methods into Column and Table classes #88

Merged
merged 4 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions docs/tutorials/data_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import correlation_heatmap\n",
"\n",
"correlation_heatmap(titanic_numerical)"
"titanic_numerical.correlation_heatmap()"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -145,9 +143,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import lineplot\n",
"\n",
"lineplot(titanic_numerical, \"survived\", \"fare\")"
"titanic_numerical.lineplot(\"survived\", \"fare\")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -180,9 +176,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import boxplot\n",
"\n",
"boxplot(titanic_numerical.get_column(\"age\"))"
"titanic_numerical.get_column(\"age\").boxplot()"
],
"metadata": {
"collapsed": false,
Expand All @@ -206,9 +200,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import histogram\n",
"\n",
"histogram(titanic_numerical.get_column(\"fare\"))"
"titanic_numerical.get_column(\"fare\").histogram()"
],
"metadata": {
"collapsed": false,
Expand All @@ -232,9 +224,7 @@
"execution_count": null,
"outputs": [],
"source": [
"from safeds.plotting import scatterplot\n",
"\n",
"scatterplot(titanic_numerical, \"age\", \"fare\")\n"
"titanic_numerical.scatterplot(\"age\", \"fare\")\n"
],
"metadata": {
"collapsed": false,
Expand Down
42 changes: 42 additions & 0 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from numbers import Number
from typing import Any, Callable, Iterable, Iterator, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display_functions import DisplayHandle, display
from safeds.data.tabular.typing import ColumnType
from safeds.exceptions import (
Expand Down Expand Up @@ -483,3 +485,43 @@ def idness(self) -> float:
if self._data.size == 0:
raise ColumnSizeError("> 0", "0")
return self._data.nunique() / self._data.size

def boxplot(self) -> None:
"""
Plot this column in a boxplot. This function can only plot real numerical data.

Raises
-------
TypeError
If the column contains non-numerical data or complex data.
"""
for data in self._data:
if (
not isinstance(data, int)
and not isinstance(data, float)
and not isinstance(data, complex)
):
raise NonNumericColumnError(self.name)
if isinstance(data, complex):
raise TypeError(
"The column contains complex data. Boxplots cannot plot the imaginary part of complex "
"data. Please provide a Column with only real numbers"
)
ax = sns.boxplot(data=self._data)
ax.set(xlabel=self.name)
plt.tight_layout()
plt.show()

def histogram(self) -> None:
"""
Plot a column in a histogram.
"""

ax = sns.histplot(data=self._data)
ax.set_xticks(ax.get_xticks())
ax.set(xlabel=self.name)
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()
95 changes: 95 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from pathlib import Path
from typing import Callable, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display_functions import DisplayHandle, display
from pandas import DataFrame, Series
from safeds.data.tabular.containers._column import Column
Expand Down Expand Up @@ -970,3 +972,96 @@ def shuffle(self) -> Table:
new_df = self._data.sample(frac=1.0)
new_df.columns = self.schema.get_column_names()
return Table(new_df)

def correlation_heatmap(self) -> None:
"""
Plot a correlation heatmap of an entire table. This function can only plot real numerical data.

Raises
-------
TypeError
If the table contains non-numerical data or complex data.
"""
for column in self.to_columns():
if not column.type.is_numeric():
raise NonNumericColumnError(column.name)
sns.heatmap(
data=self._data.corr(),
vmin=-1,
vmax=1,
xticklabels=self.get_column_names(),
yticklabels=self.get_column_names(),
cmap="vlag",
)
plt.tight_layout()
plt.show()

def lineplot(self, x_column_name: str, y_column_name: str) -> None:
"""
Plot two columns against each other in a lineplot. If there are multiple x-values for a y-value,
the resulting plot will consist of a line representing the mean and the lower-transparency area around the line
representing the 95% confidence interval.

Parameters
----------
x_column_name : str
The column name of the column to be plotted on the x-Axis.
y_column_name : str
The column name of the column to be plotted on the y-Axis.

Raises
---------
UnknownColumnNameError
If either of the columns do not exist.
"""
if not self.has_column(x_column_name):
raise UnknownColumnNameError([x_column_name])
if not self.has_column(y_column_name):
raise UnknownColumnNameError([y_column_name])

ax = sns.lineplot(
data=self._data,
x=self.schema._get_column_index_by_name(x_column_name),
y=self.schema._get_column_index_by_name(y_column_name),
)
ax.set(xlabel=x_column_name, ylabel=y_column_name)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()

def scatterplot(self, x_column_name: str, y_column_name: str) -> None:
"""
Plot two columns against each other in a scatterplot.

Parameters
----------
x_column_name : str
The column name of the column to be plotted on the x-Axis.
y_column_name : str
The column name of the column to be plotted on the y-Axis.

Raises
---------
UnknownColumnNameError
If either of the columns do not exist.
"""
if not self.has_column(x_column_name):
raise UnknownColumnNameError([x_column_name])
if not self.has_column(y_column_name):
raise UnknownColumnNameError([y_column_name])

ax = sns.scatterplot(
data=self._data,
x=self.schema._get_column_index_by_name(x_column_name),
y=self.schema._get_column_index_by_name(y_column_name),
)
ax.set(xlabel=x_column_name, ylabel=y_column_name)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, horizontalalignment="right"
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
plt.show()
5 changes: 0 additions & 5 deletions src/safeds/plotting/__init__.py

This file was deleted.

38 changes: 0 additions & 38 deletions src/safeds/plotting/_boxplot.py

This file was deleted.

35 changes: 0 additions & 35 deletions src/safeds/plotting/_correlation_heatmap.py

This file was deleted.

23 changes: 0 additions & 23 deletions src/safeds/plotting/_histogram.py

This file was deleted.

44 changes: 0 additions & 44 deletions src/safeds/plotting/_lineplot.py

This file was deleted.

Loading