From 93600571c88d31c94d5080af35541d27b0021443 Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Wed, 27 Nov 2019 22:55:43 +0100 Subject: [PATCH] fix(bigquery): add close() method to client for releasing open sockets (#9894) * Add close() method to Client * Add psutil as an extra test dependency * Fix open sockets leak in IPython magics * Move psutil test dependency to noxfile * Wrap entire cell magic into try-finally block A single common cleanup point at the end makes it much less likely to accidentally re-introduce an open socket leak. --- bigquery/google/cloud/bigquery/client.py | 12 ++ bigquery/google/cloud/bigquery/magics.py | 160 +++++++++++++---------- bigquery/noxfile.py | 2 +- bigquery/tests/system.py | 28 ++++ bigquery/tests/unit/test_client.py | 11 ++ bigquery/tests/unit/test_magics.py | 67 ++++++++++ 6 files changed, 211 insertions(+), 69 deletions(-) diff --git a/bigquery/google/cloud/bigquery/client.py b/bigquery/google/cloud/bigquery/client.py index c8df21e91f55..e6eaf5fcb3ba 100644 --- a/bigquery/google/cloud/bigquery/client.py +++ b/bigquery/google/cloud/bigquery/client.py @@ -194,6 +194,18 @@ def location(self): """Default location for jobs / datasets / tables.""" return self._location + def close(self): + """Close the underlying transport objects, releasing system resources. + + .. note:: + + The client instance can be used for making additional requests even + after closing, in which case the underlying connections are + automatically re-created. + """ + self._http._auth_request.session.close() + self._http.close() + def get_service_account_email(self, project=None): """Get the email address of the project's BigQuery service account diff --git a/bigquery/google/cloud/bigquery/magics.py b/bigquery/google/cloud/bigquery/magics.py index 59265ed6b0c5..5ca6817a99c6 100644 --- a/bigquery/google/cloud/bigquery/magics.py +++ b/bigquery/google/cloud/bigquery/magics.py @@ -137,6 +137,7 @@ import re import ast +import functools import sys import time from concurrent import futures @@ -494,86 +495,91 @@ def _cell_magic(line, query): args.use_bqstorage_api or context.use_bqstorage_api, context.credentials ) - if args.max_results: - max_results = int(args.max_results) - else: - max_results = None + close_transports = functools.partial(_close_transports, client, bqstorage_client) - query = query.strip() + try: + if args.max_results: + max_results = int(args.max_results) + else: + max_results = None + + query = query.strip() + + # Any query that does not contain whitespace (aside from leading and trailing whitespace) + # is assumed to be a table id + if not re.search(r"\s", query): + try: + rows = client.list_rows(query, max_results=max_results) + except Exception as ex: + _handle_error(ex, args.destination_var) + return + + result = rows.to_dataframe(bqstorage_client=bqstorage_client) + if args.destination_var: + IPython.get_ipython().push({args.destination_var: result}) + return + else: + return result + + job_config = bigquery.job.QueryJobConfig() + job_config.query_parameters = params + job_config.use_legacy_sql = args.use_legacy_sql + job_config.dry_run = args.dry_run + + if args.destination_table: + split = args.destination_table.split(".") + if len(split) != 2: + raise ValueError( + "--destination_table should be in a . format." + ) + dataset_id, table_id = split + job_config.allow_large_results = True + dataset_ref = client.dataset(dataset_id) + destination_table_ref = dataset_ref.table(table_id) + job_config.destination = destination_table_ref + job_config.create_disposition = "CREATE_IF_NEEDED" + job_config.write_disposition = "WRITE_TRUNCATE" + _create_dataset_if_necessary(client, dataset_id) + + if args.maximum_bytes_billed == "None": + job_config.maximum_bytes_billed = 0 + elif args.maximum_bytes_billed is not None: + value = int(args.maximum_bytes_billed) + job_config.maximum_bytes_billed = value - # Any query that does not contain whitespace (aside from leading and trailing whitespace) - # is assumed to be a table id - if not re.search(r"\s", query): try: - rows = client.list_rows(query, max_results=max_results) + query_job = _run_query(client, query, job_config=job_config) except Exception as ex: _handle_error(ex, args.destination_var) return - result = rows.to_dataframe(bqstorage_client=bqstorage_client) - if args.destination_var: - IPython.get_ipython().push({args.destination_var: result}) - return - else: - return result - - job_config = bigquery.job.QueryJobConfig() - job_config.query_parameters = params - job_config.use_legacy_sql = args.use_legacy_sql - job_config.dry_run = args.dry_run + if not args.verbose: + display.clear_output() - if args.destination_table: - split = args.destination_table.split(".") - if len(split) != 2: - raise ValueError( - "--destination_table should be in a . format." + if args.dry_run and args.destination_var: + IPython.get_ipython().push({args.destination_var: query_job}) + return + elif args.dry_run: + print( + "Query validated. This query will process {} bytes.".format( + query_job.total_bytes_processed + ) ) - dataset_id, table_id = split - job_config.allow_large_results = True - dataset_ref = client.dataset(dataset_id) - destination_table_ref = dataset_ref.table(table_id) - job_config.destination = destination_table_ref - job_config.create_disposition = "CREATE_IF_NEEDED" - job_config.write_disposition = "WRITE_TRUNCATE" - _create_dataset_if_necessary(client, dataset_id) - - if args.maximum_bytes_billed == "None": - job_config.maximum_bytes_billed = 0 - elif args.maximum_bytes_billed is not None: - value = int(args.maximum_bytes_billed) - job_config.maximum_bytes_billed = value - - try: - query_job = _run_query(client, query, job_config=job_config) - except Exception as ex: - _handle_error(ex, args.destination_var) - return - - if not args.verbose: - display.clear_output() + return query_job - if args.dry_run and args.destination_var: - IPython.get_ipython().push({args.destination_var: query_job}) - return - elif args.dry_run: - print( - "Query validated. This query will process {} bytes.".format( - query_job.total_bytes_processed + if max_results: + result = query_job.result(max_results=max_results).to_dataframe( + bqstorage_client=bqstorage_client ) - ) - return query_job - - if max_results: - result = query_job.result(max_results=max_results).to_dataframe( - bqstorage_client=bqstorage_client - ) - else: - result = query_job.to_dataframe(bqstorage_client=bqstorage_client) + else: + result = query_job.to_dataframe(bqstorage_client=bqstorage_client) - if args.destination_var: - IPython.get_ipython().push({args.destination_var: result}) - else: - return result + if args.destination_var: + IPython.get_ipython().push({args.destination_var: result}) + else: + return result + finally: + close_transports() def _make_bqstorage_client(use_bqstorage_api, credentials): @@ -601,3 +607,21 @@ def _make_bqstorage_client(use_bqstorage_api, credentials): credentials=credentials, client_info=gapic_client_info.ClientInfo(user_agent=IPYTHON_USER_AGENT), ) + + +def _close_transports(client, bqstorage_client): + """Close the given clients' underlying transport channels. + + Closing the transport is needed to release system resources, namely open + sockets. + + Args: + client (:class:`~google.cloud.bigquery.client.Client`): + bqstorage_client + (Optional[:class:`~google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient`]): + A client for the BigQuery Storage API. + + """ + client.close() + if bqstorage_client is not None: + bqstorage_client.transport.channel.close() diff --git a/bigquery/noxfile.py b/bigquery/noxfile.py index a6d8094ebbc3..87809b74a569 100644 --- a/bigquery/noxfile.py +++ b/bigquery/noxfile.py @@ -81,7 +81,7 @@ def system(session): session.install("--pre", "grpcio") # Install all test dependencies, then install local packages in place. - session.install("mock", "pytest") + session.install("mock", "pytest", "psutil") for local_dep in LOCAL_DEPS: session.install("-e", local_dep) session.install("-e", os.path.join("..", "storage")) diff --git a/bigquery/tests/system.py b/bigquery/tests/system.py index 4816962a70d6..bba527178f47 100644 --- a/bigquery/tests/system.py +++ b/bigquery/tests/system.py @@ -27,6 +27,7 @@ import re import six +import psutil import pytest import pytz @@ -203,6 +204,27 @@ def _create_bucket(self, bucket_name, location=None): return bucket + def test_close_releases_open_sockets(self): + current_process = psutil.Process() + conn_count_start = len(current_process.connections()) + + client = Config.CLIENT + client.query( + """ + SELECT + source_year AS year, COUNT(is_male) AS birth_count + FROM `bigquery-public-data.samples.natality` + GROUP BY year + ORDER BY year DESC + LIMIT 15 + """ + ) + + client.close() + + conn_count_end = len(current_process.connections()) + self.assertEqual(conn_count_end, conn_count_start) + def test_create_dataset(self): DATASET_ID = _make_dataset_id("create_dataset") dataset = self.temp_dataset(DATASET_ID) @@ -2417,6 +2439,9 @@ def temp_dataset(self, dataset_id, location=None): @pytest.mark.usefixtures("ipython_interactive") def test_bigquery_magic(): ip = IPython.get_ipython() + current_process = psutil.Process() + conn_count_start = len(current_process.connections()) + ip.extension_manager.load_extension("google.cloud.bigquery") sql = """ SELECT @@ -2432,6 +2457,8 @@ def test_bigquery_magic(): with io.capture_output() as captured: result = ip.run_cell_magic("bigquery", "", sql) + conn_count_end = len(current_process.connections()) + lines = re.split("\n|\r", captured.stdout) # Removes blanks & terminal code (result of display clearing) updates = list(filter(lambda x: bool(x) and x != "\x1b[2K", lines)) @@ -2441,6 +2468,7 @@ def test_bigquery_magic(): assert isinstance(result, pandas.DataFrame) assert len(result) == 10 # verify row count assert list(result) == ["url", "view_count"] # verify column names + assert conn_count_end == conn_count_start # system resources are released def _job_done(instance): diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index da3fb2c56689..ecde69d2cf97 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -1398,6 +1398,17 @@ def test_create_table_alreadyexists_w_exists_ok_true(self): ] ) + def test_close(self): + creds = _make_credentials() + http = mock.Mock() + http._auth_request.session = mock.Mock() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + + client.close() + + http.close.assert_called_once() + http._auth_request.session.close.assert_called_once() + def test_get_model(self): path = "projects/%s/datasets/%s/models/%s" % ( self.PROJECT, diff --git a/bigquery/tests/unit/test_magics.py b/bigquery/tests/unit/test_magics.py index 6ff9819854a8..8e768c1b7d23 100644 --- a/bigquery/tests/unit/test_magics.py +++ b/bigquery/tests/unit/test_magics.py @@ -545,6 +545,7 @@ def test_bigquery_magic_with_bqstorage_from_argument(monkeypatch): bqstorage_instance_mock = mock.create_autospec( bigquery_storage_v1beta1.BigQueryStorageClient, instance=True ) + bqstorage_instance_mock.transport = mock.Mock() bqstorage_mock.return_value = bqstorage_instance_mock bqstorage_client_patch = mock.patch( "google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock @@ -601,6 +602,7 @@ def test_bigquery_magic_with_bqstorage_from_context(monkeypatch): bqstorage_instance_mock = mock.create_autospec( bigquery_storage_v1beta1.BigQueryStorageClient, instance=True ) + bqstorage_instance_mock.transport = mock.Mock() bqstorage_mock.return_value = bqstorage_instance_mock bqstorage_client_patch = mock.patch( "google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock @@ -728,6 +730,41 @@ def test_bigquery_magic_w_max_results_valid_calls_queryjob_result(): query_job_mock.result.assert_called_with(max_results=5) +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_w_max_results_query_job_results_fails(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + credentials_mock = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + default_patch = mock.patch( + "google.auth.default", return_value=(credentials_mock, "general-project") + ) + client_query_patch = mock.patch( + "google.cloud.bigquery.client.Client.query", autospec=True + ) + close_transports_patch = mock.patch( + "google.cloud.bigquery.magics._close_transports", autospec=True, + ) + + sql = "SELECT 17 AS num" + + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.result.side_effect = [[], OSError] + + with pytest.raises( + OSError + ), client_query_patch as client_query_mock, default_patch, close_transports_patch as close_transports: + client_query_mock.return_value = query_job_mock + ip.run_cell_magic("bigquery", "--max_results=5", sql) + + assert close_transports.called + + def test_bigquery_magic_w_table_id_invalid(): ip = IPython.get_ipython() ip.extension_manager.load_extension("google.cloud.bigquery") @@ -820,6 +857,7 @@ def test_bigquery_magic_w_table_id_and_bqstorage_client(): bqstorage_instance_mock = mock.create_autospec( bigquery_storage_v1beta1.BigQueryStorageClient, instance=True ) + bqstorage_instance_mock.transport = mock.Mock() bqstorage_mock.return_value = bqstorage_instance_mock bqstorage_client_patch = mock.patch( "google.cloud.bigquery_storage_v1beta1.BigQueryStorageClient", bqstorage_mock @@ -1290,3 +1328,32 @@ def test_bigquery_magic_w_destination_table(): assert job_config_used.write_disposition == "WRITE_TRUNCATE" assert job_config_used.destination.dataset_id == "dataset_id" assert job_config_used.destination.table_id == "table_id" + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_create_dataset_fails(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + create_dataset_if_necessary_patch = mock.patch( + "google.cloud.bigquery.magics._create_dataset_if_necessary", + autospec=True, + side_effect=OSError, + ) + close_transports_patch = mock.patch( + "google.cloud.bigquery.magics._close_transports", autospec=True, + ) + + with pytest.raises( + OSError + ), create_dataset_if_necessary_patch, close_transports_patch as close_transports: + ip.run_cell_magic( + "bigquery", + "--destination_table dataset_id.table_id", + "SELECT foo FROM WHERE LIMIT bar", + ) + + assert close_transports.called