From 5ec6189a807092b00d38620403549c96a02164a5 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Sun, 26 Mar 2023 22:02:07 +0200 Subject: [PATCH] feat: move plotting methods into `Column` and `Table` classes (#88) Closes #62. ### Summary of Changes * Move plotting methods into `Column` and `Table` classes since they are specific to these classes. We don't want to create a generic plotting library. * Remove the `plotting` module. --------- Co-authored-by: lars-reimann --- docs/tutorials/data_visualization.ipynb | 20 +--- src/safeds/data/tabular/containers/_column.py | 42 ++++++++ src/safeds/data/tabular/containers/_table.py | 95 +++++++++++++++++++ src/safeds/plotting/__init__.py | 5 - src/safeds/plotting/_boxplot.py | 38 -------- src/safeds/plotting/_correlation_heatmap.py | 35 ------- src/safeds/plotting/_histogram.py | 23 ----- src/safeds/plotting/_lineplot.py | 44 --------- src/safeds/plotting/_scatterplot.py | 42 -------- .../containers/_column}/test_boxplot.py | 9 +- .../containers/_column}/test_histogram.py | 3 +- .../_table}/test_correlation_heatmap.py | 5 +- .../containers/_table}/test_lineplot.py | 5 +- .../containers/_table}/test_scatterplot.py | 5 +- tests/safeds/plotting/__init__.py | 0 tests/safeds/plotting/_boxplot/__init__.py | 0 .../plotting/_correlation_heatmap/__init__.py | 0 tests/safeds/plotting/_histogram/__init__.py | 0 tests/safeds/plotting/_lineplot/__init__.py | 0 .../safeds/plotting/_scatterplot/__init__.py | 0 20 files changed, 153 insertions(+), 218 deletions(-) delete mode 100644 src/safeds/plotting/__init__.py delete mode 100644 src/safeds/plotting/_boxplot.py delete mode 100644 src/safeds/plotting/_correlation_heatmap.py delete mode 100644 src/safeds/plotting/_histogram.py delete mode 100644 src/safeds/plotting/_lineplot.py delete mode 100644 src/safeds/plotting/_scatterplot.py rename tests/safeds/{plotting/_boxplot => data/tabular/containers/_column}/test_boxplot.py (79%) rename tests/safeds/{plotting/_histogram => data/tabular/containers/_column}/test_histogram.py (79%) rename tests/safeds/{plotting/_correlation_heatmap => data/tabular/containers/_table}/test_correlation_heatmap.py (83%) rename tests/safeds/{plotting/_lineplot => data/tabular/containers/_table}/test_lineplot.py (83%) rename tests/safeds/{plotting/_scatterplot => data/tabular/containers/_table}/test_scatterplot.py (82%) delete mode 100644 tests/safeds/plotting/__init__.py delete mode 100644 tests/safeds/plotting/_boxplot/__init__.py delete mode 100644 tests/safeds/plotting/_correlation_heatmap/__init__.py delete mode 100644 tests/safeds/plotting/_histogram/__init__.py delete mode 100644 tests/safeds/plotting/_lineplot/__init__.py delete mode 100644 tests/safeds/plotting/_scatterplot/__init__.py diff --git a/docs/tutorials/data_visualization.ipynb b/docs/tutorials/data_visualization.ipynb index 440d07193..b9a4e102f 100644 --- a/docs/tutorials/data_visualization.ipynb +++ b/docs/tutorials/data_visualization.ipynb @@ -102,9 +102,7 @@ "execution_count": null, "outputs": [], "source": [ - "from safeds.plotting import correlation_heatmap\n", - "\n", - "correlation_heatmap(titanic_numerical)" + "titanic_numerical.correlation_heatmap()" ], "metadata": { "collapsed": false, @@ -145,9 +143,7 @@ "execution_count": null, "outputs": [], "source": [ - "from safeds.plotting import lineplot\n", - "\n", - "lineplot(titanic_numerical, \"survived\", \"fare\")" + "titanic_numerical.lineplot(\"survived\", \"fare\")" ], "metadata": { "collapsed": false, @@ -180,9 +176,7 @@ "execution_count": null, "outputs": [], "source": [ - "from safeds.plotting import boxplot\n", - "\n", - "boxplot(titanic_numerical.get_column(\"age\"))" + "titanic_numerical.get_column(\"age\").boxplot()" ], "metadata": { "collapsed": false, @@ -206,9 +200,7 @@ "execution_count": null, "outputs": [], "source": [ - "from safeds.plotting import histogram\n", - "\n", - "histogram(titanic_numerical.get_column(\"fare\"))" + "titanic_numerical.get_column(\"fare\").histogram()" ], "metadata": { "collapsed": false, @@ -232,9 +224,7 @@ "execution_count": null, "outputs": [], "source": [ - "from safeds.plotting import scatterplot\n", - "\n", - "scatterplot(titanic_numerical, \"age\", \"fare\")\n" + "titanic_numerical.scatterplot(\"age\", \"fare\")\n" ], "metadata": { "collapsed": false, diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index fad1766bc..ab6433450 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -3,8 +3,10 @@ from numbers import Number from typing import Any, Callable, Iterable, Iterator, Optional +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns from IPython.core.display_functions import DisplayHandle, display from safeds.data.tabular.typing import ColumnType from safeds.exceptions import ( @@ -483,3 +485,43 @@ def idness(self) -> float: if self._data.size == 0: raise ColumnSizeError("> 0", "0") return self._data.nunique() / self._data.size + + def boxplot(self) -> None: + """ + Plot this column in a boxplot. This function can only plot real numerical data. + + Raises + ------- + TypeError + If the column contains non-numerical data or complex data. + """ + for data in self._data: + if ( + not isinstance(data, int) + and not isinstance(data, float) + and not isinstance(data, complex) + ): + raise NonNumericColumnError(self.name) + if isinstance(data, complex): + raise TypeError( + "The column contains complex data. Boxplots cannot plot the imaginary part of complex " + "data. Please provide a Column with only real numbers" + ) + ax = sns.boxplot(data=self._data) + ax.set(xlabel=self.name) + plt.tight_layout() + plt.show() + + def histogram(self) -> None: + """ + Plot a column in a histogram. + """ + + ax = sns.histplot(data=self._data) + ax.set_xticks(ax.get_xticks()) + ax.set(xlabel=self.name) + ax.set_xticklabels( + ax.get_xticklabels(), rotation=45, horizontalalignment="right" + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + plt.tight_layout() + plt.show() diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index e097938ba..b0e4dac20 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -6,8 +6,10 @@ from pathlib import Path from typing import Callable, Optional, Union +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns from IPython.core.display_functions import DisplayHandle, display from pandas import DataFrame, Series from safeds.data.tabular.containers._column import Column @@ -970,3 +972,96 @@ def shuffle(self) -> Table: new_df = self._data.sample(frac=1.0) new_df.columns = self.schema.get_column_names() return Table(new_df) + + def correlation_heatmap(self) -> None: + """ + Plot a correlation heatmap of an entire table. This function can only plot real numerical data. + + Raises + ------- + TypeError + If the table contains non-numerical data or complex data. + """ + for column in self.to_columns(): + if not column.type.is_numeric(): + raise NonNumericColumnError(column.name) + sns.heatmap( + data=self._data.corr(), + vmin=-1, + vmax=1, + xticklabels=self.get_column_names(), + yticklabels=self.get_column_names(), + cmap="vlag", + ) + plt.tight_layout() + plt.show() + + def lineplot(self, x_column_name: str, y_column_name: str) -> None: + """ + Plot two columns against each other in a lineplot. If there are multiple x-values for a y-value, + the resulting plot will consist of a line representing the mean and the lower-transparency area around the line + representing the 95% confidence interval. + + Parameters + ---------- + x_column_name : str + The column name of the column to be plotted on the x-Axis. + y_column_name : str + The column name of the column to be plotted on the y-Axis. + + Raises + --------- + UnknownColumnNameError + If either of the columns do not exist. + """ + if not self.has_column(x_column_name): + raise UnknownColumnNameError([x_column_name]) + if not self.has_column(y_column_name): + raise UnknownColumnNameError([y_column_name]) + + ax = sns.lineplot( + data=self._data, + x=self.schema._get_column_index_by_name(x_column_name), + y=self.schema._get_column_index_by_name(y_column_name), + ) + ax.set(xlabel=x_column_name, ylabel=y_column_name) + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), rotation=45, horizontalalignment="right" + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + plt.tight_layout() + plt.show() + + def scatterplot(self, x_column_name: str, y_column_name: str) -> None: + """ + Plot two columns against each other in a scatterplot. + + Parameters + ---------- + x_column_name : str + The column name of the column to be plotted on the x-Axis. + y_column_name : str + The column name of the column to be plotted on the y-Axis. + + Raises + --------- + UnknownColumnNameError + If either of the columns do not exist. + """ + if not self.has_column(x_column_name): + raise UnknownColumnNameError([x_column_name]) + if not self.has_column(y_column_name): + raise UnknownColumnNameError([y_column_name]) + + ax = sns.scatterplot( + data=self._data, + x=self.schema._get_column_index_by_name(x_column_name), + y=self.schema._get_column_index_by_name(y_column_name), + ) + ax.set(xlabel=x_column_name, ylabel=y_column_name) + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), rotation=45, horizontalalignment="right" + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + plt.tight_layout() + plt.show() diff --git a/src/safeds/plotting/__init__.py b/src/safeds/plotting/__init__.py deleted file mode 100644 index 51b9a8c71..000000000 --- a/src/safeds/plotting/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._boxplot import boxplot -from ._correlation_heatmap import correlation_heatmap -from ._histogram import histogram -from ._lineplot import lineplot -from ._scatterplot import scatterplot diff --git a/src/safeds/plotting/_boxplot.py b/src/safeds/plotting/_boxplot.py deleted file mode 100644 index d7d0f24b0..000000000 --- a/src/safeds/plotting/_boxplot.py +++ /dev/null @@ -1,38 +0,0 @@ -import matplotlib.pyplot as plt -import seaborn as sns -from safeds.data.tabular.containers import Column -from safeds.exceptions import NonNumericColumnError - - -def boxplot(column: Column) -> None: - """ - Plot a column in a boxplot. This function can only plot real numerical data. - - Parameters - ---------- - column : Column - The column to be plotted. - - Raises - ------- - TypeError - If the column contains non-numerical data or complex data. - """ - # noinspection PyProtectedMember - for data in column._data: - if ( - not isinstance(data, int) - and not isinstance(data, float) - and not isinstance(data, complex) - ): - raise NonNumericColumnError(column.name) - if isinstance(data, complex): - raise TypeError( - "The column contains complex data. Boxplots cannot plot the imaginary part of complex " - "data. Please provide a Column with only real numbers" - ) - # noinspection PyProtectedMember - ax = sns.boxplot(data=column._data) - ax.set(xlabel=column.name) - plt.tight_layout() - plt.show() diff --git a/src/safeds/plotting/_correlation_heatmap.py b/src/safeds/plotting/_correlation_heatmap.py deleted file mode 100644 index 003a6fce3..000000000 --- a/src/safeds/plotting/_correlation_heatmap.py +++ /dev/null @@ -1,35 +0,0 @@ -import matplotlib.pyplot as plt -import seaborn as sns -from safeds.data.tabular.containers import Table -from safeds.exceptions import NonNumericColumnError - - -def correlation_heatmap(table: Table) -> None: - """ - Plot a correlation heatmap of an entire table. This function can only plot real numerical data. - - Parameters - ---------- - table : Table - The column to be plotted. - - Raises - ------- - TypeError - If the table contains non-numerical data or complex data. - """ - # noinspection PyProtectedMember - for column in table.to_columns(): - if not column.type.is_numeric(): - raise NonNumericColumnError(column.name) - # noinspection PyProtectedMember - sns.heatmap( - data=table._data.corr(), - vmin=-1, - vmax=1, - xticklabels=table.get_column_names(), - yticklabels=table.get_column_names(), - cmap="vlag", - ) - plt.tight_layout() - plt.show() diff --git a/src/safeds/plotting/_histogram.py b/src/safeds/plotting/_histogram.py deleted file mode 100644 index 02c346bc2..000000000 --- a/src/safeds/plotting/_histogram.py +++ /dev/null @@ -1,23 +0,0 @@ -import matplotlib.pyplot as plt -import seaborn as sns -from safeds.data.tabular.containers import Column - - -def histogram(column: Column) -> None: - """ - Plot a column in a histogram. - - Parameters - ---------- - column : Column - The column to be plotted. - """ - # noinspection PyProtectedMember - ax = sns.histplot(data=column._data) - ax.set_xticks(ax.get_xticks()) - ax.set(xlabel=column.name) - ax.set_xticklabels( - ax.get_xticklabels(), rotation=45, horizontalalignment="right" - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - plt.tight_layout() - plt.show() diff --git a/src/safeds/plotting/_lineplot.py b/src/safeds/plotting/_lineplot.py deleted file mode 100644 index 5a68b33da..000000000 --- a/src/safeds/plotting/_lineplot.py +++ /dev/null @@ -1,44 +0,0 @@ -import matplotlib.pyplot as plt -import seaborn as sns -from safeds.data.tabular.containers import Table -from safeds.exceptions import UnknownColumnNameError - - -def lineplot(table: Table, x: str, y: str) -> None: - """ - Plot two columns against each other in a lineplot. If there are multiple x-values for a y-value, - the resulting plot will consist of a line representing the mean and the lower-transparency area around the line - representing the 95% confidence interval. - - Parameters - ---------- - table : Table - The table containing the data to be plotted. - x : str - The column name of the column to be plotted on the x-Axis. - y : str - The column name of the column to be plotted on the y-Axis. - - Raises - --------- - UnknownColumnNameError - If either of the columns do not exist. - """ - # noinspection PyProtectedMember - if not table.has_column(x): - raise UnknownColumnNameError([x]) - if not table.has_column(y): - raise UnknownColumnNameError([y]) - - ax = sns.lineplot( - data=table._data, - x=table.schema._get_column_index_by_name(x), - y=table.schema._get_column_index_by_name(y), - ) - ax.set(xlabel=x, ylabel=y) - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels( - ax.get_xticklabels(), rotation=45, horizontalalignment="right" - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - plt.tight_layout() - plt.show() diff --git a/src/safeds/plotting/_scatterplot.py b/src/safeds/plotting/_scatterplot.py deleted file mode 100644 index 3fde8daa3..000000000 --- a/src/safeds/plotting/_scatterplot.py +++ /dev/null @@ -1,42 +0,0 @@ -import matplotlib.pyplot as plt -import seaborn as sns -from safeds.data.tabular.containers import Table -from safeds.exceptions import UnknownColumnNameError - - -def scatterplot(table: Table, x: str, y: str) -> None: - """ - Plot two columns against each other in a scatterplot. - - Parameters - ---------- - table : Table - The table containing the data to be plotted. - x : str - The column name of the column to be plotted on the x-Axis. - y : str - The column name of the column to be plotted on the y-Axis. - - Raises - --------- - UnknownColumnNameError - If either of the columns do not exist. - """ - # noinspection PyProtectedMember - if not table.has_column(x): - raise UnknownColumnNameError([x]) - if not table.has_column(y): - raise UnknownColumnNameError([y]) - - ax = sns.scatterplot( - data=table._data, - x=table.schema._get_column_index_by_name(x), - y=table.schema._get_column_index_by_name(y), - ) - ax.set(xlabel=x, ylabel=y) - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels( - ax.get_xticklabels(), rotation=45, horizontalalignment="right" - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - plt.tight_layout() - plt.show() diff --git a/tests/safeds/plotting/_boxplot/test_boxplot.py b/tests/safeds/data/tabular/containers/_column/test_boxplot.py similarity index 79% rename from tests/safeds/plotting/_boxplot/test_boxplot.py rename to tests/safeds/data/tabular/containers/_column/test_boxplot.py index dee1f14b0..c1e375949 100644 --- a/tests/safeds/plotting/_boxplot/test_boxplot.py +++ b/tests/safeds/data/tabular/containers/_column/test_boxplot.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import pandas as pd import pytest -from safeds import plotting from safeds.data.tabular.containers import Table from safeds.exceptions import NonNumericColumnError @@ -10,22 +9,22 @@ def test_boxplot_complex() -> None: with pytest.raises(TypeError): table = Table(pd.DataFrame(data={"A": [1, 2, complex(1, -2)]})) - plotting.boxplot(table.get_column("A")) + table.get_column("A").boxplot() def test_boxplot_non_numeric() -> None: with pytest.raises(NonNumericColumnError): table = Table(pd.DataFrame(data={"A": [1, 2, "A"]})) - plotting.boxplot(table.get_column("A")) + table.get_column("A").boxplot() def test_boxplot_float(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3.5]})) - plotting.boxplot(table.get_column("A")) + table.get_column("A").boxplot() def test_boxplot_int(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3]})) - plotting.boxplot(table.get_column("A")) + table.get_column("A").boxplot() diff --git a/tests/safeds/plotting/_histogram/test_histogram.py b/tests/safeds/data/tabular/containers/_column/test_histogram.py similarity index 79% rename from tests/safeds/plotting/_histogram/test_histogram.py rename to tests/safeds/data/tabular/containers/_column/test_histogram.py index d68f7b693..ca9c35d54 100644 --- a/tests/safeds/plotting/_histogram/test_histogram.py +++ b/tests/safeds/data/tabular/containers/_column/test_histogram.py @@ -2,10 +2,9 @@ import matplotlib.pyplot as plt import pandas as pd from safeds.data.tabular.containers import Table -from safeds.plotting import histogram def test_histogram(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3]})) - histogram(table.get_column("A")) + table.get_column("A").histogram() diff --git a/tests/safeds/plotting/_correlation_heatmap/test_correlation_heatmap.py b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py similarity index 83% rename from tests/safeds/plotting/_correlation_heatmap/test_correlation_heatmap.py rename to tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py index 9ddbd509a..8f3b94746 100644 --- a/tests/safeds/plotting/_correlation_heatmap/test_correlation_heatmap.py +++ b/tests/safeds/data/tabular/containers/_table/test_correlation_heatmap.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import pandas as pd import pytest -from safeds import plotting from safeds.data.tabular.containers import Table from safeds.exceptions import NonNumericColumnError @@ -10,10 +9,10 @@ def test_correlation_heatmap_non_numeric() -> None: with pytest.raises(NonNumericColumnError): table = Table(pd.DataFrame(data={"A": [1, 2, "A"], "B": [1, 2, "A"]})) - plotting.correlation_heatmap(table) + table.correlation_heatmap() def test_correlation_heatmap(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3.5], "B": [2, 4, 7]})) - plotting.correlation_heatmap(table) + table.correlation_heatmap() diff --git a/tests/safeds/plotting/_lineplot/test_lineplot.py b/tests/safeds/data/tabular/containers/_table/test_lineplot.py similarity index 83% rename from tests/safeds/plotting/_lineplot/test_lineplot.py rename to tests/safeds/data/tabular/containers/_table/test_lineplot.py index af4396480..d4b1fe721 100644 --- a/tests/safeds/plotting/_lineplot/test_lineplot.py +++ b/tests/safeds/data/tabular/containers/_table/test_lineplot.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import pandas as pd import pytest -from safeds import plotting from safeds.data.tabular.containers import Table from safeds.exceptions import UnknownColumnNameError @@ -10,10 +9,10 @@ def test_lineplot(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 4, 7]})) - plotting.lineplot(table, "A", "B") + table.lineplot("A", "B") def test_lineplot_wrong_column_name() -> None: with pytest.raises(UnknownColumnNameError): table = Table(pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 4, 7]})) - plotting.lineplot(table, "C", "A") + table.lineplot("C", "A") diff --git a/tests/safeds/plotting/_scatterplot/test_scatterplot.py b/tests/safeds/data/tabular/containers/_table/test_scatterplot.py similarity index 82% rename from tests/safeds/plotting/_scatterplot/test_scatterplot.py rename to tests/safeds/data/tabular/containers/_table/test_scatterplot.py index d41691f89..8ff59e995 100644 --- a/tests/safeds/plotting/_scatterplot/test_scatterplot.py +++ b/tests/safeds/data/tabular/containers/_table/test_scatterplot.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import pandas as pd import pytest -from safeds import plotting from safeds.data.tabular.containers import Table from safeds.exceptions import UnknownColumnNameError @@ -10,10 +9,10 @@ def test_scatterplot(monkeypatch: _pytest.monkeypatch) -> None: monkeypatch.setattr(plt, "show", lambda: None) table = Table(pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 4, 7]})) - plotting.scatterplot(table, "A", "B") + table.scatterplot("A", "B") def test_scatterplot_wrong_column_name() -> None: with pytest.raises(UnknownColumnNameError): table = Table(pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 4, 7]})) - plotting.scatterplot(table, "C", "A") + table.scatterplot("C", "A") diff --git a/tests/safeds/plotting/__init__.py b/tests/safeds/plotting/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/plotting/_boxplot/__init__.py b/tests/safeds/plotting/_boxplot/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/plotting/_correlation_heatmap/__init__.py b/tests/safeds/plotting/_correlation_heatmap/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/plotting/_histogram/__init__.py b/tests/safeds/plotting/_histogram/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/plotting/_lineplot/__init__.py b/tests/safeds/plotting/_lineplot/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/plotting/_scatterplot/__init__.py b/tests/safeds/plotting/_scatterplot/__init__.py deleted file mode 100644 index e69de29bb..000000000