-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5f0b7ad
commit 59b27b3
Showing
20 changed files
with
1,945 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import sys | ||
|
||
if sys.version_info < (3, 10): | ||
from importlib_metadata import entry_points | ||
else: | ||
from importlib.metadata import entry_points | ||
|
||
ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyverse_ydoc")} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Dict, List, Type, Union | ||
|
||
INT = Type[int] | ||
FLOAT = Type[float] | ||
|
||
|
||
def cast_all( | ||
o: Union[List, Dict], from_type: Union[INT, FLOAT], to_type: Union[FLOAT, INT] | ||
) -> Union[List, Dict]: | ||
if isinstance(o, list): | ||
for i, v in enumerate(o): | ||
if type(v) is from_type: | ||
v2 = to_type(v) | ||
if v == v2: | ||
o[i] = v2 | ||
elif isinstance(v, (list, dict)): | ||
cast_all(v, from_type, to_type) | ||
elif isinstance(o, dict): | ||
for k, v in o.items(): | ||
if type(v) is from_type: | ||
v2 = to_type(v) | ||
if v == v2: | ||
o[k] = v2 | ||
elif isinstance(v, (list, dict)): | ||
cast_all(v, from_type, to_type) | ||
return o |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from pycrdt import Doc, Map | ||
|
||
|
||
class YBaseDoc(ABC): | ||
def __init__(self, ydoc: Optional[Doc] = None): | ||
if ydoc is None: | ||
self._ydoc = Doc() | ||
else: | ||
self._ydoc = ydoc | ||
self._ystate = Map() | ||
self._ydoc["state"] = self._ystate | ||
self._subscriptions: Dict[Any, str] = {} | ||
|
||
@property | ||
@abstractmethod | ||
def version(self) -> str: | ||
... | ||
|
||
@property | ||
def ystate(self) -> Map: | ||
return self._ystate | ||
|
||
@property | ||
def ydoc(self) -> Doc: | ||
return self._ydoc | ||
|
||
@property | ||
def source(self) -> Any: | ||
return self.get() | ||
|
||
@source.setter | ||
def source(self, value: Any): | ||
return self.set(value) | ||
|
||
@property | ||
def dirty(self) -> Optional[bool]: | ||
return self._ystate.get("dirty") | ||
|
||
@dirty.setter | ||
def dirty(self, value: bool) -> None: | ||
self._ystate["dirty"] = value | ||
|
||
@property | ||
def path(self) -> Optional[str]: | ||
return self._ystate.get("path") | ||
|
||
@path.setter | ||
def path(self, value: str) -> None: | ||
self._ystate["path"] = value | ||
|
||
@abstractmethod | ||
def get(self) -> Any: | ||
... | ||
|
||
@abstractmethod | ||
def set(self, value: Any) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def observe(self, callback: Callable[[str, Any], None]) -> None: | ||
... | ||
|
||
def unobserve(self) -> None: | ||
for k, v in self._subscriptions.items(): | ||
k.unobserve(v) | ||
self._subscriptions = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import base64 | ||
from functools import partial | ||
from typing import Any, Callable, Optional, Union | ||
|
||
from pycrdt import Doc, Map | ||
|
||
from .ybasedoc import YBaseDoc | ||
|
||
|
||
class YBlob(YBaseDoc): | ||
""" | ||
Extends :class:`YBaseDoc`, and represents a blob document. | ||
It is currently encoded as base64 because of: | ||
https://github.com/y-crdt/ypy/issues/108#issuecomment-1377055465 | ||
The Y document can be set from bytes or from str, in which case it is assumed to be encoded as | ||
base64. | ||
""" | ||
|
||
def __init__(self, ydoc: Optional[Doc] = None): | ||
super().__init__(ydoc) | ||
self._ysource = Map() | ||
self._ydoc["source"] = self._ysource | ||
|
||
@property | ||
def version(self) -> str: | ||
return "1.0.0" | ||
|
||
def get(self) -> bytes: | ||
return base64.b64decode(self._ysource["base64"].encode()) | ||
|
||
def set(self, value: Union[bytes, str]) -> None: | ||
if isinstance(value, bytes): | ||
value = base64.b64encode(value).decode() | ||
self._ysource["base64"] = value | ||
|
||
def observe(self, callback: Callable[[str, Any], None]) -> None: | ||
self.unobserve() | ||
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) | ||
self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .yunicode import YUnicode | ||
|
||
|
||
class YFile(YUnicode): # for backwards-compatibility | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import copy | ||
import json | ||
from functools import partial | ||
from typing import Any, Callable, Dict, Optional | ||
from uuid import uuid4 | ||
|
||
from pycrdt import Array, Doc, Map, Text | ||
|
||
from .utils import cast_all | ||
from .ybasedoc import YBaseDoc | ||
|
||
# The default major version of the notebook format. | ||
NBFORMAT_MAJOR_VERSION = 4 | ||
# The default minor version of the notebook format. | ||
NBFORMAT_MINOR_VERSION = 5 | ||
|
||
|
||
class YNotebook(YBaseDoc): | ||
def __init__(self, ydoc: Optional[Doc] = None): | ||
super().__init__(ydoc) | ||
self._ymeta = Map() | ||
self._ycells = Array() | ||
self._ydoc["meta"] = self._ymeta | ||
self._ydoc["cells"] = self._ycells | ||
|
||
@property | ||
def version(self) -> str: | ||
return "1.0.0" | ||
|
||
@property | ||
def ycells(self): | ||
return self._ycells | ||
|
||
@property | ||
def cell_number(self) -> int: | ||
return len(self._ycells) | ||
|
||
def get_cell(self, index: int) -> Dict[str, Any]: | ||
meta = json.loads(str(self._ymeta)) | ||
cell = json.loads(str(self._ycells[index])) | ||
cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float | ||
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: | ||
# strip cell IDs if we have notebook format 4.0-4.4 | ||
del cell["id"] | ||
if ( | ||
"attachments" in cell | ||
and cell["cell_type"] in ("raw", "markdown") | ||
and not cell["attachments"] | ||
): | ||
del cell["attachments"] | ||
return cell | ||
|
||
def append_cell(self, value: Dict[str, Any]) -> None: | ||
ycell = self.create_ycell(value) | ||
self._ycells.append(ycell) | ||
|
||
def set_cell(self, index: int, value: Dict[str, Any]) -> None: | ||
ycell = self.create_ycell(value) | ||
self.set_ycell(index, ycell) | ||
|
||
def create_ycell(self, value: Dict[str, Any]) -> Map: | ||
cell = copy.deepcopy(value) | ||
if "id" not in cell: | ||
cell["id"] = str(uuid4()) | ||
cell_type = cell["cell_type"] | ||
cell_source = cell["source"] | ||
cell_source = "".join(cell_source) if isinstance(cell_source, list) else cell_source | ||
cell["source"] = Text(cell_source) | ||
cell["metadata"] = Map(cell.get("metadata", {})) | ||
|
||
if cell_type in ("raw", "markdown"): | ||
if "attachments" in cell and not cell["attachments"]: | ||
del cell["attachments"] | ||
elif cell_type == "code": | ||
cell["outputs"] = Array(cell.get("outputs", [])) | ||
|
||
return Map(cell) | ||
|
||
def set_ycell(self, index: int, ycell: Map) -> None: | ||
self._ycells[index] = ycell | ||
|
||
def get(self) -> Dict: | ||
meta = json.loads(str(self._ymeta)) | ||
cast_all(meta, float, int) # notebook coming from Yjs has e.g. nbformat as float | ||
cells = [] | ||
for i in range(len(self._ycells)): | ||
cell = self.get_cell(i) | ||
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: | ||
# strip cell IDs if we have notebook format 4.0-4.4 | ||
del cell["id"] | ||
if ( | ||
"attachments" in cell | ||
and cell["cell_type"] in ["raw", "markdown"] | ||
and not cell["attachments"] | ||
): | ||
del cell["attachments"] | ||
cells.append(cell) | ||
|
||
return dict( | ||
cells=cells, | ||
metadata=meta.get("metadata", {}), | ||
nbformat=int(meta.get("nbformat", 0)), | ||
nbformat_minor=int(meta.get("nbformat_minor", 0)), | ||
) | ||
|
||
def set(self, value: Dict) -> None: | ||
nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"} | ||
nb = copy.deepcopy(nb_without_cells) | ||
cast_all(nb, int, float) # Yjs expects numbers to be floating numbers | ||
cells = value["cells"] or [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": None, | ||
# auto-created empty code cell without outputs ought be trusted | ||
"metadata": {"trusted": True}, | ||
"outputs": [], | ||
"source": "", | ||
"id": str(uuid4()), | ||
} | ||
] | ||
|
||
with self._ydoc.transaction(): | ||
# clear document | ||
self._ymeta.clear() | ||
self._ycells.clear() | ||
for key in [k for k in self._ystate.keys() if k not in ("dirty", "path")]: | ||
del self._ystate[key] | ||
|
||
# initialize document | ||
self._ycells.extend([self.create_ycell(cell) for cell in cells]) | ||
self._ymeta["nbformat"] = nb.get("nbformat", NBFORMAT_MAJOR_VERSION) | ||
self._ymeta["nbformat_minor"] = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION) | ||
|
||
metadata = nb.get("metadata", {}) | ||
metadata.setdefault("language_info", {"name": ""}) | ||
metadata.setdefault("kernelspec", {"name": "", "display_name": ""}) | ||
|
||
self._ymeta["metadata"] = Map(metadata) | ||
|
||
def observe(self, callback: Callable[[str, Any], None]) -> None: | ||
self.unobserve() | ||
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) | ||
self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta")) | ||
self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from functools import partial | ||
from typing import Any, Callable, Optional | ||
|
||
from pycrdt import Doc, Text | ||
|
||
from .ybasedoc import YBaseDoc | ||
|
||
|
||
class YUnicode(YBaseDoc): | ||
def __init__(self, ydoc: Optional[Doc] = None): | ||
super().__init__(ydoc) | ||
self._ysource = Text() | ||
self._ydoc["source"] = self._ysource | ||
|
||
@property | ||
def version(self) -> str: | ||
return "1.0.0" | ||
|
||
def get(self) -> str: | ||
return str(self._ysource) | ||
|
||
def set(self, value: str) -> None: | ||
with self._ydoc.transaction(): | ||
# clear document | ||
del self._ysource[:] | ||
# initialize document | ||
if value: | ||
self._ysource += value | ||
|
||
def observe(self, callback: Callable[[str, Any], None]) -> None: | ||
self.unobserve() | ||
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) | ||
self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source")) |
Oops, something went wrong.