Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC + WIP: Support multiprocessing Version 1 #3392

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 0 additions & 3 deletions pygmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pygmt.accessors import GMTDataArrayAccessor
from pygmt.figure import Figure, set_display
from pygmt.io import load_dataarray
from pygmt.session_management import begin as _begin
from pygmt.session_management import end as _end
from pygmt.src import (
binstats,
Expand Down Expand Up @@ -66,7 +65,5 @@
xyz2grd,
)

# Start our global modern mode session
_begin()
# Tell Python to run _end when shutting down
_atexit.register(_end)
9 changes: 9 additions & 0 deletions pygmt/_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Private dictionary to keep tracking of current PyGMT state.

The feature is only meant for internal use by PyGMT and is experimental!
"""

_STATE = {
"session_name": None,
}
15 changes: 15 additions & 0 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import contextlib
import ctypes as ctp
import io
import os
import pathlib
import sys
import warnings
Expand All @@ -18,6 +19,7 @@
import pandas as pd
import xarray as xr
from packaging.version import Version
from pygmt._state import _STATE
from pygmt.clib.conversion import (
array_to_datetime,
as_c_contiguous,
Expand Down Expand Up @@ -216,7 +218,20 @@ def __enter__(self):

Calls :meth:`pygmt.clib.Session.create`.
"""
_init_cli_session = False
# This is the first time a Session object is created.
if _STATE["session_name"] is None:
# Set GMT_SESSION_NAME to the current process id.
_STATE["session_name"] = os.environ["GMT_SESSION_NAME"] = str(os.getpid())
# Need to initialize the GMT CLI session.
_init_cli_session = True
self.create("pygmt-session")

if _init_cli_session:
self.call_module("begin", args=["pygmt-session"])
self.call_module(module="set", args=["GMT_COMPATIBILITY=6"])
del _init_cli_session

return self

def __exit__(self, exc_type, exc_value, traceback):
Expand Down
11 changes: 3 additions & 8 deletions pygmt/session_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
Modern mode session management modules.
"""

import os
import sys

from pygmt._state import _STATE
from pygmt.clib import Session
from pygmt.helpers import unique_name


def begin():
Expand All @@ -17,10 +14,6 @@ def begin():

Only meant to be used once for creating the global session.
"""
# On Windows, need to set GMT_SESSION_NAME to a unique value
if sys.platform == "win32":
os.environ["GMT_SESSION_NAME"] = unique_name()

prefix = "pygmt-session"
with Session() as lib:
lib.call_module(module="begin", args=[prefix])
Expand All @@ -39,3 +32,5 @@ def end():
"""
with Session() as lib:
lib.call_module(module="end", args=[])

_STATE["session_name"] = None # Reset the sesion name to None
85 changes: 85 additions & 0 deletions pygmt/tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Test multiprocessing support.
"""

import multiprocessing as mp
from importlib import reload
from pathlib import Path

import numpy.testing as npt
import pygmt


def _func(figname):
"""
A wrapper function for testing multiprocessing support.
"""
fig = pygmt.Figure()
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
fig.savefig(figname)


def test_multiprocessing():
"""
Test multiprocessing support for plotting figures.
"""
prefix = "test_session_multiprocessing"
with mp.Pool(2) as p:
p.map(_func, [f"{prefix}-1.png", f"{prefix}-2.png"])
Path(f"{prefix}-1.png").unlink()
Path(f"{prefix}-2.png").unlink()


def _func_datacut(dataset):
"""
A wrapper function for testing multiprocessing support.
"""
xrgrid = pygmt.grdcut(dataset, region=[-10, 10, -5, 5])
return xrgrid


def test_multiprocessing_data_processing():
"""
Test multiprocessing support for data processing.
"""
with mp.Pool(2) as p:
grids = p.map(_func_datacut, ["@earth_relief_01d_g", "@moon_relief_01d_g"])
assert len(grids) == 2
# The Earth relief dataset
assert grids[0].shape == (11, 21)
npt.assert_allclose(grids[0].min(), -5118.0, atol=0.5)
npt.assert_allclose(grids[0].max(), 680.5, atol=0.5)
# The Moon relief dataset
assert grids[1].shape == (11, 21)
npt.assert_allclose(grids[1].min(), -1122.0, atol=0.5)
npt.assert_allclose(grids[1].max(), 943.0, atol=0.5)


def _func_reload(figname):
"""
A wrapper for running PyGMT scripts with multiprocessing.

Before the official multiprocessing support in PyGMT, we have to reload the
PyGMT library. Workaround from
https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875.

This test makes sure that the old workaround still works.
"""
import pygmt

reload(pygmt)
fig = pygmt.Figure()
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
fig.savefig(figname)


def test_multiprocessing_reload():
"""
Make sure that multiprocessing is supported if pygmt is re-imported.
"""

prefix = "test_session_multiprocessing"
with mp.Pool(2) as p:
p.map(_func_reload, [f"{prefix}-1.png", f"{prefix}-2.png"])
Path(f"{prefix}-1.png").unlink()
Path(f"{prefix}-2.png").unlink()
34 changes: 2 additions & 32 deletions pygmt/tests/test_session_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Test the session management modules.
"""

import multiprocessing as mp
from importlib import reload
from pathlib import Path

import pytest
Expand Down Expand Up @@ -36,10 +34,8 @@ def test_gmt_compat_6_is_applied(capsys):
"""
end() # Kill the global session
try:
# Generate a gmt.conf file in the current directory
# with GMT_COMPATIBILITY = 5
with Session() as lib:
lib.call_module("gmtset", ["GMT_COMPATIBILITY=5"])
# Generate a gmt.conf file in the current directory with GMT_COMPATIBILITY = 5
Path("gmt.conf").write_text("GMT_COMPATIBILITY = 5", encoding="utf-8")
begin()
with Session() as lib:
lib.call_module("basemap", ["-R10/70/-3/8", "-JX4i/3i", "-Ba"])
Expand All @@ -60,29 +56,3 @@ def test_gmt_compat_6_is_applied(capsys):
# Make sure no global "gmt.conf" in the current directory
assert not Path("gmt.conf").exists()
begin() # Restart the global session


def _gmt_func_wrapper(figname):
"""
A wrapper for running PyGMT scripts with multiprocessing.

Currently, we have to import pygmt and reload it in each process. Workaround from
https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875.
"""
import pygmt

reload(pygmt)
fig = pygmt.Figure()
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
fig.savefig(figname)


def test_session_multiprocessing():
"""
Make sure that multiprocessing is supported if pygmt is re-imported.
"""
prefix = "test_session_multiprocessing"
with mp.Pool(2) as p:
p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"])
Path(f"{prefix}-1.png").unlink()
Path(f"{prefix}-2.png").unlink()
Loading