Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(docker): Configure TensorBoard port through .env file #2397

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TENSORBOARD_PORT=6006
4 changes: 3 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ services:
- 7860:7860
environment:
SAFETENSORS_FAST_GPU: 1
TENSORBOARD_PORT: ${TENSORBOARD_PORT:-6006}
tmpfs:
- /tmp
volumes:
Expand Down Expand Up @@ -42,7 +43,8 @@ services:
container_name: tensorboard
image: tensorflow/tensorflow:latest-gpu
ports:
- 6006:6006
# !Please change the port in .env file
- ${TENSORBOARD_PORT:-6006}:6006
volumes:
- ./dataset/logs:/app/logs
command: tensorboard --logdir=/app/logs --bind_all
Expand Down
137 changes: 1 addition & 136 deletions kohya_gui/class_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,138 +17,6 @@
from .custom_logging import setup_logging


class TensorboardManager:
DEFAULT_TENSORBOARD_PORT = 6006

def __init__(self, logging_dir, headless: bool = False, wait_time=5):
self.logging_dir = logging_dir
self.headless = headless
self.wait_time = wait_time
self.tensorboard_proc = None
self.tensorboard_port = os.environ.get(
"TENSORBOARD_PORT", self.DEFAULT_TENSORBOARD_PORT
)
self.log = setup_logging()
self.thread = None
self.stop_event = Event()

self.gradio_interface()

def get_button_states(self, started=False):
return gr.Button(
visible=visibility and (not started or self.headless)
), gr.Button(visible=visibility and (started or self.headless))

def start_tensorboard(self, logging_dir=None):
if self.tensorboard_proc is not None:
self.log.info(
"Tensorboard is already running. Terminating existing process before starting new one..."
)
self.stop_tensorboard()

if not os.path.exists(logging_dir) or not os.listdir(logging_dir):
self.log.error(
"Error: logging folder does not exist or does not contain logs."
)
msgbox(msg="Error: logging folder does not exist or does not contain logs.")
return self.get_button_states(started=False)

run_cmd = [
"tensorboard",
"--logdir",
logging_dir,
"--host",
"0.0.0.0",
"--port",
str(self.tensorboard_port),
]

self.log.info(run_cmd)

self.log.info("Starting TensorBoard on port {}".format(self.tensorboard_port))
try:
env = os.environ.copy()
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
self.tensorboard_proc = subprocess.Popen(run_cmd, env=env)
except Exception as e:
self.log.error("Failed to start Tensorboard:", e)
return self.get_button_states(started=False)

def open_tensorboard_url():
time.sleep(self.wait_time)
if not self.stop_event.is_set():
tensorboard_url = f"http://localhost:{self.tensorboard_port}"
self.log.info(f"Opening TensorBoard URL in browser: {tensorboard_url}")
webbrowser.open(tensorboard_url)

if not self.headless:
self.stop_event.clear()
self.thread = Thread(target=open_tensorboard_url)
self.thread.start()

return self.get_button_states(started=True)

def stop_tensorboard(self):
if self.tensorboard_proc is not None:
self.log.info("Stopping tensorboard process...")
try:
self.tensorboard_proc.terminate()
self.tensorboard_proc = None
self.log.info("...process stopped")
except Exception as e:
self.log.error("Failed to stop Tensorboard:", e)

if self.thread is not None:
self.stop_event.set()
self.thread.join() # Wait for the thread to finish
self.thread = None
self.log.info("Thread terminated successfully.")

return self.get_button_states(started=False)

def gradio_interface(self):

with gr.Row():
button_start_tensorboard = gr.Button(
value="Start tensorboard",
elem_id="myTensorButton",
visible=visibility,
)
button_stop_tensorboard = gr.Button(
value="Stop tensorboard",
visible=visibility and self.headless,
elem_id="myTensorButtonStop",
)
button_start_tensorboard.click(
self.start_tensorboard,
inputs=[self.logging_dir],
outputs=[button_start_tensorboard, button_stop_tensorboard],
show_progress=False,
)
button_stop_tensorboard.click(
self.stop_tensorboard,
outputs=[button_start_tensorboard, button_stop_tensorboard],
show_progress=False,
)
import os
import gradio as gr
import subprocess
import time
import webbrowser

try:
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow # Attempt to import tensorflow to check if it is installed

visibility = True
except ImportError:
visibility = False

from easygui import msgbox
from threading import Thread, Event
from .custom_logging import setup_logging


class TensorboardManager:
DEFAULT_TENSORBOARD_PORT = 6006

Expand Down Expand Up @@ -256,6 +124,7 @@ def gradio_interface(self):
value="Open tensorboard",
elem_id="myTensorButton",
visible=not visibility,
link=f"http://localhost:{self.tensorboard_port}",
)
button_start_tensorboard.click(
self.start_tensorboard,
Expand All @@ -268,7 +137,3 @@ def gradio_interface(self):
outputs=[button_start_tensorboard, button_stop_tensorboard],
show_progress=False,
)
button_open_tensorboard.click(
self.open_tensorboard_url,
show_progress=False,
)