Skip to content

rahowa/CLRFinder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

21 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

CLRFinder

This repository contains Pytorch implementations of learning rate finder and cyclic scheduler.

Install

Clone repo

$ git clone https://github.com/rahowa/CLRFinder.git

Install poetry for package building

$ pip install poetry

Build and install python package with poetry

$ cd CLRFinder
$ poetry install

Usage

from clr_finder.scalers import Scaler
from clr_finder.torch_pipeline import LRFinderPytorch, PytorchScheduler

...

model = TestModel()
model.to('cpu')
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), 1.0)

transforms = Compose([ToTensor()])
train_dataset = MNIST('./', train=True, download=True, transform=transforms)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

finder = LRFinderPytorch(model=model, min_lr=1e-3, max_lr=0.9, loader=train_loader, n_epochs=1)
finder.run(criterion, optimizer)
best_lr = finder.best_lr()
scheduler = PytorchScheduler(best_lr[0], best_lr[1], Scaler.triangle1, len(train_loader), 1.)

optimizer = optim.Adam(model.parameters(), best_lr[1])
for ep in range(10):
    model, acc, loss = train(model, criterion, optimizer, train_loader, 'cpu')
    print(f"Epooch: {ep}, loss: {loss}, acc: {acc}")

License

License

About

πŸ“‰πŸ”Ž This repository contains Pytorch and Keras implementations of learning rate finder and cyclic scheduler.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages