Spaces:
Sleeping
Sleeping
| import os | |
| from tqdm import tqdm | |
| import psutil | |
| from datetime import datetime | |
| from mwm.constants import * | |
| from mwm.utils.common import read_yaml, load_json | |
| from mwm.config.configuration import get_params | |
| from mwm import logger | |
| # Model architecture | |
| import segmentation_models_pytorch as smp | |
| from mwm.components.model_architecture import * | |
| # Dataset | |
| from mwm.components.dataset import * | |
| from torch.utils.data import DataLoader | |
| # Loss | |
| from mwm.components.loss import * | |
| # Metrics logger | |
| import mlflow | |
| from mwm.components.metrics_logger import * | |
| class Training: | |
| def __init__( | |
| self, | |
| config_filepath = CONFIG_FILE_PATH, | |
| params_filepath = PARAMS_FILE_PATH | |
| ): | |
| self.config = read_yaml(config_filepath) | |
| self.params = get_params(params_filepath, "training") | |
| # Make model | |
| self.model = make_model(self.params.network, self.params.encoder_weights) | |
| freeze_encoder(self.model, self.params.freeze_encoder_layers) | |
| # Make dataset | |
| self.image_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.image_dir) | |
| self.mask_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.mask_dir) | |
| self.sdm_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.sdm_dir) | |
| # TODO: update with cross-validation | |
| # - Train dataset | |
| with open(os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.training_set_file), "r") as f: | |
| self.image_list_train = f.read().splitlines()[:self.params.num_training_samples] | |
| self.train_dataset = make_dataset( | |
| self.params.dataset, | |
| self.image_dir, | |
| self.mask_dir, | |
| self.sdm_dir, | |
| self.image_list_train, | |
| "train", | |
| self.params.image_size | |
| ) | |
| # - Validation dataset | |
| with open(os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.validation_set_file), "r") as f: | |
| self.image_list_val = f.read().splitlines() | |
| self.val_dataset = make_dataset( | |
| self.params.dataset, | |
| self.image_dir, | |
| self.mask_dir, | |
| self.sdm_dir, | |
| self.image_list_val, | |
| "val", | |
| self.params.image_size | |
| ) | |
| def handle_device(self): | |
| # Move model to GPU if available | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = self.model.to(self.device) | |
| def make_criterion(self): | |
| if self.params.loss == "weighted_dice_bce_2ch": | |
| self.criterion = WeightedDiceBCELoss( | |
| weight_1=self.params.weighted_dice_bce_2ch.weight_1, | |
| weight_2=self.params.weighted_dice_bce_2ch.weight_2, | |
| weight_3=self.params.weighted_dice_bce_2ch.weight_3, | |
| bce_weight=self.params.weighted_dice_bce_2ch.bce_weight, | |
| grad_weight=self.params.weighted_dice_bce_2ch.grad_weight, | |
| use_focal=self.params.weighted_dice_bce_2ch.use_focal, | |
| use_gradient_loss=self.params.weighted_dice_bce_2ch.use_gradient_loss, | |
| use_dist_loss=self.params.weighted_dice_bce_2ch.use_dist_loss | |
| ) | |
| logger.info(f"Loss: {self.params.loss} selected. ") | |
| else: | |
| logger.error(f"Invalid loss: {self.params.loss}") | |
| raise ValueError(f"Invalid loss: {self.params.loss}") | |
| def make_optimizer(self): | |
| if self.params.optimizer == "adam": | |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params.learning_rate) | |
| logger.info(f"Optimizer: {self.params.optimizer} selected. ") | |
| else: | |
| logger.error(f"Invalid optimizer: {self.params.optimizer}") | |
| raise ValueError(f"Invalid optimizer: {self.params.optimizer}") | |
| def train_epoch(self): | |
| batch_progress_bar = tqdm(range(self.params.steps_per_epoch), desc=f"Epoch {self.this_epoch}/{self.params.epochs-1}", leave=True) | |
| ### Training Phase ### | |
| self.model.train() | |
| for step in batch_progress_bar: | |
| images, masks, sdms = next(iter(self.train_loader)) | |
| images, masks, sdms = images.to(self.device), masks.to(self.device), sdms.to(self.device) | |
| self.optimizer.zero_grad() # Reset gradients | |
| outputs = self.model(images) # Forward pass | |
| loss = self.criterion(outputs, masks, sdms) # Compute loss | |
| loss.backward() # Backpropagation | |
| self.optimizer.step() # Update weights | |
| self.metrics_logger.update_sum(loss, outputs.cpu(), masks.cpu()) | |
| # Get CPU & RAM usage for display/monitoring | |
| ram_used = psutil.virtual_memory().used / 1024**3 | |
| batch_progress_bar.set_postfix(loss=loss.item(), ram_used=f"{ram_used:.2f} GB", cpu_usage=f"{psutil.cpu_percent()}%") | |
| self.metrics_logger.update_mean( | |
| self.params.steps_per_epoch, | |
| self.params.steps_per_epoch * self.params.batch_size | |
| ) | |
| ### Validation Phase ### | |
| self.model.eval() | |
| batch_progress_bar = tqdm(self.val_loader, desc=f"Epoch {self.this_epoch}/{self.params.epochs-1} validation", leave=True) | |
| with torch.no_grad(): | |
| for images, masks, sdms in batch_progress_bar: | |
| images, masks, sdms = images.to(self.device), masks.to(self.device), sdms.to(self.device) | |
| outputs = self.model(images) | |
| loss = self.criterion(outputs, masks, sdms) | |
| self.metrics_logger.update_sum_val(loss) | |
| batch_progress_bar.set_postfix(val_loss=loss.item()) | |
| self.metrics_logger.update_mean_val(len(self.val_loader)) | |
| self.metrics_logger.log_metrics_mlflow(self.this_epoch) # Logger is reset afterwards | |
| def train(self, save_model=False, save_interval=10): | |
| # Initialize metrics logger | |
| if self.params.metrics_logger == "metrics_logger_2ch": | |
| self.metrics_logger = MetricsLogger2Channel() | |
| else: | |
| logger.error(f"Invalid metrics logger: {self.params.metrics_logger}") | |
| raise ValueError(f"Invalid metrics logger: {self.params.metrics_logger}") | |
| # Define data loaders | |
| self.train_loader = DataLoader(self.train_dataset, batch_size=self.params.batch_size, shuffle=True) | |
| self.val_loader = DataLoader(self.val_dataset, batch_size=self.params.batch_size, shuffle=False) | |
| # Start training | |
| mlflow.set_experiment("Training") | |
| with mlflow.start_run(): | |
| for epoch in range(self.params.epochs): | |
| self.this_epoch = epoch | |
| self.train_epoch() | |
| if save_model: | |
| if (epoch+1) % save_interval == 0: | |
| self.save_model() | |
| mlflow.log_params(self.params.to_dict()) | |
| logger.info("Training completed. ") | |
| def save_model(self): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = os.path.join(self.config.model.model_dir, f"model_epoch{self.this_epoch}_{timestamp}.pth") | |
| torch.save(self.model.state_dict(), save_path) # Save model weights | |
| mlflow.log_param(f"model_epoch{self.this_epoch}_path", save_path) | |
| logger.info(f"Model saved successfully! Location: {save_path}") | |