Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix concurrent access to self.cache in nosqldict #121

Merged
merged 3 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: pytest
run: |
docker run -d -p 27017:27017 mongo
pip install -r dev-requirements.txt
make test
- name: lint
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@ format:

test:
pip3 install .
coverage run -m pytest
ifdef GITHUB_ACTIONS
coverage run -m pytest -v --with_nosqldict
else
coverage run -m pytest -v
endif
coverage report -m
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="enochecker",
version="0.4.1",
version="0.4.2",
author="domenukk",
author_email="[email protected]",
description="Library to build checker scripts for EnoEngine A/D CTF Framework in Python",
Expand Down
2 changes: 1 addition & 1 deletion src/enochecker/enochecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
self.storage_dir = storage_dir

self._setup_logger()
if use_db_cache:
if use_db_cache and not os.getenv("MONGO_ENABLED"):
self._active_dbs: Dict[str, Union[NoSqlDict, StoredDict]] = global_db_cache
else:
self._active_dbs = {}
Expand Down
105 changes: 59 additions & 46 deletions src/enochecker/nosqldict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections.abc import MutableMapping
from functools import wraps
from threading import RLock, current_thread
from threading import Lock, RLock, current_thread
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union

from . import utils
Expand Down Expand Up @@ -135,16 +135,16 @@ def __init__(
self.checker_name = checker_name
self.cache: Dict[Any, Any] = {}
self.hash_cache: Dict[Any, Any] = {}
self._lock: Lock = Lock()
host_: str = host or DB_DEFAULT_HOST
if isinstance(port, int):
port_: int = port
else:
port_ = int(port or DB_DEFAULT_PORT)
username_: Optional[str] = username or DB_DEFAULT_USER
password_: Optional[str] = password or DB_DEFAULT_PASS
self.db = self.get_client(host_, port_, username_, password_, self.logger)[
checker_name
][self.dict_name]
self.client = self.get_client(host_, port_, username_, password_, self.logger)
self.db = self.client[checker_name][self.dict_name]
try:
self.db.index_information()["checker_key"]
except KeyError:
Expand All @@ -162,14 +162,17 @@ def __setitem__(self, key: str, value: Any) -> None:
:param key: key in the dictionary
:param value: value in the dictionary
"""
key = str(key)
with self._lock:
key = str(key)

self.cache[key] = value
hash_ = value_to_hash(value)
if hash_:
self.hash_cache[key] = hash_
self.cache[key] = value
hash_ = value_to_hash(value)
if hash_:
self.hash_cache[key] = hash_
elif key in self.hash_cache:
del self.hash_cache[key]

self._upsert(key, value)
self._upsert(key, value)

def _upsert(self, key: Any, value: Any) -> None:
query_dict = {
Expand Down Expand Up @@ -198,28 +201,31 @@ def __getitem__(self, key: str, print_result: bool = False) -> Any:
:param print_result: TODO
:return: retrieved value
"""
key = str(key)
if key in self.cache.items():
return self.cache[key]
with self._lock:
key = str(key)

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
if key in self.cache:
return self.cache[key]

result = self.db.find_one(to_extract)
to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}

if print_result:
self.logger.debug(result)
result = self.db.find_one(to_extract)

if result:
self.cache[key] = result["value"]
hash_ = value_to_hash(result)
if hash_:
self.hash_cache[key] = hash_
return result["value"]
raise KeyError("Could not find {} in {}".format(key, self))
if print_result:
self.logger.debug(result)

if result:
val = result["value"]
self.cache[key] = val
hash_ = value_to_hash(val)
if hash_:
self.hash_cache[key] = hash_
return val
raise KeyError("Could not find {} in {}".format(key, self))

@_try_n_times
def __delitem__(self, key: str) -> None:
Expand All @@ -230,16 +236,19 @@ def __delitem__(self, key: str) -> None:

:param key: key to delete
"""
key = str(key)
if key in self.cache:
del self.cache[key]

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
self.db.delete_one(to_extract)
with self._lock:
key = str(key)
if key in self.cache:
del self.cache[key]
if key in self.hash_cache:
del self.hash_cache[key]

to_extract = {
"key": key,
"checker": self.checker_name,
"name": self.dict_name,
}
self.db.delete_one(to_extract)

@_try_n_times
def __len__(self) -> int:
Expand Down Expand Up @@ -267,14 +276,18 @@ def persist(self) -> None:
"""
Persist the changes in the backend.
"""
for (key, value) in self.cache.items():
hash_ = value_to_hash(value)
if (
(not hash_)
or (key not in self.hash_cache)
or (self.hash_cache[key] != hash_)
):
self._upsert(key, value)
with self._lock:
for (key, value) in list(self.cache.items()):
hash_ = value_to_hash(value)
if (
(not hash_)
or (key not in self.hash_cache)
or (self.hash_cache[key] != hash_)
):
self._upsert(key, value)
del self.cache[key]
domenukk marked this conversation as resolved.
Show resolved Hide resolved
if key in self.hash_cache:
del self.hash_cache[key]

def __del__(self) -> None:
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest


def pytest_addoption(parser):
parser.addoption(
"--with_nosqldict", action="store_true", help="Run the tests with the nosqldict"
)


def pytest_configure(config):
config.addinivalue_line("markers", "nosqldict: mark test as requiring MongoDB")


def pytest_collection_modifyitems(config, items):
if config.getoption("--with_nosqldict"):
return
skip_nosqldict = pytest.mark.skip(reason="need --with_nosqldict option to run")
for item in items:
if "nosqldict" in item.keywords:
item.add_marker(skip_nosqldict)
40 changes: 40 additions & 0 deletions tests/test_enochecker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
import functools
import hashlib
import secrets
import sys
import tempfile
from logging import DEBUG
from unittest import mock

import pytest
from enochecker_core import CheckerMethod, CheckerTaskMessage, CheckerTaskResult
Expand Down Expand Up @@ -310,6 +312,44 @@ def putflagfn(self: CheckerExampleImpl):
assert result.attack_info == attack_info


@pytest.mark.nosqldict
def test_nested_change_enochecker():
import os

with mock.patch.dict(
os.environ,
{
"MONGO_ENABLED": "1",
},
):
dict_name = secrets.token_hex(8)

def putflagfn(self: CheckerExampleImpl):
db = self.db(dict_name)
x = {
"asd": 123,
}
db["test"] = x

x["asd"] = 456

def getflagfn(self: CheckerExampleImpl):
db = self.db(dict_name)
assert db["test"]["asd"] == 456

setattr(CheckerExampleImpl, "putflag", putflagfn)
checker = CheckerExampleImpl(method="putflag")

result = checker.run()
assert result.result == CheckerTaskResult.OK

setattr(CheckerExampleImpl, "getflag", getflagfn)
checker = CheckerExampleImpl(method="getflag")

result = checker.run()
assert result.result == CheckerTaskResult.OK


def main():
pytest.main(sys.argv)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_nosqldict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import secrets

import pytest

from enochecker.nosqldict import NoSqlDict


@pytest.fixture
def nosqldict():
dict_name = secrets.token_hex(8)
checker_name = secrets.token_hex(8)
return NoSqlDict(dict_name, checker_name)


@pytest.mark.nosqldict
def test_basic(nosqldict):
nosqldict["abc"] = "xyz"
assert nosqldict["abc"] == "xyz"

with pytest.raises(KeyError):
_ = nosqldict["xyz"]

nosqldict["abc"] = {"stuff": b"asd"}
assert nosqldict["abc"] == {"stuff": b"asd"}

del nosqldict["abc"]
with pytest.raises(KeyError):
_ = nosqldict["abc"]


@pytest.mark.nosqldict
def test_nested_change():
dict_name = secrets.token_hex(8)
checker_name = secrets.token_hex(8)

def scoped_access(dict_name, checker_name):
nosqldict = NoSqlDict(dict_name, checker_name)

x = {
"asd": 123,
}
nosqldict["test"] = x
x["asd"] = 456

assert nosqldict["test"] == {
"asd": 456,
}

scoped_access(dict_name, checker_name)

nosqldict_new = NoSqlDict(dict_name, checker_name)

assert nosqldict_new["test"] == {
"asd": 456,
}