diff --git a/tableaudocumentapi/connection.py b/tableaudocumentapi/connection.py index ab4dcbb..8e9eb58 100644 --- a/tableaudocumentapi/connection.py +++ b/tableaudocumentapi/connection.py @@ -3,6 +3,7 @@ # Connection - A class for writing connections to Tableau files # ############################################################################### +import xml.etree.ElementTree as ET from tableaudocumentapi.dbclass import is_valid_dbclass @@ -33,6 +34,17 @@ def __init__(self, connxml): def __repr__(self): return "''".format(self._server, self._dbname, hex(id(self))) + @classmethod + def from_attributes(cls, server, dbname, username, dbclass, authentication=''): + root = ET.Element('connection', authentication=authentication) + xml = cls(root) + xml.server = server + xml.dbname = dbname + xml.username = username + xml.dbclass = dbclass + + return xml + ########### # dbname ########### @@ -120,4 +132,4 @@ def dbclass(self, value): raise AttributeError("'{}' is not a valid database type".format(value)) self._class = value - self._connectionXML.set('dbclass', value) + self._connectionXML.set('class', value) diff --git a/tableaudocumentapi/datasource.py b/tableaudocumentapi/datasource.py index 09860b4..bae9159 100644 --- a/tableaudocumentapi/datasource.py +++ b/tableaudocumentapi/datasource.py @@ -5,8 +5,10 @@ ############################################################################### import collections import itertools +import random import xml.etree.ElementTree as ET import xml.sax.saxutils as sax +from uuid import uuid4 from tableaudocumentapi import Connection, xfile from tableaudocumentapi import Field @@ -38,6 +40,7 @@ def _is_used_by_worksheet(names, field): class FieldDictionary(MultiLookupDict): + def used_by_sheet(self, name): # If we pass in a string, no need to get complicated, just check to see if name is in # the field's list of worksheets @@ -63,7 +66,36 @@ def _column_object_from_metadata_xml(metadata_xml): return _ColumnObjectReturnTuple(field_object.id, field_object) +def base36encode(number): + """Converts an integer into a base36 string.""" + + ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyz" + + base36 = '' + sign = '' + + if number < 0: + sign = '-' + number = -number + + if 0 <= number < len(ALPHABET): + return sign + ALPHABET[number] + + while number != 0: + number, i = divmod(number, len(ALPHABET)) + base36 = ALPHABET[i] + base36 + + return sign + base36 + + +def make_unique_name(dbclass): + rand_part = base36encode(uuid4().int) + name = dbclass + '.' + rand_part + return name + + class ConnectionParser(object): + def __init__(self, datasource_xml, version): self._dsxml = datasource_xml self._dsversion = version @@ -116,6 +148,20 @@ def from_file(cls, filename): dsxml = xml_open(filename, cls.__name__.lower()).getroot() return cls(dsxml, filename) + @classmethod + def from_connections(cls, caption, connections): + root = ET.Element('datasource', caption=caption, version='10.0', inline='true') + outer_connection = ET.SubElement(root, 'connection') + outer_connection.set('class', 'federated') + named_conns = ET.SubElement(outer_connection, 'named-connections') + for conn in connections: + nc = ET.SubElement(named_conns, + 'named-connection', + name=make_unique_name(conn.dbclass), + caption=conn.server) + nc.append(conn._connectionXML) + return cls(root) + def save(self): """ Call finalization code and save file. @@ -143,6 +189,7 @@ def save_as(self, new_filename): Nothing. """ + xfile._save_file(self._filename, self._datasourceTree, new_filename) ########### diff --git a/tableaudocumentapi/xfile.py b/tableaudocumentapi/xfile.py index 8f9ffd1..66e5aac 100644 --- a/tableaudocumentapi/xfile.py +++ b/tableaudocumentapi/xfile.py @@ -104,6 +104,10 @@ def save_into_archive(xml_tree, filename, new_filename=None): def _save_file(container_file, xml_tree, new_filename=None): + + if container_file is None: + container_file = new_filename + if zipfile.is_zipfile(container_file): save_into_archive(xml_tree, container_file, new_filename) else: diff --git a/test/bvt.py b/test/bvt.py index ce96afe..463a087 100644 --- a/test/bvt.py +++ b/test/bvt.py @@ -74,6 +74,25 @@ def test_bad_dbclass_rasies_attribute_error(self): with self.assertRaises(AttributeError): conn.dbclass = 'NotReal' + def test_can_create_connection_from_scratch(self): + conn = Connection.from_attributes( + server='a', dbname='b', username='c', dbclass='mysql', authentication='d') + self.assertEqual(conn.server, 'a') + self.assertEqual(conn.dbname, 'b') + self.assertEqual(conn.username, 'c') + self.assertEqual(conn.dbclass, 'mysql') + self.assertEqual(conn.authentication, 'd') + + def test_can_create_datasource_from_connections(self): + conn1 = Connection.from_attributes( + server='a', dbname='b', username='c', dbclass='mysql', authentication='d') + conn2 = Connection.from_attributes( + server='1', dbname='2', username='3', dbclass='mysql', authentication='7') + ds = Datasource.from_connections('test', connections=[conn1, conn2]) + + self.assertEqual(ds.connections[0].server, 'a') + self.assertEqual(ds.connections[1].server, '1') + class DatasourceModelTests(unittest.TestCase):