diff --git a/tables/automl/automl_tables_predict.py b/tables/automl/automl_tables_predict.py index 4a3423e3d537..786f80fcb856 100644 --- a/tables/automl/automl_tables_predict.py +++ b/tables/automl/automl_tables_predict.py @@ -85,7 +85,7 @@ def batch_predict( project_id, compute_region, model_display_name, - gcs_input_uris, + gcs_input_uri, gcs_output_uri, ): """Make a batch of predictions.""" @@ -94,17 +94,19 @@ def batch_predict( # project_id = 'PROJECT_ID_HERE' # compute_region = 'COMPUTE_REGION_HERE' # model_display_name = 'MODEL_DISPLAY_NAME_HERE' - # gcs_input_uris = ['gs://path/to/file.csv] - # gcs_output_uri = 'gs://path' + # gcs_input_uri = 'gs://YOUR_BUCKET_ID/path_to_your_input_csv' + # gcs_output_uri = 'gs://YOUR_BUCKET_ID/path_to_save_results/' from google.cloud import automl_v1beta1 as automl client = automl.TablesClient(project=project_id, region=compute_region) # Query model - response = client.batch_predict(gcs_input_uris=gcs_input_uris, - gcs_output_uri_prefix=gcs_output_uri, - model_display_name=model_display_name) + response = client.batch_predict( + gcs_input_uris=gcs_input_uri, + gcs_output_uri_prefix=gcs_output_uri, + model_display_name=model_display_name, + ) print("Making batch prediction... ") response.result() print("Batch prediction complete.\n{}".format(response.metadata))