diff --git a/src/gretel_synthetics/batch.py b/src/gretel_synthetics/batch.py index 2895f494..d3832603 100644 --- a/src/gretel_synthetics/batch.py +++ b/src/gretel_synthetics/batch.py @@ -1027,6 +1027,7 @@ def __init__( tokenizer: BaseTokenizerTrainer = None, mode: str = WRITE, checkpoint_dir: str = None, + validate_model: bool = True, ): if mode not in (WRITE, READ): # pragma: no cover @@ -1115,13 +1116,14 @@ def __init__( except FileNotFoundError: self.original_headers = None - logger.info("Validating underlying models exist via generation test...") - try: - self.generate_all_batch_lines(parallelism=1, num_lines=1) - except Exception as err: - raise RuntimeError( - "Error testing generation during model load" - ) from err + if validate_model: + logger.info("Validating underlying models exist via generation test...") + try: + self.generate_all_batch_lines(parallelism=1, num_lines=1) + except Exception as err: + raise RuntimeError( + "Error testing generation during model load" + ) from err def _create_header_batches(self): num_batches = ceil(len(self._source_df.columns) / self.batch_size) diff --git a/src/gretel_synthetics/generate_parallel.py b/src/gretel_synthetics/generate_parallel.py index 0f0311e5..60d775c9 100644 --- a/src/gretel_synthetics/generate_parallel.py +++ b/src/gretel_synthetics/generate_parallel.py @@ -3,6 +3,7 @@ not have to be used directly. They are used automatically by the ``generate.py`` module based on the parallelization settings configured for text generation. """ +import logging import os import sys @@ -20,6 +21,8 @@ from gretel_synthetics.errors import TooManyInvalidError +logger = logging.getLogger(__name__) + def get_num_workers( parallelism: Union[int, float], total_lines: int, chunk_size: int = 5 @@ -180,32 +183,55 @@ def _loky_init_worker(settings: Settings): # Workers should be using CPU only (note, this has no effect on the parent process) os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - # Suppress stdout and stderr in worker threads. Do so on a best-effort basis only. - try: - devnull = open(os.devnull, "w") - try: - os.dup2(devnull.fileno(), sys.stdout.fileno()) - os.dup2(devnull.fileno(), sys.stderr.fileno()) - except BaseException: - pass - # On Windows in a Jupyter notebook, we have observed an - # 'OSError [WinError 1]: Incorrect function' when trying to sys.stdout/sys.stderr - # after the dup2 invocation. Hence, set the references explicitly to make prevent this - # (the native code writing to stdout/stderr directly seems to be unaffected). - sys.stdout = devnull - sys.stderr = devnull - except BaseException: # pylint: disable=broad-except - pass + # Re-configure logging for this worker, with a format that includes PID to + # identify the worker. + logging.basicConfig( + format="%(asctime)s [%(process)d] - %(levelname)s - %(name)s - %(message)s", + force=True, + ) + logging.getLogger("gretel_synthetics.tensorflow").setLevel(logging.WARNING) + logger.info("Initializing new generation worker.") global _loky_worker_generator _loky_worker_generator = settings.generator(settings) + logger.info("Generation worker is ready.") + _suppress_stdout() + except BaseException as e: # pylint: disable=broad-except global _loky_worker_init_exception _loky_worker_init_exception = e # Simply return without raising, to keep the worker alive. +def _suppress_stdout() -> None: + """ + Suppress stdout and stderr in worker threads. Do so on a best-effort basis only. + + We have a separate mechanism for workers to report the progress back to the main + process, so logging/printing is not necessary and output can get garbled when + logging from multiple processes at once without any synchronization. + """ + try: + logger.info( + "NOTE: stdout and stderr is suppressed for this process, so nothing will be output." + ) + devnull = open(os.devnull, "w") + try: + os.dup2(devnull.fileno(), sys.stdout.fileno()) + os.dup2(devnull.fileno(), sys.stderr.fileno()) + except BaseException: + pass + # On Windows in a Jupyter notebook, we have observed an + # 'OSError [WinError 1]: Incorrect function' when trying to sys.stdout/sys.stderr + # after the dup2 invocation. Hence, set the references explicitly to make prevent this + # (the native code writing to stdout/stderr directly seems to be unaffected). + sys.stdout = devnull + sys.stderr = devnull + except BaseException: # pylint: disable=broad-except + pass + + def _loky_worker_process_chunk( chunk_size: int, hard_limit: Optional[int] = None ) -> Tuple[int, List[GenText], int]: diff --git a/src/gretel_synthetics/tensorflow/model.py b/src/gretel_synthetics/tensorflow/model.py index 9d1b5d9b..07c7dc7b 100644 --- a/src/gretel_synthetics/tensorflow/model.py +++ b/src/gretel_synthetics/tensorflow/model.py @@ -1,6 +1,8 @@ """ Tensorflow - Keras Sequential RNN (GRU) """ +import logging + from typing import TYPE_CHECKING import tensorflow as tf @@ -16,6 +18,8 @@ BaseTokenizer = None from keras import backend as k +logger = logging.getLogger(__name__) + def build_model( vocab_size: int, batch_size: int, store: BaseConfig @@ -30,7 +34,7 @@ def build_model( else: model = build_default_model(store, batch_size, vocab_size) - print(model.summary()) + _print_model_summary(model) return model @@ -53,11 +57,22 @@ def _prepare_model( model.load_weights(tf.train.latest_checkpoint(load_dir)).expect_partial() model.build(tf.TensorShape([1, None])) - model.summary() + _print_model_summary(model) return model +def _print_model_summary(model: tf.keras.Model) -> None: + model_summary = "" + + def print_fn(line: str) -> None: + nonlocal model_summary + model_summary += "\t" + line + "\n" + + model.summary(print_fn=print_fn) + logger.info("Model summary: \n%s", model_summary) + + def load_model( store: BaseConfig, tokenizer: BaseTokenizer, diff --git a/src/gretel_synthetics/tokenizers.py b/src/gretel_synthetics/tokenizers.py index aebe4f8a..76f777fc 100644 --- a/src/gretel_synthetics/tokenizers.py +++ b/src/gretel_synthetics/tokenizers.py @@ -62,6 +62,7 @@ format="%(asctime)s : %(threadName)s : %(levelname)s : %(message)s", level=logging.INFO, ) +logger = logging.getLogger(__name__) spm_logger = logging.getLogger("sentencepiece") spm_logger.setLevel(logging.INFO) @@ -474,7 +475,7 @@ def _train(self, extra_symbols: Optional[List[str]] = None): self.newline_str, self.config.field_delimiter_token, ] + extra_symbols - logging.info("Training SentencePiece tokenizer") + logger.info("Training SentencePiece tokenizer") spm.SentencePieceTrainer.Train( input=self.config.training_data_path, model_prefix=const.MODEL_PREFIX, @@ -569,22 +570,24 @@ def _train(self): def _log_sample_data(model_dir: str, sp: spm.SentencePieceProcessor): training_data_path = Path(model_dir) / const.TRAINING_DATA if not training_data_path.is_file(): - logging.info("Training data not found for SP sampling") + logger.info("Training data not found for SP sampling") return with open(training_data_path) as fin: sample = fin.readline().strip() - logging.info(f"Tokenizer model vocabulary size: {len(sp)} tokens") - logging.info( + logger.info(f"Tokenizer model vocabulary size: {len(sp)} tokens") + logger.info( "Mapping first line of training data\n\n{}\n ---- sample tokens mapped to pieces ---- > \n{}\n".format( repr(sample), ", ".join(sp.SampleEncodeAsPieces(sample, -1, 0.1)) - ) + ), + extra={"maybe_sensitive": True}, ) - logging.info( + logger.info( "Mapping first line of training data\n\n{}\n ---- sample tokens mapped to int ---- > \n{}\n".format( repr(sample), ", ".join([str(idx) for idx in sp.EncodeAsIds(sample)]) - ) + ), + extra={"maybe_sensitive": True}, ) @@ -605,7 +608,7 @@ def load(cls, model_dir: str): """ sp = spm.SentencePieceProcessor() model_fname = f"{const.MODEL_PREFIX}.model" - logging.info("Loading tokenizer from: %s", model_fname) + logger.info("Loading tokenizer from: %s", model_fname) model_path = Path(model_dir) / model_fname sp.Load(str(model_path)) _log_sample_data(model_dir, sp)