From 07dc94a068ad036f91b9c023ecb8469b001a97a3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 2 Feb 2024 20:30:44 -0500 Subject: [PATCH 1/2] pt: apply global logger Signed-off-by: Jinzhe Zeng --- deepmd/env.py | 6 ++++-- deepmd/pt/entrypoints/main.py | 34 +++++++++++++-------------------- deepmd/pt/infer/inference.py | 31 +++++++++++++++--------------- deepmd/pt/model/model/model.py | 10 ++++++---- deepmd/pt/model/task/dipole.py | 4 +++- deepmd/pt/model/task/ener.py | 6 ++++-- deepmd/pt/model/task/fitting.py | 8 +++++--- deepmd/pt/optimizer/LKF.py | 4 +++- deepmd/pt/train/training.py | 20 ++++++++++--------- deepmd/pt/utils/dataloader.py | 11 ++++++----- deepmd/pt/utils/dataset.py | 2 +- deepmd/pt/utils/finetune.py | 6 ++++-- deepmd/pt/utils/preprocess.py | 4 +++- deepmd/pt/utils/stat.py | 4 +++- deepmd/utils/pair_tab.py | 8 +++++--- 15 files changed, 86 insertions(+), 72 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index 1a8da63f8e..451b79d94f 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -13,6 +13,8 @@ "global_float_prec", ] +log = logging.getLogger(__name__) + # FLOAT_PREC dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "high").lower() if dp_float_prec in ("high", ""): @@ -47,7 +49,7 @@ def set_env_if_empty(key: str, value: str, verbose: bool = True): if os.environ.get(key) is None: os.environ[key] = value if verbose: - logging.warning( + log.warning( f"Environment variable {key} is empty. Use the default value {value}" ) @@ -72,7 +74,7 @@ def set_default_nthreads(): and "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ ) ): - logging.warning( + log.warning( "To get the best performance, it is recommended to adjust " "the number of threads by setting the environment variables " "OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and " diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index ad5e92d495..680d3313a6 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -30,6 +30,9 @@ from deepmd.infer.model_devi import ( make_model_devi, ) +from deepmd.loggers.loggers import ( + set_log_handles, +) from deepmd.main import ( parse_args, ) @@ -42,9 +45,6 @@ from deepmd.pt.train import ( training, ) -from deepmd.pt.utils import ( - env, -) from deepmd.pt.utils.dataloader import ( DpLoaderSet, ) @@ -58,6 +58,8 @@ make_stat_input, ) +log = logging.getLogger(__name__) + def get_trainer( config, @@ -237,7 +239,7 @@ def prepare_trainer_input_single( def train(FLAGS): - logging.info("Configuration path: %s", FLAGS.INPUT) + log.info("Configuration path: %s", FLAGS.INPUT) with open(FLAGS.INPUT) as fin: config = json.load(fin) trainer = get_trainer( @@ -278,28 +280,18 @@ def freeze(FLAGS): ) -# avoid logger conflicts of tf version -def clean_loggers(): - logger = logging.getLogger() - while logger.hasHandlers(): - logger.removeHandler(logger.handlers[0]) - - @record def main(args: Optional[Union[List[str], argparse.Namespace]] = None): - clean_loggers() - if not isinstance(args, argparse.Namespace): FLAGS = parse_args(args=args) else: FLAGS = args dict_args = vars(FLAGS) - logging.basicConfig( - level=logging.WARNING if env.LOCAL_RANK else logging.INFO, - format=f"%(asctime)-15s {os.environ.get('RANK') or ''} [%(filename)s:%(lineno)d] %(levelname)s %(message)s", - ) - logging.info("DeepMD version: %s", __version__) + set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None) + log.debug("Log handles were successfully set") + + log.info("DeepMD version: %s", __version__) if FLAGS.command == "train": train(FLAGS) @@ -315,9 +307,9 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) - elif args.command == "doc-train-input": + elif FLAGS.command == "doc-train-input": doc_train_input(**dict_args) - elif args.command == "model-devi": + elif FLAGS.command == "model-devi": dict_args["models"] = [ str(Path(mm).with_suffix(".pt")) if Path(mm).suffix not in (".pb", ".pt") @@ -325,7 +317,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): for mm in dict_args["models"] ] make_model_devi(**dict_args) - elif args.command == "gui": + elif FLAGS.command == "gui": start_dpgui(**dict_args) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 4906bb7a46..6a9f0d99d2 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -40,6 +40,7 @@ if torch.__version__.startswith("2"): import torch._dynamo +log = logging.getLogger(__name__) class Tester: @@ -95,9 +96,7 @@ def __init__( ), f"Validation systems not found in {input_script}!" self.systems = training_params["validation_data"]["systems"] self.batchsize = training_params["validation_data"]["batch_size"] - logging.info( - f"Testing validation systems in input script: {input_script}" - ) + log.info(f"Testing validation systems in input script: {input_script}") else: assert ( "data_dict" in training_params @@ -115,18 +114,18 @@ def __init__( self.batchsize = training_params["data_dict"][head]["validation_data"][ "batch_size" ] - logging.info( + log.info( f"Testing validation systems in head {head} of input script: {input_script}" ) elif system is not None: self.systems = expand_sys_str(system) self.batchsize = "auto" - logging.info("Testing systems in path: %s", system) + log.info("Testing systems in path: %s", system) elif datafile is not None: with open(datafile) as fin: self.systems = fin.read().splitlines() self.batchsize = "auto" - logging.info("Testing systems in file: %s", datafile) + log.info("Testing systems in file: %s", datafile) else: self.systems = None self.batchsize = None @@ -210,8 +209,8 @@ def run(self): system_results = {} global_sum_natoms = 0 for cc, system in enumerate(systems): - logging.info("# ---------------output of dp test--------------- ") - logging.info(f"# testing system : {system}") + log.info("# ---------------output of dp test--------------- ") + log.info(f"# testing system : {system}") system_pred = [] system_label = [] dataset = DpLoaderSet( @@ -226,7 +225,7 @@ def run(self): dataset, replacement=True, num_samples=dataset.total_batch ) if sampler is None: - logging.warning( + log.warning( "Sampler not specified!" ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. dataloader = DataLoader( @@ -296,8 +295,8 @@ def run(self): for k, v in single_results.items() } for item in sorted(results.keys()): - logging.info(f"{item}: {results[item]:.4f}") - logging.info("# ----------------------------------------------- ") + log.info(f"{item}: {results[item]:.4f}") + log.info("# ----------------------------------------------- ") for k, v in single_results.items(): system_results[k] = system_results.get(k, 0.0) + v global_sum_natoms += sum_natoms @@ -306,14 +305,14 @@ def run(self): k: v / global_sum_natoms if "mae" in k else math.sqrt(v / global_sum_natoms) for k, v in system_results.items() } - logging.info("# ----------weighted average of errors----------- ") + log.info("# ----------weighted average of errors----------- ") if not self.multi_task: - logging.info(f"# number of systems : {len(systems)}") + log.info(f"# number of systems : {len(systems)}") else: - logging.info(f"# number of systems for {self.head}: {len(systems)}") + log.info(f"# number of systems for {self.head}: {len(systems)}") for item in sorted(global_results.keys()): - logging.info(f"{item}: {global_results[item]:.4f}") - logging.info("# ----------------------------------------------- ") + log.info(f"{item}: {global_results[item]:.4f}") + log.info("# ----------------------------------------------- ") return global_results diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index 01c2d7b9d6..000746a213 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -12,6 +12,8 @@ compute_output_stats, ) +log = logging.getLogger(__name__) + class BaseModel(torch.nn.Module): def __init__(self): @@ -55,7 +57,7 @@ def compute_or_load_stat( if not os.path.exists(stat_file_dir): os.mkdir(stat_file_dir) if not isinstance(stat_file_path, list): - logging.info(f"Saving stat file to {stat_file_path}") + log.info(f"Saving stat file to {stat_file_path}") np.savez_compressed( stat_file_path, sumr=sumr, @@ -68,7 +70,7 @@ def compute_or_load_stat( ) else: for ii, file_path in enumerate(stat_file_path): - logging.info(f"Saving stat file to {file_path}") + log.info(f"Saving stat file to {file_path}") np.savez_compressed( file_path, sumr=sumr[ii], @@ -82,7 +84,7 @@ def compute_or_load_stat( else: # load stat target_type_map = type_map if not isinstance(stat_file_path, list): - logging.info(f"Loading stat file from {stat_file_path}") + log.info(f"Loading stat file from {stat_file_path}") stats = np.load(stat_file_path) stat_type_map = list(stats["type_map"]) missing_type = [ @@ -105,7 +107,7 @@ def compute_or_load_stat( sumr, suma, sumn, sumr2, suma2 = [], [], [], [], [] id_bias_atom_e = None for ii, file_path in enumerate(stat_file_path): - logging.info(f"Loading stat file from {file_path}") + log.info(f"Loading stat file from {file_path}") stats = np.load(file_path) stat_type_map = list(stats["type_map"]) missing_type = [ diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 4906987bf8..d911613a5b 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -10,6 +10,8 @@ Fitting, ) +log = logging.getLogger(__name__) + class DipoleFittingNetType(Fitting): def __init__( @@ -37,7 +39,7 @@ def __init__( self.filter_layers = torch.nn.ModuleList(filter_layers) if "seed" in kwargs: - logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + log.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) def forward(self, inputs, atype, atype_tebd, rot_mat): diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 484e477b6a..5e3cd87367 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -40,6 +40,8 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE +log = logging.getLogger(__name__) + @fitting_check_output class InvarFitting(Fitting): @@ -153,7 +155,7 @@ def __init__( # very bad design... if "seed" in kwargs: - logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + log.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) def output_def(self) -> FittingOutputDef: @@ -451,7 +453,7 @@ def __init__( self.filter_layers = torch.nn.ModuleList(filter_layers) if "seed" in kwargs: - logging.info("Set seed to %d in fitting net.", kwargs["seed"]) + log.info("Set seed to %d in fitting net.", kwargs["seed"]) torch.manual_seed(kwargs["seed"]) def output_def(self): diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 551fb9640b..b03aee7539 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -23,6 +23,8 @@ make_stat_input, ) +log = logging.getLogger(__name__) + class Fitting(torch.nn.Module, BaseFitting): __plugins = Plugin() @@ -115,7 +117,7 @@ def change_energy_bias( ntest : int The number of test samples in a system to change the energy bias. """ - logging.info( + log.info( "Changing energy bias in pretrained model for types {}... " "(this step may take long time)".format(str(new_type_map)) ) @@ -188,7 +190,7 @@ def change_energy_bias( self.bias_atom_e[idx_type_map] += torch.from_numpy( delta_bias.reshape(-1) ).to(DEVICE) - logging.info( + log.info( f"RMSE of atomic energy after linear regression is: {rmse_ae:10.5e} eV/atom." ) elif bias_shift == "statistic": @@ -202,7 +204,7 @@ def change_energy_bias( ) else: raise RuntimeError("Unknown bias_shift mode: " + bias_shift) - logging.info( + log.info( "Change energy bias of {} from {} to {}.".format( str(new_type_map), str(old_bias.detach().cpu().numpy()), diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 5e18797c7b..ebc9242d49 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -7,6 +7,8 @@ Optimizer, ) +log = logging.getLogger(__name__) + class LKFOptimizer(Optimizer): def __init__( @@ -59,7 +61,7 @@ def __init_P(self): P = [] params_packed_index = [] - logging.info("LKF parameter nums: %s" % param_nums) + log.info("LKF parameter nums: %s" % param_nums) for param_num in param_nums: if param_num >= block_size: block_num = math.ceil(param_num / block_size) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index ee0e7a54cc..02367f4aee 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -59,6 +59,8 @@ DataLoader, ) +log = logging.getLogger(__name__) + class Trainer: def __init__( @@ -140,7 +142,7 @@ def get_data_loader(_training_data, _validation_data, _training_params): valid_sampler = get_weighted_sampler(_validation_data, "prob_sys_size") if train_sampler is None or valid_sampler is None: - logging.warning( + log.warning( "Sampler not specified!" ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. training_dataloader = DataLoader( @@ -299,7 +301,7 @@ def get_loss(loss_params, start_lr, _ntypes): origin_model = ( finetune_model if finetune_model is not None else resume_model ) - logging.info(f"Resuming from {origin_model}.") + log.info(f"Resuming from {origin_model}.") state_dict = torch.load(origin_model, map_location=DEVICE) if "model" in state_dict: optimizer_state_dict = ( @@ -332,7 +334,7 @@ def get_loss(loss_params, start_lr, _ntypes): tmp_keys = ".".join(item.split(".")[:3]) slim_keys.append(tmp_keys) slim_keys = [i + ".*" for i in slim_keys] - logging.warning( + log.warning( f"Force load mode allowed! These keys are not in ckpt and will re-init: {slim_keys}" ) elif self.finetune_multi_task: @@ -451,9 +453,9 @@ def run(self): if SAMPLER_RECORD: record_file = f"Sample_rank_{self.rank}.txt" fout1 = open(record_file, mode="w", buffering=1) - logging.info("Start to train %d steps.", self.num_steps) + log.info("Start to train %d steps.", self.num_steps) if dist.is_initialized(): - logging.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") + log.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") if self.enable_tensorboard: from torch.utils.tensorboard import ( SummaryWriter, @@ -655,7 +657,7 @@ def log_loss_valid(_task_key="Default"): train_time = time.time() - self.t0 self.t0 = time.time() msg += f", speed={train_time:.2f} s/{self.disp_freq if _step_id else 1} batches" - logging.info(msg) + log.info(msg) if fout: if self.lcurve_should_print_header: @@ -674,7 +676,7 @@ def log_loss_valid(_task_key="Default"): module = self.wrapper.module if dist.is_initialized() else self.wrapper self.save_model(self.latest_model, lr=cur_lr, step=_step_id) - logging.info(f"Saved model to {self.latest_model}") + log.info(f"Saved model to {self.latest_model}") symlink_prefix_files(self.latest_model.stem, self.save_ckpt) with open("checkpoint", "w") as f: f.write(str(self.latest_model)) @@ -714,10 +716,10 @@ def log_loss_valid(_task_key="Default"): "frozen_model.pth" # We use .pth to denote the frozen model ) self.model.save(pth_model_path) - logging.info( + log.info( f"Frozen model for inferencing has been saved to {pth_model_path}" ) - logging.info(f"Trained model has been saved to: {self.save_ckpt}") + log.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: fout.close() diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 7a6684e82e..0ec43f5a75 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -40,6 +40,7 @@ process_sys_probs, ) +log = logging.getLogger(__name__) torch.multiprocessing.set_sharing_strategy("file_system") @@ -69,7 +70,7 @@ def __init__( self.systems: List[DeepmdDataSetForLoader] = [] if len(systems) >= 100: - logging.info(f"Constructing DataLoaders from {len(systems)} systems") + log.info(f"Constructing DataLoaders from {len(systems)} systems") def construct_dataset(system): ### this design requires "rcut" and "sel" in the descriptor @@ -119,7 +120,7 @@ def construct_dataset(system): rule = int(batch_size.split(":")[1]) else: rule = None - logging.error("Unsupported batch size type") + log.error("Unsupported batch size type") self.batch_size = rule // system._natoms if self.batch_size * system._natoms < rule: self.batch_size += 1 @@ -155,7 +156,7 @@ def __len__(self): return len(self.dataloaders) def __getitem__(self, idx): - # logging.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx])) + # log.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx])) try: batch = next(self.iters[idx]) except StopIteration: @@ -216,7 +217,7 @@ def __next__(self): self.warning_time is None or time.time() - self.warning_time > 15 * 60 ): - logging.warning( + log.warning( "Data loading buffer is empty or nearly empty. This may " "indicate a data loading bottleneck, and increasing the " "number of workers (--num-workers) may help." @@ -310,7 +311,7 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False): probs = prob_sys_size_ext(style, len(training_data), training_data.index) else: probs = process_sys_probs(prob_style, training_data.index) - logging.info("Generated weighted sampler with prob array: " + str(probs)) + log.info("Generated weighted sampler with prob array: " + str(probs)) # training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iteraters len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1) sampler = WeightedRandomSampler(probs, len_sampler, replacement=True) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 68d4a09ce4..aca4a9ce5b 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -449,7 +449,7 @@ def _load_data( if atomic: ndof *= self._natoms path = os.path.join(set_name, key + ".npy") - # logging.info('Loading data from: %s', path) + # log.info('Loading data from: %s', path) if os.path.isfile(path): if high_prec: data = np.load(path).astype(env.GLOBAL_ENER_FLOAT_PRECISION) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 9d82783cc0..82ff04071c 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -7,6 +7,8 @@ env, ) +log = logging.getLogger(__name__) + def change_finetune_model_params( ckpt, finetune_model, model_config, multi_task=False, model_branch="" @@ -45,7 +47,7 @@ def change_finetune_model_params( old_type_map ), "Only support for smaller type map when finetuning or resuming." model_config = last_model_params - logging.info( + log.info( "Change the model configurations according to the pretrained one..." ) model_config["new_type_map"] = new_type_map @@ -83,7 +85,7 @@ def change_finetune_model_params( model_config["fitting_net"] = model_dict_params[model_branch_chosen][ "fitting_net" ] - logging.info( + log.info( f"Change the model configurations according to the model branch " f"{model_branch_chosen} in the pretrained one..." ) diff --git a/deepmd/pt/utils/preprocess.py b/deepmd/pt/utils/preprocess.py index 18c798138e..806bacdcd2 100644 --- a/deepmd/pt/utils/preprocess.py +++ b/deepmd/pt/utils/preprocess.py @@ -10,6 +10,8 @@ env, ) +log = logging.getLogger(__name__) + class Region3D: def __init__(self, boxt): @@ -263,7 +265,7 @@ def make_env_mat( ) merged_coord = coord[merged_mapping] - merged_coord_shift if merged_coord.shape[0] <= coord.shape[0]: - logging.warning("No ghost atom is added for system ") + log.warning("No ghost atom is added for system ") else: merged_coord_shift = torch.zeros_like(coord) merged_atype = atype.clone() diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index eec7179bcd..5fde03c74a 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -8,6 +8,8 @@ env, ) +log = logging.getLogger(__name__) + def make_stat_input(datasets, dataloaders, nbatches): """Pack data for statistics. @@ -36,7 +38,7 @@ def make_stat_input(datasets, dataloaders, nbatches): ] if datasets[0].mixed_type: keys.append("real_natoms_vec") - logging.info(f"Packing data for statistics from {len(datasets)} systems") + log.info(f"Packing data for statistics from {len(datasets)} systems") for i in range(len(datasets)): sys_stat = {key: [] for key in keys} iterator = iter(dataloaders[i]) diff --git a/deepmd/utils/pair_tab.py b/deepmd/utils/pair_tab.py index 56f8e618df..57157fbd00 100644 --- a/deepmd/utils/pair_tab.py +++ b/deepmd/utils/pair_tab.py @@ -12,6 +12,8 @@ CubicSpline, ) +log = logging.getLogger(__name__) + class PairTab: """Pairwise tabulated potential. @@ -114,7 +116,7 @@ def _check_table_upper_boundary(self) -> None: if np.all(upper_val == 0): # if table values decay to `0` after rcut if self.rcut < self.rmax and np.any(self.vdata[rcut_idx - 1][1:] != 0): - logging.warning( + log.warning( "The energy provided in the table does not decay to 0 at rcut." ) # if table values decay to `0` at rcut, do nothing @@ -131,12 +133,12 @@ def _check_table_upper_boundary(self) -> None: else: # if table values do not decay to `0` at rcut if self.rcut <= self.rmax: - logging.warning( + log.warning( "The energy provided in the table does not decay to 0 at rcut." ) # if rcut goes beyond table upper bond, need extrapolation, ensure values decay to `0` before rcut. else: - logging.warning( + log.warning( "The rcut goes beyond table upper boundary, performing extrapolation." ) pad_extrapolation = np.zeros((rcut_idx - upper_idx, self.ncol)) From 2333b4c3644acb60964efea3d6b44c63f88f6d0b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 2 Feb 2024 20:33:40 -0500 Subject: [PATCH 2/2] avoid print Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/wrapper.py | 8 ++++++-- deepmd/pt/utils/finetune.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index fe423e6318..2207f111a0 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( Dict, Optional, @@ -11,6 +12,9 @@ import torch._dynamo +log = logging.getLogger(__name__) + + class ModelWrapper(torch.nn.Module): def __init__( self, @@ -124,7 +128,7 @@ def share_params(self, shared_links, resume=False): link_class.share_params( base_class, shared_level_link, resume=resume ) - print( + log.warning( f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) else: @@ -146,7 +150,7 @@ def share_params(self, shared_links, resume=False): link_class.share_params( base_class, shared_level_link, resume=resume ) - print( + log.warning( f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 82ff04071c..13749da151 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -21,7 +21,7 @@ def change_finetune_model_params( """ if multi_task: # TODO - print("finetune mode need modification for multitask mode!") + log.error("finetune mode need modification for multitask mode!") if finetune_model is not None: state_dict = torch.load(finetune_model, map_location=env.DEVICE) if "model" in state_dict: @@ -59,7 +59,7 @@ def change_finetune_model_params( model_branch_chosen = next(iter(model_dict_params.keys())) new_fitting = True model_config["bias_shift"] = "statistic" # fitting net re-init - print( + log.warning( "The fitting net will be re-init instead of using that in the pretrained model! " "The bias_shift will be statistic!" )