Skip to content

Commit

Permalink
move default encoding definition to backend, re-order engine, format,…
Browse files Browse the repository at this point in the history
… encoding
  • Loading branch information
xflr6 committed Dec 24, 2020
1 parent 8fe3f02 commit ab7e243
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 28 deletions.
4 changes: 3 additions & 1 deletion graphviz/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@

FORMATTERS = {'cairo', 'core', 'gd', 'gdiplus', 'gdwbmp', 'xlib'}

ENCODING = 'utf-8'

PLATFORM = platform.system().lower()


Expand Down Expand Up @@ -249,7 +251,7 @@ def pipe(engine, format, data, renderer=None, formatter=None, quiet=False):

def unflatten(source,
stagger=None, fanout=False, chain=None,
encoding='utf-8'):
encoding=ENCODING):
"""Return DOT ``source`` piped through Graphviz *unflatten* preprocessor.
Args:
Expand Down
5 changes: 3 additions & 2 deletions graphviz/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
'test-output/m00se.gv.pdf'
"""

from . import lang
from . import backend
from . import files
from . import lang

__all__ = ['Graph', 'Digraph']

Expand All @@ -52,7 +53,7 @@ class Dot(files.File):

def __init__(self, name=None, comment=None,
filename=None, directory=None,
format=None, engine=None, encoding=files.ENCODING,
format=None, engine=None, encoding=backend.ENCODING,
graph_attr=None, node_attr=None, edge_attr=None, body=None,
strict=False):
self.name = name
Expand Down
34 changes: 17 additions & 17 deletions graphviz/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,17 @@

__all__ = ['File', 'Source']

ENCODING = 'utf-8'


log = logging.getLogger(__name__)


class Base(object):

_format = 'pdf'
_engine = 'dot'
_encoding = ENCODING

@property
def format(self):
"""The output format used for rendering (``'pdf'``, ``'png'``, ...)."""
return self._format
_format = 'pdf'

@format.setter
def format(self, format):
format = format.lower()
if format not in backend.FORMATS:
raise ValueError('unknown format: %r' % format)
self._format = format
_encoding = backend.ENCODING

@property
def engine(self):
Expand All @@ -51,6 +39,18 @@ def engine(self, engine):
raise ValueError('unknown engine: %r' % engine)
self._engine = engine

@property
def format(self):
"""The output format used for rendering (``'pdf'``, ``'png'``, ...)."""
return self._format

@format.setter
def format(self, format):
format = format.lower()
if format not in backend.FORMATS:
raise ValueError('unknown format: %r' % format)
self._format = format

@property
def encoding(self):
"""The encoding for the saved source file."""
Expand Down Expand Up @@ -85,7 +85,7 @@ class File(Base):
_default_extension = 'gv'

def __init__(self, filename=None, directory=None,
format=None, engine=None, encoding=ENCODING):
format=None, engine=None, encoding=backend.ENCODING):
if filename is None:
name = getattr(self, 'name', None) or self.__class__.__name__
filename = '%s.%s' % (name, self._default_extension)
Expand Down Expand Up @@ -323,7 +323,7 @@ class Source(File):

@classmethod
def from_file(cls, filename, directory=None,
format=None, engine=None, encoding=ENCODING):
format=None, engine=None, encoding=backend.ENCODING):
"""Return an instance with the source string read from the given file.
Args:
Expand All @@ -342,7 +342,7 @@ def from_file(cls, filename, directory=None,
return cls(source, filename, directory, format, engine, encoding)

def __init__(self, source, filename=None, directory=None,
format=None, engine=None, encoding=ENCODING):
format=None, engine=None, encoding=backend.ENCODING):
super(Source, self).__init__(filename, directory,
format, engine, encoding)
self.source = source #: The verbatim DOT source code string.
Expand Down
16 changes: 8 additions & 8 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ def source():
return Source(**SOURCE)


def test_format(source):
assert not SOURCE['format'].islower()

assert source.format == SOURCE['format'].lower()
with pytest.raises(ValueError, match=r'format'):
source.format = ''


def test_engine(source):
assert not SOURCE['engine'].islower()

Expand All @@ -35,6 +27,14 @@ def test_engine(source):
source.engine = ''


def test_format(source):
assert not SOURCE['format'].islower()

assert source.format == SOURCE['format'].lower()
with pytest.raises(ValueError, match=r'format'):
source.format = ''


def test_encoding(source):
assert source.encoding == SOURCE['encoding']

Expand Down

0 comments on commit ab7e243

Please sign in to comment.