| import torch |
| import wandb |
| from datetime import datetime |
| import yaml |
| import os |
| import shutil |
| from data.dataloader import load_data |
| from model.network import create_model, cri_opt_sch |
| from model.utils import train, validate, test |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f'Device: {device}\n') |
|
|
|
|
| def train_model(): |
| print(f'{"="*30}{"TRAINING":^20}{"="*30}') |
|
|
| best_acc = 0 |
| for epoch in range(config['epochs']): |
| train_loss = train(model, train_data_loader, optimizer, criterion, scheduler, device) |
| curr_lr = optimizer.param_groups[0]['lr'] |
| print(f'Epoch {epoch+1}/{config["epochs"]} - Train Loss: {train_loss}\tLR: {curr_lr}') |
| val_loss, val_acc = validate(model, val_data_loader, criterion, device) |
| print(f'Epoch {epoch+1}/{config["epochs"]} - Validation Loss: {val_loss}\tValidation Accuracy: {val_acc}\n') |
| scheduler.step(val_acc) |
| if not config['debug']: |
| wandb.log({ |
| 'train_loss': train_loss, |
| 'val_loss': val_loss, |
| 'val_accuracy': val_acc, |
| 'lr': curr_lr |
| }) |
|
|
| if val_acc >= best_acc and not config['debug']: |
| best_acc = val_acc |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'train_loss': train_loss, |
| 'val_loss': val_loss, |
| 'acc': val_acc, |
| 'lr': curr_lr |
| }, f'{save_dir}/model.pt') |
| print('Model Saved\n') |
| wandb.finish() |
|
|
|
|
| config = yaml.load(open('./config.yaml', 'r'), Loader=yaml.FullLoader) |
| config['device'] = device |
|
|
| train_data_loader, val_data_loader, test_data_loader = load_data(config) |
| config['sch']['steps'] = len(train_data_loader) |
|
|
| model = create_model(config) |
| criterion, optimizer, scheduler = cri_opt_sch(config, model) |
|
|
| if not config['debug']: |
| run_name = f'{config["task"]}-{datetime.now().strftime("%m%d_%H%M")}' |
| wandb.init(project='PeptideBERT', name=run_name) |
|
|
| save_dir = f'./checkpoints/{run_name}' |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
| shutil.copy('./config.yaml', f'{save_dir}/config.yaml') |
| shutil.copy('./model/network.py', f'{save_dir}/network.py') |
|
|
| train_model() |
| if not config['debug']: |
| model.load_state_dict(torch.load(f'{save_dir}/model.pt')['model_state_dict'], strict=False) |
| test_acc = test(model, test_data_loader, device) |
| print(f'Test Accuracy: {test_acc}%') |
|
|