Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded file history #1819

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/prompt_toolkit/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def append_string(self, string: str) -> None:

class FileHistory(History):
"""
:class:`.History` class that stores all strings in a file.
:class:`.History` class that stores all strings in a file. You can optionally specify the
maximum amount of initially loaded commands.
"""

def __init__(self, filename: str) -> None:
Expand Down Expand Up @@ -300,3 +301,53 @@ def write(t: str) -> None:
write("\n# %s\n" % datetime.datetime.now())
for line in string.split("\n"):
write("+%s\n" % line)


class BoundedFileHistory(FileHistory):
"""
:class:`.History` class that stores all strings in a file but also limits the total number of
contained history items. The file will be re-written with the specified bound number as the
number of most recent history items when re-loading the history strings.
"""

def __init__(self, filename: str, bound: int) -> None:
self.bound = bound
super().__init__(filename)

def load_history_strings(self) -> Iterable[str]:
strings: list[str] = []
date_lines: list[bytes] = []
lines: list[str] = []

def add() -> None:
if lines:
# Join and drop trailing newline.
string = "".join(lines)[:-1]

strings.append(string)

if os.path.exists(self.filename):
with open(self.filename, "rb") as f:
for line_bytes in f:
line = line_bytes.decode("utf-8", errors="replace")
if line.startswith("+"):
lines.append(line[1:])
else:
if line.startswith("#"):
date_lines.append(line_bytes)
add()
lines = []

add()

if len(strings) > self.bound:
assert len(date_lines) == len(strings)
# Reverse the order, because newest items have to go first.
list_of_strings = list(reversed(strings))[: self.bound]
# Re-write the truncated file.
with open(self.filename, "wb") as f:
for date_str, string in zip(date_lines, strings):
f.write(date_str)
f.write(f"{string}\n".encode())
return list_of_strings
return reversed(strings)
34 changes: 33 additions & 1 deletion tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from asyncio import run

from prompt_toolkit.history import FileHistory, InMemoryHistory, ThreadedHistory
from prompt_toolkit.history import (
BoundedFileHistory,
FileHistory,
InMemoryHistory,
ThreadedHistory,
)


def _call_history_load(history):
Expand Down Expand Up @@ -60,6 +65,33 @@ def test_file_history(tmpdir):
assert _call_history_load(history2) == ["test3", "world", "hello"]


def test_bounded_file_history(tmpdir):
histfile = tmpdir.join("history")

history = BoundedFileHistory(histfile, bound=4)

history.append_string("hello")
history.append_string("world")

# Newest should yield first.
assert _call_history_load(history) == ["world", "hello"]

# Test another call.
assert _call_history_load(history) == ["world", "hello"]

history.append_string("test3")
assert _call_history_load(history) == ["test3", "world", "hello"]
history.append_string("test4")
assert _call_history_load(history) == ["test4", "test3", "world", "hello"]
history.append_string("test5")
# In-memory history still can contain more files.
assert _call_history_load(history) == ["test5", "test4", "test3", "world", "hello"]

# The newly loaded history will now only get four files.
new_history = BoundedFileHistory(histfile, bound=4)
assert _call_history_load(new_history) == ["test5", "test4", "test3", "world"]


def test_threaded_file_history(tmpdir):
histfile = tmpdir.join("history")

Expand Down
Loading