Skip to content

Commit

Permalink
Enforce valid values for key value keys
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervonb committed Jun 17, 2024
1 parent d952419 commit cf16997
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
7 changes: 7 additions & 0 deletions nats/js/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def __str__(self) -> str:
return "nats: history limited to a max of 64"


class InvalidKeyError(Error):
"""
Raised when trying to put an object in Key Value with an invalid key.
"""
pass


class InvalidBucketNameError(Error):
"""
Raised when trying to create a KV or OBJ bucket with invalid name.
Expand Down
23 changes: 23 additions & 0 deletions nats/js/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import datetime
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional

Expand All @@ -31,6 +32,13 @@
KV_PURGE = "PURGE"
MSG_ROLLUP_SUBJECT = "sub"

VALID_KEY_RE = re.compile(r'^[-/_=\.a-zA-Z0-9]+$')

def _is_key_valid(key: str) -> bool:
if len(key) == 0 or key[0] == '.' or key[-1] == '.':
return False
return bool(VALID_KEY_RE.match(key))


class KeyValue:
"""
Expand Down Expand Up @@ -126,6 +134,9 @@ async def get(self, key: str, revision: Optional[int] = None) -> Entry:
"""
get returns the latest value for the key.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError

entry = None
try:
entry = await self._get(key, revision)
Expand Down Expand Up @@ -182,13 +193,19 @@ async def put(self, key: str, value: bytes) -> int:
put will place the new value for the key into the store
and return the revision number.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

pa = await self._js.publish(f"{self._pre}{key}", value)
return pa.seq

async def create(self, key: str, value: bytes) -> int:
"""
create will add the key/value pair iff it does not exist.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

pa = None
try:
pa = await self.update(key, value, last=0)
Expand Down Expand Up @@ -221,6 +238,9 @@ async def update(
"""
update will update the value iff the latest revision matches.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

hdrs = {}
if not last:
last = 0
Expand All @@ -245,6 +265,9 @@ async def delete(self, key: str, last: Optional[int] = None) -> bool:
"""
delete will place a delete marker and remove all previous revisions.
"""
if not _is_key_valid(key):
raise nats.js.errors.InvalidKeyError(key)

hdrs = {}
hdrs[KV_OP] = KV_DEL

Expand Down
69 changes: 69 additions & 0 deletions tests/test_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,72 @@ async def error_handler(e):
with pytest.raises(BadBucketError):
await js.key_value(bucket="TEST3")

@async_test
async def test_bucket_name_validation(self):
nc = await nats.connect()
js = nc.jetstream()

invalid_bucket_names = [
" x y",
"x ",
"x!",
"xx$",
"*",
">",
"x.>",
"x.*",
".",
".x",
".x.",
"x.",
]

for bucket_name in invalid_bucket_names:
with self.subTest(bucket_name):
with pytest.raises(InvalidBucketNameError):
await js.create_key_value(
bucket=bucket_name, history=5, ttl=3600
)

with pytest.raises(InvalidBucketNameError):
await js.key_value(bucket_name)

with pytest.raises(InvalidBucketNameError):
await js.delete_key_value(bucket_name)

@async_test
async def test_key_validation(self):
nc = await nats.connect()
js = nc.jetstream()

kv = await js.create_key_value(bucket="TEST", history=5, ttl=3600)
invalid_keys = [
" x y",
"x ",
"x!",
"xx$",
"*",
">",
"x.>",
"x.*",
".",
".x",
".x.",
"x.",
]

for key in invalid_keys:
with self.subTest(key):
# Invalid put (empty)
with pytest.raises(InvalidKeyError):
await kv.put(key, b'')

with pytest.raises(InvalidKeyError):
await kv.get(key)

with pytest.raises(InvalidKeyError):
await kv.update(key, b'')

@async_test
async def test_kv_basic(self):
errors = []
Expand All @@ -2406,6 +2472,9 @@ async def error_handler(e):
nc = await nats.connect(error_cb=error_handler)
js = nc.jetstream()

with pytest.raises(nats.js.errors.InvalidBucketNameError):
await js.create_key_value(bucket="notok!")

bucket = "TEST"
kv = await js.create_key_value(
bucket=bucket,
Expand Down

0 comments on commit cf16997

Please sign in to comment.