diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index 44181aa477b72..4948ca4263135 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -2027,6 +2027,13 @@ def executemany(self, operation, seq_of_parameters): for parameters in seq_of_parameters: self.execute(operation, parameters) + def flush_results(self): + """ Flush results related cursor attributes. """ + self.page_token = None + self.job_id = None + self.all_pages_loaded = False + self.buffer = [] + def fetchone(self): """ Fetch the next row of a query result set. """ return self.next() @@ -2067,9 +2074,7 @@ def next(self): else: # Reset all state since we've exhausted the results. - self.page_token = None - self.job_id = None - self.page_token = None + self.flush_results() return None return self.buffer.pop(0) diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py index 7341add645739..f30c0c83a2b2b 100644 --- a/tests/contrib/hooks/test_bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -636,6 +636,25 @@ def test_execute_with_parameters(self, mocked_rwc): "SELECT %(foo)s", {"foo": "bar"}) assert mocked_rwc.call_count == 1 + @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') + @mock.patch.object(hook.BigQueryCursor, 'flush_results') + def test_flush_cursor_in_execute(self, _, mocked_fr): + hook.BigQueryCursor("test", "test").execute( + "SELECT %(foo)s", {"foo": "bar"}) + assert mocked_fr.call_count == 1 + + def test_flush_cursor(self): + bq_cursor = hook.BigQueryCursor("test", "test") + bq_cursor.page_token = '456dcea9-fcbf-4f02-b570-83f5297c685e' + bq_cursor.job_id = 'c0a79ae4-0e72-4593-a0d0-7dbbf726f193' + bq_cursor.all_pages_loaded = True + bq_cursor.buffer = [('a', 100, 200), ('b', 200, 300)] + bq_cursor.flush_results() + self.assertIsNone(bq_cursor.page_token) + self.assertIsNone(bq_cursor.job_id) + self.assertFalse(bq_cursor.all_pages_loaded) + self.assertListEqual(bq_cursor.buffer, []) + class TestLabelsInRunJob(unittest.TestCase): @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')