asphodel-thuang's picture
Upload 59 files
ae522ad verified
import os
from tqdm import tqdm
import psutil
from datetime import datetime
from mwm.constants import *
from mwm.utils.common import read_yaml, load_json
from mwm.config.configuration import get_params
from mwm import logger
# Model architecture
import segmentation_models_pytorch as smp
from mwm.components.model_architecture import *
# Dataset
from mwm.components.dataset import *
from torch.utils.data import DataLoader
# Loss
from mwm.components.loss import *
# Metrics logger
import mlflow
from mwm.components.metrics_logger import *
class Training:
def __init__(
self,
config_filepath = CONFIG_FILE_PATH,
params_filepath = PARAMS_FILE_PATH
):
self.config = read_yaml(config_filepath)
self.params = get_params(params_filepath, "training")
# Make model
self.model = make_model(self.params.network, self.params.encoder_weights)
freeze_encoder(self.model, self.params.freeze_encoder_layers)
# Make dataset
self.image_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.image_dir)
self.mask_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.mask_dir)
self.sdm_dir = os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.sdm_dir)
# TODO: update with cross-validation
# - Train dataset
with open(os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.training_set_file), "r") as f:
self.image_list_train = f.read().splitlines()[:self.params.num_training_samples]
self.train_dataset = make_dataset(
self.params.dataset,
self.image_dir,
self.mask_dir,
self.sdm_dir,
self.image_list_train,
"train",
self.params.image_size
)
# - Validation dataset
with open(os.path.join(self.config.data_ingestion.unzip_dir, self.config.dataset.validation_set_file), "r") as f:
self.image_list_val = f.read().splitlines()
self.val_dataset = make_dataset(
self.params.dataset,
self.image_dir,
self.mask_dir,
self.sdm_dir,
self.image_list_val,
"val",
self.params.image_size
)
def handle_device(self):
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
def make_criterion(self):
if self.params.loss == "weighted_dice_bce_2ch":
self.criterion = WeightedDiceBCELoss(
weight_1=self.params.weighted_dice_bce_2ch.weight_1,
weight_2=self.params.weighted_dice_bce_2ch.weight_2,
weight_3=self.params.weighted_dice_bce_2ch.weight_3,
bce_weight=self.params.weighted_dice_bce_2ch.bce_weight,
grad_weight=self.params.weighted_dice_bce_2ch.grad_weight,
use_focal=self.params.weighted_dice_bce_2ch.use_focal,
use_gradient_loss=self.params.weighted_dice_bce_2ch.use_gradient_loss,
use_dist_loss=self.params.weighted_dice_bce_2ch.use_dist_loss
)
logger.info(f"Loss: {self.params.loss} selected. ")
else:
logger.error(f"Invalid loss: {self.params.loss}")
raise ValueError(f"Invalid loss: {self.params.loss}")
def make_optimizer(self):
if self.params.optimizer == "adam":
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params.learning_rate)
logger.info(f"Optimizer: {self.params.optimizer} selected. ")
else:
logger.error(f"Invalid optimizer: {self.params.optimizer}")
raise ValueError(f"Invalid optimizer: {self.params.optimizer}")
def train_epoch(self):
batch_progress_bar = tqdm(range(self.params.steps_per_epoch), desc=f"Epoch {self.this_epoch}/{self.params.epochs-1}", leave=True)
### Training Phase ###
self.model.train()
for step in batch_progress_bar:
images, masks, sdms = next(iter(self.train_loader))
images, masks, sdms = images.to(self.device), masks.to(self.device), sdms.to(self.device)
self.optimizer.zero_grad() # Reset gradients
outputs = self.model(images) # Forward pass
loss = self.criterion(outputs, masks, sdms) # Compute loss
loss.backward() # Backpropagation
self.optimizer.step() # Update weights
self.metrics_logger.update_sum(loss, outputs.cpu(), masks.cpu())
# Get CPU & RAM usage for display/monitoring
ram_used = psutil.virtual_memory().used / 1024**3
batch_progress_bar.set_postfix(loss=loss.item(), ram_used=f"{ram_used:.2f} GB", cpu_usage=f"{psutil.cpu_percent()}%")
self.metrics_logger.update_mean(
self.params.steps_per_epoch,
self.params.steps_per_epoch * self.params.batch_size
)
### Validation Phase ###
self.model.eval()
batch_progress_bar = tqdm(self.val_loader, desc=f"Epoch {self.this_epoch}/{self.params.epochs-1} validation", leave=True)
with torch.no_grad():
for images, masks, sdms in batch_progress_bar:
images, masks, sdms = images.to(self.device), masks.to(self.device), sdms.to(self.device)
outputs = self.model(images)
loss = self.criterion(outputs, masks, sdms)
self.metrics_logger.update_sum_val(loss)
batch_progress_bar.set_postfix(val_loss=loss.item())
self.metrics_logger.update_mean_val(len(self.val_loader))
self.metrics_logger.log_metrics_mlflow(self.this_epoch) # Logger is reset afterwards
def train(self, save_model=False, save_interval=10):
# Initialize metrics logger
if self.params.metrics_logger == "metrics_logger_2ch":
self.metrics_logger = MetricsLogger2Channel()
else:
logger.error(f"Invalid metrics logger: {self.params.metrics_logger}")
raise ValueError(f"Invalid metrics logger: {self.params.metrics_logger}")
# Define data loaders
self.train_loader = DataLoader(self.train_dataset, batch_size=self.params.batch_size, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=self.params.batch_size, shuffle=False)
# Start training
mlflow.set_experiment("Training")
with mlflow.start_run():
for epoch in range(self.params.epochs):
self.this_epoch = epoch
self.train_epoch()
if save_model:
if (epoch+1) % save_interval == 0:
self.save_model()
mlflow.log_params(self.params.to_dict())
logger.info("Training completed. ")
def save_model(self):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(self.config.model.model_dir, f"model_epoch{self.this_epoch}_{timestamp}.pth")
torch.save(self.model.state_dict(), save_path) # Save model weights
mlflow.log_param(f"model_epoch{self.this_epoch}_path", save_path)
logger.info(f"Model saved successfully! Location: {save_path}")