PeptideBERT / train.py
Irwiny123's picture
添加PeptideBERT初始代码
9b00a31
import torch
import wandb
from datetime import datetime
import yaml
import os
import shutil
from data.dataloader import load_data
from model.network import create_model, cri_opt_sch
from model.utils import train, validate, test
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}\n')
def train_model():
print(f'{"="*30}{"TRAINING":^20}{"="*30}')
best_acc = 0
for epoch in range(config['epochs']):
train_loss = train(model, train_data_loader, optimizer, criterion, scheduler, device)
curr_lr = optimizer.param_groups[0]['lr']
print(f'Epoch {epoch+1}/{config["epochs"]} - Train Loss: {train_loss}\tLR: {curr_lr}')
val_loss, val_acc = validate(model, val_data_loader, criterion, device)
print(f'Epoch {epoch+1}/{config["epochs"]} - Validation Loss: {val_loss}\tValidation Accuracy: {val_acc}\n')
scheduler.step(val_acc)
if not config['debug']:
wandb.log({
'train_loss': train_loss,
'val_loss': val_loss,
'val_accuracy': val_acc,
'lr': curr_lr
})
if val_acc >= best_acc and not config['debug']:
best_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'acc': val_acc,
'lr': curr_lr
}, f'{save_dir}/model.pt')
print('Model Saved\n')
wandb.finish()
config = yaml.load(open('./config.yaml', 'r'), Loader=yaml.FullLoader)
config['device'] = device
train_data_loader, val_data_loader, test_data_loader = load_data(config)
config['sch']['steps'] = len(train_data_loader)
model = create_model(config)
criterion, optimizer, scheduler = cri_opt_sch(config, model)
if not config['debug']:
run_name = f'{config["task"]}-{datetime.now().strftime("%m%d_%H%M")}'
wandb.init(project='PeptideBERT', name=run_name)
save_dir = f'./checkpoints/{run_name}'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy('./config.yaml', f'{save_dir}/config.yaml')
shutil.copy('./model/network.py', f'{save_dir}/network.py')
train_model()
if not config['debug']:
model.load_state_dict(torch.load(f'{save_dir}/model.pt')['model_state_dict'], strict=False)
test_acc = test(model, test_data_loader, device)
print(f'Test Accuracy: {test_acc}%')