Skip to content

Commit

Permalink
More typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricklupton committed Oct 3, 2023
1 parent e398c29 commit 8760868
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
21 changes: 14 additions & 7 deletions src/rmscene/crdt_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from .tagged_block_common import CrdtId


_T = tp.TypeVar("_T")
# If the type constraint is for a CrdtSequenceItem[Superclass], then a
# CrdtSequenceItem[Subclass] would do, so it is covariant.

_T = tp.TypeVar("_T", covariant=True)


@dataclass
Expand All @@ -22,7 +25,11 @@ class CrdtSequenceItem(tp.Generic[_T]):
value: _T


class CrdtSequence(tp.Generic[_T]):
# As a mutable container, CrdtSequence is invariant.
_Ti = tp.TypeVar("_Ti", covariant=False)


class CrdtSequence(tp.Generic[_Ti]):
"""Ordered CRDT Sequence container.
The Sequence contains `CrdtSequenceItem`s, each of which has an ID and
Expand Down Expand Up @@ -57,27 +64,27 @@ def keys(self) -> list[CrdtId]:
"""Return CrdtIds in order."""
return list(self)

def values(self) -> list[_T]:
def values(self) -> list[_Ti]:
"""Return list of sorted values."""
return [self[item_id] for item_id in self]

def items(self) -> Iterable[tuple[CrdtId, _T]]:
def items(self) -> Iterable[tuple[CrdtId, _Ti]]:
"""Return list of sorted key, value pairs."""
return [(item_id, self[item_id]) for item_id in self]

def __getitem__(self, key: CrdtId) -> _T:
def __getitem__(self, key: CrdtId) -> _Ti:
"""Return item with key"""
return self._items[key].value

## Access SequenceItems

def sequence_items(self) -> list[CrdtSequenceItem[_T]]:
def sequence_items(self) -> list[CrdtSequenceItem[_Ti]]:
"""Iterate through CrdtSequenceItems."""
return list(self._items.values())

## Modify sequence

def add(self, item: CrdtSequenceItem[_T]):
def add(self, item: CrdtSequenceItem[_Ti]):
if item.item_id in self._items:
raise ValueError("Already have item %s" % item.item_id)
self._items[item.item_id] = item
Expand Down
23 changes: 17 additions & 6 deletions src/rmscene/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
_logger = logging.getLogger(__name__)


def expand_text_item(item: CrdtSequenceItem[str | int]) -> Iterable[CrdtSequenceItem[str | int]]:
def expand_text_item(
item: CrdtSequenceItem[str | int],
) -> tp.Iterator[CrdtSequenceItem[str | int]]:
"""Expand TextItem into single-character TextItems.
Text is stored as strings in TextItems, each with an associated ID for the
Expand Down Expand Up @@ -56,7 +58,9 @@ def expand_text_item(item: CrdtSequenceItem[str | int]) -> Iterable[CrdtSequence
yield CrdtSequenceItem(item_id, left_id, item.right_id, deleted_length, chars[-1])


def expand_text_items(items: Iterable[CrdtSequenceItem[str | int]]) -> Iterable[CrdtSequenceItem[str | int]]:
def expand_text_items(
items: Iterable[CrdtSequenceItem[str | int]],
) -> tp.Iterator[CrdtSequenceItem[str | int]]:
"""Expand a sequence of TextItems into single-character TextItems."""
for item in items:
yield from expand_text_item(item)
Expand All @@ -74,6 +78,7 @@ def __str__(self):
@dataclass
class TextSpan:
"""Base class for text spans with formatting."""

contents: list[tp.Union["TextSpan", CrdtStr]]


Expand All @@ -88,6 +93,7 @@ class ItalicSpan(TextSpan):
@dataclass
class Paragraph:
"""Paragraph of text."""

contents: list[TextSpan]
start_id: CrdtId
style: LwwValue[si.ParagraphStyle]
Expand All @@ -98,7 +104,6 @@ def __str__(self):

@dataclass
class TextDocument:

contents: list[Paragraph]

@classmethod
Expand Down Expand Up @@ -146,8 +151,12 @@ def parse_paragraph_contents():
elif char in span_end_codes:
span_type, nested = stack.pop()
if span_type is not span_end_codes[char]:
_logger.error("Unexpected end of span at %s: got %s, expected %s",
k, span_end_codes[char], span_type)
_logger.error(
"Unexpected end of span at %s: got %s, expected %s",
k,
span_end_codes[char],
span_type,
)
if span_type is not None:
stack[-1][1].append(span_type(nested))
else:
Expand All @@ -172,7 +181,9 @@ def parse_paragraph_contents():

paragraphs = []
while keys:
style = text.styles.get(last_linebreak, LwwValue(CrdtId(0, 0), si.ParagraphStyle.PLAIN))
style = text.styles.get(
last_linebreak, LwwValue(CrdtId(0, 0), si.ParagraphStyle.PLAIN)
)
contents = parse_paragraph_contents()
p = Paragraph(contents, last_linebreak, style)
paragraphs += [p]
Expand Down

0 comments on commit 8760868

Please sign in to comment.