Skip to content

Commit

Permalink
#322 - Add possibility to keep files in variable
Browse files Browse the repository at this point in the history
  • Loading branch information
TytoCapensis committed Apr 26, 2024
1 parent 3c279bf commit e19055a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
7 changes: 5 additions & 2 deletions thehive4py/endpoints/alert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json as jsonlib
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional, Union

from thehive4py.endpoints._base import EndpointBase
from thehive4py.query import QueryExpr
Expand Down Expand Up @@ -96,7 +96,10 @@ def add_attachment(
)["attachments"]

def download_attachment(
self, alert_id: str, attachment_id: str, attachment_path: str
self,
alert_id: str,
attachment_id: str,
attachment_path: Optional[Union[str, Literal[False]]] = False,
) -> None:
return self._session.make_request(
"GET",
Expand Down
14 changes: 11 additions & 3 deletions thehive4py/endpoints/case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json as jsonlib
from typing import List, Optional, Sequence, Union
from typing import List, Literal, Optional, Sequence, Union

from thehive4py.endpoints._base import EndpointBase
from thehive4py.query import QueryExpr
Expand Down Expand Up @@ -82,7 +82,12 @@ def import_from_file(self, import_case: InputImportCase, import_path: str) -> di
files={"file": self._fileinfo_from_filepath(import_path)},
)

def export_to_file(self, case_id: CaseId, password: str, export_path: str) -> None:
def export_to_file(
self,
case_id: CaseId,
password: str,
export_path: Optional[Union[str, Literal[False]]] = False,
) -> None:
return self._session.make_request(
"GET",
path=f"/api/v1/case/{case_id}/export",
Expand All @@ -105,7 +110,10 @@ def add_attachment(
)["attachments"]

def download_attachment(
self, case_id: CaseId, attachment_id: str, attachment_path: str
self,
case_id: CaseId,
attachment_id: str,
attachment_path: Optional[Union[str, Literal[False]]] = False,
) -> None:
return self._session.make_request(
"GET",
Expand Down
4 changes: 2 additions & 2 deletions thehive4py/endpoints/observable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Literal, Optional, Union

from thehive4py.endpoints._base import EndpointBase
from thehive4py.query import QueryExpr
Expand Down Expand Up @@ -116,7 +116,7 @@ def download_attachment(
self,
observable_id: str,
attachment_id: str,
observable_path: str,
observable_path: Optional[Union[str, Literal[False]]] = False,
as_zip=False,
) -> None:
return self._session.make_request(
Expand Down
24 changes: 16 additions & 8 deletions thehive4py/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json as jsonlib
from collections import UserDict
from os import PathLike
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union

import requests
import requests.auth
Expand Down Expand Up @@ -58,7 +58,7 @@ def make_request(
data=None,
json=None,
files=None,
download_path: Union[str, PathLike, None] = None,
download_path: Union[str, PathLike, Literal[False], None] = None,
) -> Any:
endpoint_url = f"{self.hive_url}{path}"

Expand All @@ -75,21 +75,21 @@ def make_request(
files=files,
headers=headers,
verify=self.verify,
stream=bool(download_path),
stream=True if download_path is not None else False,
)

return self._process_response(response, download_path=download_path)

def _process_response(
self,
response: requests.Response,
download_path: Union[str, PathLike, None] = None,
download_path: Union[str, PathLike, Literal[False], None] = None,
):
if response.ok:
if download_path is None:
return self._process_text_response(response)
else:
self._process_stream_response(
return self._process_stream_response(
response=response, download_path=download_path
)

Expand All @@ -107,11 +107,19 @@ def _process_text_response(self, response: requests.Response):
return json_data

def _process_stream_response(
self, response: requests.Response, download_path: Union[str, PathLike]
self,
response: requests.Response,
download_path: Union[str, PathLike, Literal[False]],
):
with open(download_path, "wb") as download_fp:
if download_path:
with open(download_path, "wb") as download_fp:
for chunk in response.iter_content(chunk_size=4096):
download_fp.write(chunk)
else:
download_file = bytes()
for chunk in response.iter_content(chunk_size=4096):
download_fp.write(chunk)
download_file += chunk
return download_file

def _process_error_response(self, response: requests.Response):
try:
Expand Down

0 comments on commit e19055a

Please sign in to comment.