-
Notifications
You must be signed in to change notification settings - Fork 283
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(core): Add support for ollama module (#618)
- Added a new class OllamaContainer with few methods to handle the Ollama container. - The `_check_and_add_gpu_capabilities` method checks if the host has GPUs and adds the necessary capabilities to the container. - The `commit_to_image` allows to save somehow the state of a container into an image so that we can reuse it, especially for the ones having some models pulled. - Added tests to check the functionality of the new class. > Note: I inspired myself from the java implementation of the Ollama module. Fixes #617 --------- Co-authored-by: David Ankin <[email protected]>
- Loading branch information
1 parent
ead0f79
commit 5442d05
Showing
5 changed files
with
188 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.. autoclass:: testcontainers.ollama.OllamaContainer | ||
.. title:: testcontainers.ollama.OllamaContainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. You may obtain | ||
# a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from os import PathLike | ||
from typing import Any, Optional, TypedDict, Union | ||
|
||
from docker.types.containers import DeviceRequest | ||
from requests import get | ||
|
||
from testcontainers.core.container import DockerContainer | ||
from testcontainers.core.waiting_utils import wait_for_logs | ||
|
||
|
||
class OllamaModel(TypedDict): | ||
name: str | ||
model: str | ||
modified_at: str | ||
size: int | ||
digest: str | ||
details: dict[str, Any] | ||
|
||
|
||
class OllamaContainer(DockerContainer): | ||
""" | ||
Ollama Container | ||
Example: | ||
.. doctest:: | ||
>>> from testcontainers.ollama import OllamaContainer | ||
>>> with OllamaContainer() as ollama: | ||
... ollama.list_models() | ||
[] | ||
""" | ||
|
||
OLLAMA_PORT = 11434 | ||
|
||
def __init__( | ||
self, | ||
image: str = "ollama/ollama:0.1.44", | ||
ollama_dir: Optional[Union[str, PathLike]] = None, | ||
**kwargs, | ||
# | ||
): | ||
super().__init__(image=image, **kwargs) | ||
self.ollama_dir = ollama_dir | ||
self.with_exposed_ports(OllamaContainer.OLLAMA_PORT) | ||
self._check_and_add_gpu_capabilities() | ||
|
||
def _check_and_add_gpu_capabilities(self): | ||
info = self.get_docker_client().client.info() | ||
if "nvidia" in info["Runtimes"]: | ||
self._kwargs = {**self._kwargs, "device_requests": DeviceRequest(count=-1, capabilities=[["gpu"]])} | ||
|
||
def start(self) -> "OllamaContainer": | ||
""" | ||
Start the Ollama server | ||
""" | ||
if self.ollama_dir: | ||
self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw") | ||
super().start() | ||
wait_for_logs(self, "Listening on ", timeout=30) | ||
|
||
return self | ||
|
||
def get_endpoint(self): | ||
""" | ||
Return the endpoint of the Ollama server | ||
""" | ||
host = self.get_container_host_ip() | ||
exposed_port = self.get_exposed_port(OllamaContainer.OLLAMA_PORT) | ||
url = f"http://{host}:{exposed_port}" | ||
return url | ||
|
||
@property | ||
def id(self) -> str: | ||
""" | ||
Return the container object | ||
""" | ||
return self._container.id | ||
|
||
def pull_model(self, model_name: str) -> None: | ||
""" | ||
Pull a model from the Ollama server | ||
Args: | ||
model_name (str): Name of the model | ||
""" | ||
self.exec(f"ollama pull {model_name}") | ||
|
||
def list_models(self) -> list[OllamaModel]: | ||
endpoint = self.get_endpoint() | ||
response = get(url=f"{endpoint}/api/tags") | ||
response.raise_for_status() | ||
return response.json().get("models", []) | ||
|
||
def commit_to_image(self, image_name: str) -> None: | ||
""" | ||
Commit the current container to a new image | ||
Args: | ||
image_name (str): Name of the new image | ||
""" | ||
docker_client = self.get_docker_client() | ||
existing_images = docker_client.client.images.list(name=image_name) | ||
if not existing_images and self.id: | ||
docker_client.client.containers.get(self.id).commit( | ||
repository=image_name, conf={"Labels": {"org.testcontainers.session-id": ""}} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import random | ||
import string | ||
from pathlib import Path | ||
|
||
import requests | ||
from testcontainers.ollama import OllamaContainer | ||
|
||
|
||
def random_string(length=6): | ||
return "".join(random.choices(string.ascii_lowercase, k=length)) | ||
|
||
|
||
def test_ollama_container(): | ||
with OllamaContainer() as ollama: | ||
url = ollama.get_endpoint() | ||
response = requests.get(url) | ||
assert response.status_code == 200 | ||
assert response.text == "Ollama is running" | ||
|
||
|
||
def test_with_default_config(): | ||
with OllamaContainer("ollama/ollama:0.1.26") as ollama: | ||
ollama.start() | ||
response = requests.get(f"{ollama.get_endpoint()}/api/version") | ||
version = response.json().get("version") | ||
assert version == "0.1.26" | ||
|
||
|
||
def test_download_model_and_commit_to_image(): | ||
new_image_name = f"tc-ollama-allminilm-{random_string(length=4).lower()}" | ||
with OllamaContainer("ollama/ollama:0.1.26") as ollama: | ||
ollama.start() | ||
# Pull the model | ||
ollama.pull_model("all-minilm") | ||
|
||
response = requests.get(f"{ollama.get_endpoint()}/api/tags") | ||
model_name = ollama.list_models()[0].get("name") | ||
assert "all-minilm" in model_name | ||
|
||
# Commit the container state to a new image | ||
ollama.commit_to_image(new_image_name) | ||
|
||
# Verify the new image | ||
with OllamaContainer(new_image_name) as ollama: | ||
ollama.start() | ||
response = requests.get(f"{ollama.get_endpoint()}/api/tags") | ||
model_name = response.json().get("models", [])[0].get("name") | ||
assert "all-minilm" in model_name | ||
|
||
|
||
def test_models_saved_in_folder(tmp_path: Path): | ||
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama: | ||
assert len(ollama.list_models()) == 0 | ||
ollama.pull_model("all-minilm") | ||
assert len(ollama.list_models()) == 1 | ||
assert "all-minilm" in ollama.list_models()[0].get("name") | ||
|
||
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama: | ||
assert len(ollama.list_models()) == 1 | ||
assert "all-minilm" in ollama.list_models()[0].get("name") |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters