Skip to content

Commit

Permalink
PLAT-860: Reload LSTM model before generation.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 4950a43472c64abd8d63ce6f5aadba078948bbfd
  • Loading branch information
pimlock committed Jun 12, 2023
1 parent 0957a28 commit 051c6b7
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 33 deletions.
16 changes: 9 additions & 7 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ def __init__(
tokenizer: BaseTokenizerTrainer = None,
mode: str = WRITE,
checkpoint_dir: str = None,
validate_model: bool = True,
):

if mode not in (WRITE, READ): # pragma: no cover
Expand Down Expand Up @@ -1115,13 +1116,14 @@ def __init__(
except FileNotFoundError:
self.original_headers = None

logger.info("Validating underlying models exist via generation test...")
try:
self.generate_all_batch_lines(parallelism=1, num_lines=1)
except Exception as err:
raise RuntimeError(
"Error testing generation during model load"
) from err
if validate_model:
logger.info("Validating underlying models exist via generation test...")
try:
self.generate_all_batch_lines(parallelism=1, num_lines=1)
except Exception as err:
raise RuntimeError(
"Error testing generation during model load"
) from err

def _create_header_batches(self):
num_batches = ceil(len(self._source_df.columns) / self.batch_size)
Expand Down
58 changes: 42 additions & 16 deletions src/gretel_synthetics/generate_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
not have to be used directly. They are used automatically by the ``generate.py`` module based
on the parallelization settings configured for text generation.
"""
import logging
import os
import sys

Expand All @@ -20,6 +21,8 @@

from gretel_synthetics.errors import TooManyInvalidError

logger = logging.getLogger(__name__)


def get_num_workers(
parallelism: Union[int, float], total_lines: int, chunk_size: int = 5
Expand Down Expand Up @@ -180,32 +183,55 @@ def _loky_init_worker(settings: Settings):
# Workers should be using CPU only (note, this has no effect on the parent process)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Suppress stdout and stderr in worker threads. Do so on a best-effort basis only.
try:
devnull = open(os.devnull, "w")
try:
os.dup2(devnull.fileno(), sys.stdout.fileno())
os.dup2(devnull.fileno(), sys.stderr.fileno())
except BaseException:
pass
# On Windows in a Jupyter notebook, we have observed an
# 'OSError [WinError 1]: Incorrect function' when trying to sys.stdout/sys.stderr
# after the dup2 invocation. Hence, set the references explicitly to make prevent this
# (the native code writing to stdout/stderr directly seems to be unaffected).
sys.stdout = devnull
sys.stderr = devnull
except BaseException: # pylint: disable=broad-except
pass
# Re-configure logging for this worker, with a format that includes PID to
# identify the worker.
logging.basicConfig(
format="%(asctime)s [%(process)d] - %(levelname)s - %(name)s - %(message)s",
force=True,
)
logging.getLogger("gretel_synthetics.tensorflow").setLevel(logging.WARNING)
logger.info("Initializing new generation worker.")

global _loky_worker_generator
_loky_worker_generator = settings.generator(settings)

logger.info("Generation worker is ready.")
_suppress_stdout()

except BaseException as e: # pylint: disable=broad-except
global _loky_worker_init_exception
_loky_worker_init_exception = e
# Simply return without raising, to keep the worker alive.


def _suppress_stdout() -> None:
"""
Suppress stdout and stderr in worker threads. Do so on a best-effort basis only.
We have a separate mechanism for workers to report the progress back to the main
process, so logging/printing is not necessary and output can get garbled when
logging from multiple processes at once without any synchronization.
"""
try:
logger.info(
"NOTE: stdout and stderr is suppressed for this process, so nothing will be output."
)
devnull = open(os.devnull, "w")
try:
os.dup2(devnull.fileno(), sys.stdout.fileno())
os.dup2(devnull.fileno(), sys.stderr.fileno())
except BaseException:
pass
# On Windows in a Jupyter notebook, we have observed an
# 'OSError [WinError 1]: Incorrect function' when trying to sys.stdout/sys.stderr
# after the dup2 invocation. Hence, set the references explicitly to make prevent this
# (the native code writing to stdout/stderr directly seems to be unaffected).
sys.stdout = devnull
sys.stderr = devnull
except BaseException: # pylint: disable=broad-except
pass


def _loky_worker_process_chunk(
chunk_size: int, hard_limit: Optional[int] = None
) -> Tuple[int, List[GenText], int]:
Expand Down
19 changes: 17 additions & 2 deletions src/gretel_synthetics/tensorflow/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Tensorflow - Keras Sequential RNN (GRU)
"""
import logging

from typing import TYPE_CHECKING

import tensorflow as tf
Expand All @@ -16,6 +18,8 @@
BaseTokenizer = None
from keras import backend as k

logger = logging.getLogger(__name__)


def build_model(
vocab_size: int, batch_size: int, store: BaseConfig
Expand All @@ -30,7 +34,7 @@ def build_model(
else:
model = build_default_model(store, batch_size, vocab_size)

print(model.summary())
_print_model_summary(model)
return model


Expand All @@ -53,11 +57,22 @@ def _prepare_model(
model.load_weights(tf.train.latest_checkpoint(load_dir)).expect_partial()

model.build(tf.TensorShape([1, None]))
model.summary()

_print_model_summary(model)
return model


def _print_model_summary(model: tf.keras.Model) -> None:
model_summary = ""

def print_fn(line: str) -> None:
nonlocal model_summary
model_summary += "\t" + line + "\n"

model.summary(print_fn=print_fn)
logger.info("Model summary: \n%s", model_summary)


def load_model(
store: BaseConfig,
tokenizer: BaseTokenizer,
Expand Down
19 changes: 11 additions & 8 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

spm_logger = logging.getLogger("sentencepiece")
spm_logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -474,7 +475,7 @@ def _train(self, extra_symbols: Optional[List[str]] = None):
self.newline_str,
self.config.field_delimiter_token,
] + extra_symbols
logging.info("Training SentencePiece tokenizer")
logger.info("Training SentencePiece tokenizer")
spm.SentencePieceTrainer.Train(
input=self.config.training_data_path,
model_prefix=const.MODEL_PREFIX,
Expand Down Expand Up @@ -569,22 +570,24 @@ def _train(self):
def _log_sample_data(model_dir: str, sp: spm.SentencePieceProcessor):
training_data_path = Path(model_dir) / const.TRAINING_DATA
if not training_data_path.is_file():
logging.info("Training data not found for SP sampling")
logger.info("Training data not found for SP sampling")
return

with open(training_data_path) as fin:
sample = fin.readline().strip()

logging.info(f"Tokenizer model vocabulary size: {len(sp)} tokens")
logging.info(
logger.info(f"Tokenizer model vocabulary size: {len(sp)} tokens")
logger.info(
"Mapping first line of training data\n\n{}\n ---- sample tokens mapped to pieces ---- > \n{}\n".format(
repr(sample), ", ".join(sp.SampleEncodeAsPieces(sample, -1, 0.1))
)
),
extra={"maybe_sensitive": True},
)
logging.info(
logger.info(
"Mapping first line of training data\n\n{}\n ---- sample tokens mapped to int ---- > \n{}\n".format(
repr(sample), ", ".join([str(idx) for idx in sp.EncodeAsIds(sample)])
)
),
extra={"maybe_sensitive": True},
)


Expand All @@ -605,7 +608,7 @@ def load(cls, model_dir: str):
"""
sp = spm.SentencePieceProcessor()
model_fname = f"{const.MODEL_PREFIX}.model"
logging.info("Loading tokenizer from: %s", model_fname)
logger.info("Loading tokenizer from: %s", model_fname)
model_path = Path(model_dir) / model_fname
sp.Load(str(model_path))
_log_sample_data(model_dir, sp)
Expand Down

0 comments on commit 051c6b7

Please sign in to comment.