-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
* #1685 Add support for multiple results in batch jobs Adds a new method 'get_batch_result_ids' and maintains backwards compatibility with the old method, adding a warning. * Add the ability to merge result sets and basic testing. Note that this only supports CSV since this is what the library does by default.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import time | ||
import abc | ||
import logging | ||
import warnings | ||
import xml.etree.ElementTree as ET | ||
from collections import OrderedDict | ||
import re | ||
|
@@ -139,6 +140,11 @@ def is_soql_file(self): | |
"""Override to True if soql property is a file path.""" | ||
return False | ||
|
||
@property | ||
def content_type(self): | ||
"""Override to use a different content type. (e.g. XML)""" | ||
return "CSV" | ||
|
||
def run(self): | ||
if self.use_sandbox and not self.sandbox_name: | ||
raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox") | ||
|
@@ -170,11 +176,23 @@ def run(self): | |
if 'foreign key relationships not supported' not in status['state_message'].lower(): | ||
raise Exception(msg) | ||
else: | ||
result_id = sf.get_batch_results(job_id, batch_id) | ||
data = sf.get_batch_result(job_id, batch_id, result_id) | ||
|
||
with open(self.output().fn, 'w') as outfile: | ||
outfile.write(data) | ||
result_ids = sf.get_batch_result_ids(job_id, batch_id) | ||
|
||
# If there's only one result, just download it, otherwise we need to merge the resulting downloads | ||
if len(result_ids) == 1: | ||
data = sf.get_batch_result(job_id, batch_id, result_ids[0]) | ||
with open(self.output().path, 'w') as outfile: | ||
outfile.write(data) | ||
else: | ||
# Download each file to disk, and then merge into one. | ||
# Preferring to do it this way so as to minimize memory consumption. | ||
for i, result_id in enumerate(result_ids): | ||
logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id)) | ||
with open("%s.%d" % (self.output().path, i), 'w') as outfile: | ||
outfile.write(sf.get_batch_result(job_id, batch_id, result_id)) | ||
|
||
logger.info("Merging results of batch %s" % batch_id) | ||
self.merge_batch_results(result_ids) | ||
finally: | ||
logger.info("Closing job %s" % job_id) | ||
sf.close_job(job_id) | ||
|
@@ -184,11 +202,30 @@ def run(self): | |
data_file = sf.query_all(self.soql) | ||
|
||
reader = csv.reader(data_file) | ||
with open(self.output().fn, 'w') as outfile: | ||
with open(self.output().path, 'w') as outfile: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
dlstadther
Collaborator
|
||
writer = csv.writer(outfile, dialect='excel') | ||
for row in reader: | ||
writer.writerow(row) | ||
|
||
def merge_batch_results(self, result_ids): | ||
""" | ||
Merges the resulting files of a multi-result batch bulk query. | ||
""" | ||
outfile = open(self.output().path, 'w') | ||
|
||
if self.content_type == 'CSV': | ||
for i, result_id in enumerate(result_ids): | ||
with open("%s.%d" % (self.output().path, i), 'r') as f: | ||
header = f.readline() | ||
if i == 0: | ||
outfile.write(header) | ||
for line in f: | ||
outfile.write(line) | ||
else: | ||
raise Exception("Batch result merging not implemented for %s" % self.content_type) | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
outfile.close() | ||
|
||
|
||
class SalesforceAPI(object): | ||
""" | ||
|
@@ -353,15 +390,17 @@ def restful(self, path, params): | |
else: | ||
return json_result | ||
|
||
def create_operation_job(self, operation, obj, external_id_field_name=None, content_type='CSV'): | ||
def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None): | ||
""" | ||
Creates a new SF job that for doing any operation (insert, upsert, update, delete, query) | ||
:param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase. | ||
:param obj: Parent SF object | ||
:param external_id_field_name: Optional. | ||
:param content_type: XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV | ||
""" | ||
if content_type is None: | ||
content_type = self.content_type | ||
|
||
if not self.has_active_session(): | ||
This comment has been minimized.
Sorry, something went wrong.
dlstadther
Collaborator
|
||
self.start_session() | ||
|
||
|
@@ -419,7 +458,7 @@ def close_job(self, job_id): | |
|
||
return response | ||
|
||
def create_batch(self, job_id, data, file_type='csv'): | ||
def create_batch(self, job_id, data, file_type=None): | ||
""" | ||
Creates a batch with either a string of data or a file containing data. | ||
|
@@ -429,13 +468,15 @@ def create_batch(self, job_id, data, file_type='csv'): | |
:param job_id: job_id as returned by 'create_operation_job(...)' | ||
:param data: | ||
:param file_type: | ||
:return: Returns batch_id | ||
""" | ||
if not job_id or not self.has_active_session(): | ||
raise Exception("Can not create a batch without a valid job_id and an active session.") | ||
|
||
if file_type is None: | ||
file_type = self.content_type.lower() | ||
|
||
headers = self._get_create_batch_content_headers(file_type) | ||
headers['Content-Length'] = len(data) | ||
|
||
|
@@ -473,22 +514,27 @@ def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_s | |
|
||
def get_batch_results(self, job_id, batch_id): | ||
""" | ||
Get results of a batch that has completed processing. | ||
If the batch is a CSV file, the response is in CSV format. | ||
If the batch is an XML file, the response is in XML format. | ||
DEPRECATED: Use `get_batch_result_ids` | ||
""" | ||
warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids") | ||
return self.get_batch_result_ids(job_id, batch_id)[0] | ||
|
||
def get_batch_result_ids(self, job_id, batch_id): | ||
""" | ||
Get result IDs of a batch that has completed processing. | ||
:param job_id: job_id as returned by 'create_operation_job(...)' | ||
:param batch_id: batch_id as returned by 'create_batch(...)' | ||
:return: batch result response as either CSV or XML, dependent on the batch | ||
:return: list of batch result IDs to be used in 'get_batch_result(...)' | ||
""" | ||
response = requests.get(self._get_batch_results_url(job_id, batch_id), | ||
headers=self._get_batch_info_headers()) | ||
response.raise_for_status() | ||
|
||
root = ET.fromstring(response.text) | ||
result = root.find('%sresult' % self.API_NS).text | ||
result_ids = [r.text for r in root.findall('%sresult' % self.API_NS)] | ||
|
||
return result | ||
return result_ids | ||
|
||
def get_batch_result(self, job_id, batch_id, result_id): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2016 Simply Measured | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not | ||
# use this file except in compliance with the License. You may obtain a copy of | ||
# the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations under | ||
# the License. | ||
# | ||
# This method will be used by the mock to replace requests.get | ||
|
||
""" | ||
Unit test for the Salesforce contrib package | ||
""" | ||
|
||
from luigi.contrib.salesforce import SalesforceAPI, QuerySalesforce | ||
|
||
from helpers import unittest | ||
import mock | ||
from luigi.mock import MockTarget | ||
from luigi.six import PY3 | ||
import re | ||
|
||
|
||
def mocked_requests_get(*args, **kwargs): | ||
class MockResponse: | ||
def __init__(self, body, status_code): | ||
self.body = body | ||
self.status_code = status_code | ||
|
||
@property | ||
def text(self): | ||
return self.body | ||
|
||
def raise_for_status(self): | ||
return None | ||
|
||
result_list = ( | ||
'<result-list xmlns="http://www.force.com/2009/06/asyncapi/dataload">' | ||
'<result>1234</result><result>1235</result><result>1236</result>' | ||
'</result-list>' | ||
) | ||
return MockResponse(result_list, 200) | ||
|
||
# Keep open around so we can use it in the mock responses | ||
old__open = open | ||
|
||
|
||
def mocked_open(*args, **kwargs): | ||
if re.match("job_data", args[0]): | ||
return MockTarget(args[0]).open(args[1]) | ||
else: | ||
return old__open(*args) | ||
|
||
|
||
class TestSalesforceAPI(unittest.TestCase): | ||
# We patch 'requests.get' with our own method. The mock object is passed in to our test case method. | ||
@mock.patch('requests.get', side_effect=mocked_requests_get) | ||
def test_deprecated_results(self, mock_get): | ||
sf = SalesforceAPI('xx', 'xx', 'xx') | ||
result_id = sf.get_batch_results('job_id', 'batch_id') | ||
self.assertEqual('1234', result_id) | ||
|
||
@mock.patch('requests.get', side_effect=mocked_requests_get) | ||
def test_result_ids(self, mock_get): | ||
sf = SalesforceAPI('xx', 'xx', 'xx') | ||
result_ids = sf.get_batch_result_ids('job_id', 'batch_id') | ||
self.assertEqual(['1234', '1235', '1236'], result_ids) | ||
|
||
|
||
class TestQuerySalesforce(QuerySalesforce): | ||
def output(self): | ||
return MockTarget('job_data.csv') | ||
|
||
@property | ||
def object_name(self): | ||
return 'dual' | ||
|
||
@property | ||
def soql(self): | ||
return "SELECT * FROM %s" % self.object_name | ||
|
||
|
||
class TestSalesforceQuery(unittest.TestCase): | ||
patch_name = '__builtin__.open' | ||
if PY3: | ||
patch_name = 'builtins.open' | ||
|
||
@mock.patch(patch_name, side_effect=mocked_open) | ||
def setUp(self, mock_open): | ||
MockTarget.fs.clear() | ||
self.result_ids = ['a', 'b', 'c'] | ||
|
||
counter = 1 | ||
self.all_lines = "Lines\n" | ||
self.header = "Lines" | ||
for i, id in enumerate(self.result_ids): | ||
filename = "%s.%d" % ('job_data.csv', i) | ||
with MockTarget(filename).open('w') as f: | ||
line = "%d line\n%d line" % ((counter), (counter+1)) | ||
f.write(self.header + "\n" + line + "\n") | ||
self.all_lines += line+"\n" | ||
counter += 2 | ||
|
||
@mock.patch(patch_name, side_effect=mocked_open) | ||
def test_multi_csv_download(self, mock_open): | ||
qsf = TestQuerySalesforce() | ||
|
||
qsf.merge_batch_results(self.result_ids) | ||
self.assertEqual(MockTarget(qsf.output().path).open('r').read(), self.all_lines) |
1 comment
on commit 2a0ff1f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be submitting a fix to my comments. Basically, if you're going to set the default content_type
to csv
, then content_type
will never be None
. I'll submit the PR soon.
I apologize for not catching these breaking changes when the PR was pending.
Again, I'm sorry for not asking these questions awhile back... For what purpose did you change this?