diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index 07a2ab8bec233c..e99aa73bdb2d21 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -93,7 +93,7 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, **kwar """ raise NotImplementedError() - def get_pandas_df(self, sql, parameters=None, dialect=None): + def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs): """ Returns a Pandas DataFrame for the results produced by a BigQuery query. The DbApiHook method must be overridden because Pandas @@ -110,6 +110,8 @@ def get_pandas_df(self, sql, parameters=None, dialect=None): :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL defaults to use `self.use_legacy_sql` if not specified :type dialect: str in {'legacy', 'standard'} + :param kwargs: (optional) passed into pandas_gbq.read_gbq method + :type kwargs: dict """ private_key = self._get_field('key_path', None) or self._get_field('keyfile_dict', None) @@ -120,7 +122,8 @@ def get_pandas_df(self, sql, parameters=None, dialect=None): project_id=self._get_field('project'), dialect=dialect, verbose=False, - private_key=private_key) + private_key=private_key, + **kwargs) def table_exists(self, project_id, dataset_id, table_id): """ diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index ac5488122c39f2..76f4f0ad63bbf2 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -82,7 +82,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): engine_kwargs = {} return create_engine(self.get_uri(), **engine_kwargs) - def get_pandas_df(self, sql, parameters=None): + def get_pandas_df(self, sql, parameters=None, **kwargs): """ Executes the sql and returns a pandas dataframe @@ -90,14 +90,16 @@ def get_pandas_df(self, sql, parameters=None): sql statements to execute :type sql: str or list :param parameters: The parameters to render the SQL query with. - :type parameters: mapping or iterable + :type parameters: dict or iterable + :param kwargs: (optional) passed into pandas.io.sql.read_sql method + :type kwargs: dict """ if sys.version_info[0] < 3: sql = sql.encode('utf-8') import pandas.io.sql as psql with closing(self.get_conn()) as conn: - return psql.read_sql(sql, con=conn, params=parameters) + return psql.read_sql(sql, con=conn, params=parameters, **kwargs) def get_records(self, sql, parameters=None): """ diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index e521d7bdc6ede7..48def1164c15a7 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -983,7 +983,7 @@ def get_records(self, hql, schema='default'): """ return self.get_results(hql, schema=schema)['data'] - def get_pandas_df(self, hql, schema='default'): + def get_pandas_df(self, hql, schema='default', **kwargs): """ Get a pandas dataframe from a Hive query @@ -991,6 +991,8 @@ def get_pandas_df(self, hql, schema='default'): :type hql: str or list :param schema: target schema, default to 'default'. :type schema: str + :param kwargs: (optional) passed into pandas.DataFrame constructor + :type kwargs: dict :return: result of hql execution :rtype: DataFrame @@ -1004,6 +1006,6 @@ def get_pandas_df(self, hql, schema='default'): """ import pandas as pd res = self.get_results(hql, schema=schema) - df = pd.DataFrame(res['data']) + df = pd.DataFrame(res['data'], **kwargs) df.columns = [c[0] for c in res['header']] return df diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py index 9788411b97d7ad..7d700aba8217bf 100644 --- a/airflow/hooks/presto_hook.py +++ b/airflow/hooks/presto_hook.py @@ -105,7 +105,7 @@ def get_first(self, hql, parameters=None): except DatabaseError as e: raise PrestoException(self._get_pretty_exception_message(e)) - def get_pandas_df(self, hql, parameters=None): + def get_pandas_df(self, hql, parameters=None, **kwargs): """ Get a pandas dataframe from a sql query. """ @@ -118,10 +118,10 @@ def get_pandas_df(self, hql, parameters=None): raise PrestoException(self._get_pretty_exception_message(e)) column_descriptions = cursor.description if data: - df = pandas.DataFrame(data) + df = pandas.DataFrame(data, **kwargs) df.columns = [c[0] for c in column_descriptions] else: - df = pandas.DataFrame() + df = pandas.DataFrame(**kwargs) return df def run(self, hql, parameters=None):