PK!Gjrpytorch/__init__.py__version__ = '0.1.0' from . import datasets from . import visualiser import torch import matplotlib.pyplot as plt import numpy as np def vis_binary_classification(model, loss_fn, optimiser, X, y, epochs=1000): x_tensor = torch.tensor(X).float() y_tensor = torch.tensor(y.reshape(-1, 1)).float() plt.ion() min = np.min(X)-0.2 max = np.max(X)+0.2 x1 = np.linspace(min, max, 100) grid = np.array([(x, y) for x in x1 for y in x1]) fig = plt.figure() ax = fig.add_subplot(121) ax2 = fig.add_subplot(122) real_x = x1 real_y = x1 dx = (real_x[1]-real_x[0])/2. dy = (real_y[1]-real_y[0])/2. extent = [real_x[0]-dx, real_x[-1]+dx, real_y[0]-dy, real_y[-1]+dy] ax.set_xlim([min, max]) ax.set_ylim([min, max]) ax2.set_xlim([0, epochs]) line1, = ax2.plot([], [], 'r-') # first image output = model(torch.from_numpy(grid).float()).detach().numpy() img = ax.imshow(output.reshape(100, 100), extent=extent) ax.scatter(X[:, 0], X[:, 1], c=y, edgecolor='black') losses = [] for epoch in range(epochs): optimiser.zero_grad() output = model(x_tensor) loss = loss_fn(output, y_tensor) loss.backward() optimiser.step() losses.append(loss.item()) if epoch == 0: ax2.set_ylim([0, loss.item()*1.1]) with torch.no_grad(): # update plot output = model(torch.from_numpy(grid).float()) arr_output = output.detach().numpy().reshape(100, 100) img.set_data(arr_output) ax.set_title('Epoch: %d Loss %.3f' % (epoch + 1, loss.item())) line1.set_data(list(range(epoch+1)), losses) fig.canvas.draw() fig.canvas.flush_events() plt.ioff() PK!OAشyyjrpytorch/datasets/__init__.pyfrom sklearn.datasets import make_circles def load_circles(): return make_circles(200, random_state=1, noise=0.05) PK!F jrpytorch/visualiser.pyfrom datetime import datetime import visdom import torch class Classification_Visualiser: def __init__(self, env_name=None, host='http://localhost', port=8097): if env_name is None: env_name = str(datetime.now().strftime('%d-%m-%Hh%M')) self.env_name = env_name self.vis = visdom.Visdom(server=host, port=port, env=self.env_name) self.loss_win = None self.acc_win = None self.text_win = None def plot_loss(self, loss, epoch): self.loss_win = self.vis.line( [loss], [epoch], win=self.loss_win, update='append' if self.loss_win else None, opts={ 'xlabel': 'Epoch', 'ylabel': 'Loss', 'legend': ['train', 'test'], 'layoutopts': { 'legend': { 'train': 0, 'test': 1 } } } ) def plot_accuracy(self, acc, epoch): self.acc_win = self.vis.line( [acc], [epoch], win=self.acc_win, update='append' if self.acc_win else None, opts={ 'xlabel': 'Epoch', 'ylabel': 'Accuracy', 'legend': ['train', 'test'], 'layoutopts': { 'legend': { 'train': 0, 'test': 1 } } } ) def show_text(self, text): self.text_win = self.vis.text( text, win=self.text_win ) def train_classifier(model, criterion, optimiser, epochs, visualiser, loaders, sizes, device='cpu'): model = model.to(device) losses = { 'train': None, 'test': None } acc = { 'train': None, 'test': None } for epoch in range(epochs): text = '

Epoch %d / %d



' % (epoch + 1, epochs) for stage in ['train', 'test']: if stage == 'train': model.train() else: model.eval() running_loss = .0 running_correct = 0 for input, label in loaders[stage]: input, label = input.to(device), label.to(device) optimiser.zero_grad() with torch.set_grad_enabled(stage == 'train'): output = model(input) _, pred = torch.max(output, 1) loss = criterion(output, label) if stage == 'train': loss.backward() optimiser.step() running_loss += loss.item() * input.size(0) running_correct += torch.sum(pred == label.data) stage_loss = running_loss/sizes[stage] stage_accuracy = running_correct.double() / sizes[stage] losses[stage] = stage_loss acc[stage] = stage_accuracy.cpu().numpy() text += ' %s Loss: %.3f Acc: %.3f

' % (stage,stage_loss, stage_accuracy) visualiser.plot_loss(list(losses.values()), [epoch, epoch]) visualiser.plot_accuracy(list(acc.values()), [epoch, epoch]) visualiser.show_text(text) PK!H=BTTjrpytorch-0.1.2.dist-info/WHEEL A н#Z;/"d&F[xzw@Zpy3Fv]n0H*J>mlcAPK!HZ!V"jrpytorch-0.1.2.dist-info/METADATAN <GMZRFF׺Ѩ.k tn4^4d<:i4a%8nVI #z;pniM_D+Y B̯z1;PRT(mPm:+wATF'K+\t/sLaښƂR@7}pymʓ~o+]k|+QtNG1";Il@"CIER87}Hc`&3lj&]ľPK!HgR- jrpytorch-0.1.2.dist-info/RECORD}9v@>guP`,RvBˠ3 (A<}*<>.Svi$}.CO) Avw[ިA(L}:f  9^vlow+>N0'%" Cތ%eU7-%Vl( 2O7CZc. ^8h2|xECR IVrY..f` L{7e;$_::*ܪNry<BZ-͗TɗqI36U{NhtmQ!Ziv2]MpUyPK!Gjrpytorch/__init__.pyPK!OAشyy.jrpytorch/datasets/__init__.pyPK!F jrpytorch/visualiser.pyPK!H=BTTjrpytorch-0.1.2.dist-info/WHEELPK!HZ!V"jrpytorch-0.1.2.dist-info/METADATAPK!HgR- jrpytorch-0.1.2.dist-info/RECORDPKn