Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce valid identifiers for "key value keys" #571

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nats/js/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def __str__(self) -> str:
return "nats: history limited to a max of 64"


class InvalidKeyError(Error):
"""
Raised when trying to put an object in Key Value with an invalid key.
"""
pass


class InvalidBucketNameError(Error):
"""
Raised when trying to create a KV or OBJ bucket with invalid name.
Expand Down
24 changes: 24 additions & 0 deletions nats/js/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import datetime
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional

Expand All @@ -31,6 +32,14 @@
KV_PURGE = "PURGE"
MSG_ROLLUP_SUBJECT = "sub"

VALID_KEY_RE = re.compile(r'^[-/_=\.a-zA-Z0-9]+$')


def _is_key_valid(key: str) -> bool:
if len(key) == 0 or key[0] == '.' or key[-1] == '.':
return False
return bool(VALID_KEY_RE.match(key))


class KeyValue:
"""
Expand Down Expand Up @@ -126,6 +135,9 @@ async def get(self, key: str, revision: Optional[int] = None) -> Entry:
"""
get returns the latest value for the key.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError

entry = None
try:
entry = await self._get(key, revision)
Expand Down Expand Up @@ -182,13 +194,19 @@ async def put(self, key: str, value: bytes) -> int:
put will place the new value for the key into the store
and return the revision number.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

pa = await self._js.publish(f"{self._pre}{key}", value)
return pa.seq

async def create(self, key: str, value: bytes) -> int:
"""
create will add the key/value pair iff it does not exist.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

pa = None
try:
pa = await self.update(key, value, last=0)
Expand Down Expand Up @@ -221,6 +239,9 @@ async def update(
"""
update will update the value iff the latest revision matches.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

hdrs = {}
if not last:
last = 0
Expand All @@ -245,6 +266,9 @@ async def delete(self, key: str, last: Optional[int] = None) -> bool:
"""
delete will place a delete marker and remove all previous revisions.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

hdrs = {}
hdrs[KV_OP] = KV_DEL

Expand Down
69 changes: 69 additions & 0 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,72 @@ async def error_handler(e):
with pytest.raises(BadBucketError):
await js.key_value(bucket="TEST3")

@async_test
async def test_bucket_name_validation(self):
nc = await nats.connect()
js = nc.jetstream()

invalid_bucket_names = [
" x y",
"x ",
"x!",
"xx$",
"*",
">",
"x.>",
"x.*",
".",
".x",
".x.",
"x.",
]

for bucket_name in invalid_bucket_names:
with self.subTest(bucket_name):
with pytest.raises(InvalidBucketNameError):
await js.create_key_value(
bucket=bucket_name, history=5, ttl=3600
)

with pytest.raises(InvalidBucketNameError):
await js.key_value(bucket_name)

with pytest.raises(InvalidBucketNameError):
await js.delete_key_value(bucket_name)

@async_test
async def test_key_validation(self):
nc = await nats.connect()
js = nc.jetstream()

kv = await js.create_key_value(bucket="TEST", history=5, ttl=3600)
invalid_keys = [
" x y",
"x ",
"x!",
"xx$",
"*",
">",
"x.>",
"x.*",
".",
".x",
".x.",
"x.",
]

for key in invalid_keys:
with self.subTest(key):
# Invalid put (empty)
with pytest.raises(InvalidKeyError):
await kv.put(key, b'')

with pytest.raises(InvalidKeyError):
await kv.get(key)

with pytest.raises(InvalidKeyError):
await kv.update(key, b'')

@async_test
async def test_kv_basic(self):
errors = []
Expand All @@ -2406,6 +2472,9 @@ async def error_handler(e):
nc = await nats.connect(error_cb=error_handler)
js = nc.jetstream()

with pytest.raises(nats.js.errors.InvalidBucketNameError):
await js.create_key_value(bucket="notok!")

bucket = "TEST"
kv = await js.create_key_value(
bucket=bucket,
Expand Down