pytorch learn looper [on hold]

Multi tool use
up vote
0
down vote
favorite
This code contains some routines for training cnn. Could you please point at any things you find wrong or ugly? Thanks.
class Trainer(object):
def __init__(self, criterion,
metric,
optimizer,
model_name,
model,
base_checkpoint_name=None,
device=0,
dummy_input=None):
'''
:param watcher_env: environment for visdom
:param criterion - loss function
'''
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
self.metric = metric
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
self.best_loss = np.inf
self.model_name = model_name
self.device = device
self.epoch_num = 0
self.model = model
self.logger = create_logger(model_name + '.log')
self.writer = SummaryWriter(log_dir='/tmp/runs/')
self.counters = {}
if dummy_input is not None:
self._plot_graph(dummy_input)
@staticmethod
def save_checkpoint(state, name):
print('saving state at', name)
torch.save(state, name)
def get_checkpoint_name(self, loss):
return self.base_checkpoint_name + '_best.pth.tar'
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
return best
def validate(self, val_loader):
batch_time = AverageMeter()
losses = AverageMeter()
metrics = AverageMeter()
self.model.eval()
end = time.time()
tqdm_val_loader = tqdm(enumerate(val_loader))
for batch_idx, (input, target) in tqdm_val_loader:
with torch.no_grad():
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var)
metrics.update(metric_val)
tqdm_val_loader.set_description('val loss:%s, val metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
self._log_data(input, target, output, 'val_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'val_it_metric')
end = time.time()
self._log_metric({
'metric': metrics.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'val_epoch_metric')
self.scheduler.step(losses.avg)
if self.is_best(losses.avg):
self.save_checkpoint(self.model.state_dict(), self.get_checkpoint_name(losses.avg))
self.epoch_num += 1
return losses.avg, metrics.avg
def update_train_epoch_stats(self, loss, metric):
self.epoch_train_losses.append(loss)
self.epoch_train_metrics.append(metric)
def train(self, train_loader):
batch_time, data_time, losses, metric = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
self.model.train()
end = time.time()
train_tqdm_iterator = tqdm(enumerate(train_loader))
for batch_idx, (input, target) in train_tqdm_iterator:
data_time.update(time.time() - end)
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var) # todo - add output dimention assertion
metric.update(metric_val)
train_tqdm_iterator.set_description('train loss:%s, train metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
end = time.time()
self._log_data(input, target, output, 'train_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'train_it_metric')
self._log_metric({
'metric': metric.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'train_epoch_metric')
return losses.avg, metric.avg
def _log_data(self, input, target, output, tag):
it = self._get_it(tag)
self.writer.add_image(tag, input[:, 0:3, :, :], it)
def _log_metric(self, metrics_dict, tag):
it = self._get_it(tag)
result = 'tag: ' + tag
for k in metrics_dict:
self.writer.add_scalar(tag + '_' + k, metrics_dict[k], it)
result += ' ,' + k + '=' + str(metrics_dict[k])
result += ', iteration ' + str(it)
self.logger.debug(result)
def _get_it(self, tag):
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
def _plot_graph(self, dummy_input):
self.writer.add_graph(self.model, dummy_input)
python pytorch
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
put on hold as off-topic by Toby Speight, Ludisposed, IEatBagels, Sᴀᴍ Onᴇᴌᴀ, alecxe 9 hours ago
This question appears to be off-topic. The users who voted to close gave this specific reason:
- "Lacks concrete context: Code Review requires concrete code from a project, with sufficient context for reviewers to understand how that code is used. Pseudocode, stub code, hypothetical code, obfuscated code, and generic best practices are outside the scope of this site." – Toby Speight, Ludisposed, Sᴀᴍ Onᴇᴌᴀ, alecxe
If this question can be reworded to fit the rules in the help center, please edit the question.
add a comment |
up vote
0
down vote
favorite
This code contains some routines for training cnn. Could you please point at any things you find wrong or ugly? Thanks.
class Trainer(object):
def __init__(self, criterion,
metric,
optimizer,
model_name,
model,
base_checkpoint_name=None,
device=0,
dummy_input=None):
'''
:param watcher_env: environment for visdom
:param criterion - loss function
'''
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
self.metric = metric
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
self.best_loss = np.inf
self.model_name = model_name
self.device = device
self.epoch_num = 0
self.model = model
self.logger = create_logger(model_name + '.log')
self.writer = SummaryWriter(log_dir='/tmp/runs/')
self.counters = {}
if dummy_input is not None:
self._plot_graph(dummy_input)
@staticmethod
def save_checkpoint(state, name):
print('saving state at', name)
torch.save(state, name)
def get_checkpoint_name(self, loss):
return self.base_checkpoint_name + '_best.pth.tar'
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
return best
def validate(self, val_loader):
batch_time = AverageMeter()
losses = AverageMeter()
metrics = AverageMeter()
self.model.eval()
end = time.time()
tqdm_val_loader = tqdm(enumerate(val_loader))
for batch_idx, (input, target) in tqdm_val_loader:
with torch.no_grad():
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var)
metrics.update(metric_val)
tqdm_val_loader.set_description('val loss:%s, val metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
self._log_data(input, target, output, 'val_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'val_it_metric')
end = time.time()
self._log_metric({
'metric': metrics.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'val_epoch_metric')
self.scheduler.step(losses.avg)
if self.is_best(losses.avg):
self.save_checkpoint(self.model.state_dict(), self.get_checkpoint_name(losses.avg))
self.epoch_num += 1
return losses.avg, metrics.avg
def update_train_epoch_stats(self, loss, metric):
self.epoch_train_losses.append(loss)
self.epoch_train_metrics.append(metric)
def train(self, train_loader):
batch_time, data_time, losses, metric = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
self.model.train()
end = time.time()
train_tqdm_iterator = tqdm(enumerate(train_loader))
for batch_idx, (input, target) in train_tqdm_iterator:
data_time.update(time.time() - end)
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var) # todo - add output dimention assertion
metric.update(metric_val)
train_tqdm_iterator.set_description('train loss:%s, train metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
end = time.time()
self._log_data(input, target, output, 'train_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'train_it_metric')
self._log_metric({
'metric': metric.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'train_epoch_metric')
return losses.avg, metric.avg
def _log_data(self, input, target, output, tag):
it = self._get_it(tag)
self.writer.add_image(tag, input[:, 0:3, :, :], it)
def _log_metric(self, metrics_dict, tag):
it = self._get_it(tag)
result = 'tag: ' + tag
for k in metrics_dict:
self.writer.add_scalar(tag + '_' + k, metrics_dict[k], it)
result += ' ,' + k + '=' + str(metrics_dict[k])
result += ', iteration ' + str(it)
self.logger.debug(result)
def _get_it(self, tag):
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
def _plot_graph(self, dummy_input):
self.writer.add_graph(self.model, dummy_input)
python pytorch
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
put on hold as off-topic by Toby Speight, Ludisposed, IEatBagels, Sᴀᴍ Onᴇᴌᴀ, alecxe 9 hours ago
This question appears to be off-topic. The users who voted to close gave this specific reason:
- "Lacks concrete context: Code Review requires concrete code from a project, with sufficient context for reviewers to understand how that code is used. Pseudocode, stub code, hypothetical code, obfuscated code, and generic best practices are outside the scope of this site." – Toby Speight, Ludisposed, Sᴀᴍ Onᴇᴌᴀ, alecxe
If this question can be reworded to fit the rules in the help center, please edit the question.
1
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago
add a comment |
up vote
0
down vote
favorite
up vote
0
down vote
favorite
This code contains some routines for training cnn. Could you please point at any things you find wrong or ugly? Thanks.
class Trainer(object):
def __init__(self, criterion,
metric,
optimizer,
model_name,
model,
base_checkpoint_name=None,
device=0,
dummy_input=None):
'''
:param watcher_env: environment for visdom
:param criterion - loss function
'''
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
self.metric = metric
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
self.best_loss = np.inf
self.model_name = model_name
self.device = device
self.epoch_num = 0
self.model = model
self.logger = create_logger(model_name + '.log')
self.writer = SummaryWriter(log_dir='/tmp/runs/')
self.counters = {}
if dummy_input is not None:
self._plot_graph(dummy_input)
@staticmethod
def save_checkpoint(state, name):
print('saving state at', name)
torch.save(state, name)
def get_checkpoint_name(self, loss):
return self.base_checkpoint_name + '_best.pth.tar'
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
return best
def validate(self, val_loader):
batch_time = AverageMeter()
losses = AverageMeter()
metrics = AverageMeter()
self.model.eval()
end = time.time()
tqdm_val_loader = tqdm(enumerate(val_loader))
for batch_idx, (input, target) in tqdm_val_loader:
with torch.no_grad():
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var)
metrics.update(metric_val)
tqdm_val_loader.set_description('val loss:%s, val metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
self._log_data(input, target, output, 'val_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'val_it_metric')
end = time.time()
self._log_metric({
'metric': metrics.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'val_epoch_metric')
self.scheduler.step(losses.avg)
if self.is_best(losses.avg):
self.save_checkpoint(self.model.state_dict(), self.get_checkpoint_name(losses.avg))
self.epoch_num += 1
return losses.avg, metrics.avg
def update_train_epoch_stats(self, loss, metric):
self.epoch_train_losses.append(loss)
self.epoch_train_metrics.append(metric)
def train(self, train_loader):
batch_time, data_time, losses, metric = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
self.model.train()
end = time.time()
train_tqdm_iterator = tqdm(enumerate(train_loader))
for batch_idx, (input, target) in train_tqdm_iterator:
data_time.update(time.time() - end)
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var) # todo - add output dimention assertion
metric.update(metric_val)
train_tqdm_iterator.set_description('train loss:%s, train metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
end = time.time()
self._log_data(input, target, output, 'train_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'train_it_metric')
self._log_metric({
'metric': metric.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'train_epoch_metric')
return losses.avg, metric.avg
def _log_data(self, input, target, output, tag):
it = self._get_it(tag)
self.writer.add_image(tag, input[:, 0:3, :, :], it)
def _log_metric(self, metrics_dict, tag):
it = self._get_it(tag)
result = 'tag: ' + tag
for k in metrics_dict:
self.writer.add_scalar(tag + '_' + k, metrics_dict[k], it)
result += ' ,' + k + '=' + str(metrics_dict[k])
result += ', iteration ' + str(it)
self.logger.debug(result)
def _get_it(self, tag):
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
def _plot_graph(self, dummy_input):
self.writer.add_graph(self.model, dummy_input)
python pytorch
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
This code contains some routines for training cnn. Could you please point at any things you find wrong or ugly? Thanks.
class Trainer(object):
def __init__(self, criterion,
metric,
optimizer,
model_name,
model,
base_checkpoint_name=None,
device=0,
dummy_input=None):
'''
:param watcher_env: environment for visdom
:param criterion - loss function
'''
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
self.metric = metric
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = ReduceLROnPlateau(optimizer, patience=2, verbose=True)
self.best_loss = np.inf
self.model_name = model_name
self.device = device
self.epoch_num = 0
self.model = model
self.logger = create_logger(model_name + '.log')
self.writer = SummaryWriter(log_dir='/tmp/runs/')
self.counters = {}
if dummy_input is not None:
self._plot_graph(dummy_input)
@staticmethod
def save_checkpoint(state, name):
print('saving state at', name)
torch.save(state, name)
def get_checkpoint_name(self, loss):
return self.base_checkpoint_name + '_best.pth.tar'
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
return best
def validate(self, val_loader):
batch_time = AverageMeter()
losses = AverageMeter()
metrics = AverageMeter()
self.model.eval()
end = time.time()
tqdm_val_loader = tqdm(enumerate(val_loader))
for batch_idx, (input, target) in tqdm_val_loader:
with torch.no_grad():
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var)
metrics.update(metric_val)
tqdm_val_loader.set_description('val loss:%s, val metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
self._log_data(input, target, output, 'val_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'val_it_metric')
end = time.time()
self._log_metric({
'metric': metrics.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'val_epoch_metric')
self.scheduler.step(losses.avg)
if self.is_best(losses.avg):
self.save_checkpoint(self.model.state_dict(), self.get_checkpoint_name(losses.avg))
self.epoch_num += 1
return losses.avg, metrics.avg
def update_train_epoch_stats(self, loss, metric):
self.epoch_train_losses.append(loss)
self.epoch_train_metrics.append(metric)
def train(self, train_loader):
batch_time, data_time, losses, metric = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
self.model.train()
end = time.time()
train_tqdm_iterator = tqdm(enumerate(train_loader))
for batch_idx, (input, target) in train_tqdm_iterator:
data_time.update(time.time() - end)
input_var = input.to(self.device)
target_var = target.to(self.device)
output = self.model(input_var)
loss = self.criterion(output, target_var)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
loss_scalar = loss.item()
losses.update(loss_scalar)
metric_val = self.metric(output, target_var) # todo - add output dimention assertion
metric.update(metric_val)
train_tqdm_iterator.set_description('train loss:%s, train metric: %s' %
(str(loss_scalar), str(metric_val)))
batch_time.update(time.time() - end)
end = time.time()
self._log_data(input, target, output, 'train_it_data')
self._log_metric({
'metric': metric_val,
'loss': loss_scalar,
'batch_time': time.time() - end
}, 'train_it_metric')
self._log_metric({
'metric': metric.avg,
'loss': losses.avg,
'batch_time': batch_time.avg
}, 'train_epoch_metric')
return losses.avg, metric.avg
def _log_data(self, input, target, output, tag):
it = self._get_it(tag)
self.writer.add_image(tag, input[:, 0:3, :, :], it)
def _log_metric(self, metrics_dict, tag):
it = self._get_it(tag)
result = 'tag: ' + tag
for k in metrics_dict:
self.writer.add_scalar(tag + '_' + k, metrics_dict[k], it)
result += ' ,' + k + '=' + str(metrics_dict[k])
result += ', iteration ' + str(it)
self.logger.debug(result)
def _get_it(self, tag):
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
def _plot_graph(self, dummy_input):
self.writer.add_graph(self.model, dummy_input)
python pytorch
python pytorch
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
edited 12 hours ago


Ludisposed
6,84421959
6,84421959
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
asked 13 hours ago
Артем Лян
1
1
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
New contributor
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
Артем Лян is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.
put on hold as off-topic by Toby Speight, Ludisposed, IEatBagels, Sᴀᴍ Onᴇᴌᴀ, alecxe 9 hours ago
This question appears to be off-topic. The users who voted to close gave this specific reason:
- "Lacks concrete context: Code Review requires concrete code from a project, with sufficient context for reviewers to understand how that code is used. Pseudocode, stub code, hypothetical code, obfuscated code, and generic best practices are outside the scope of this site." – Toby Speight, Ludisposed, Sᴀᴍ Onᴇᴌᴀ, alecxe
If this question can be reworded to fit the rules in the help center, please edit the question.
put on hold as off-topic by Toby Speight, Ludisposed, IEatBagels, Sᴀᴍ Onᴇᴌᴀ, alecxe 9 hours ago
This question appears to be off-topic. The users who voted to close gave this specific reason:
- "Lacks concrete context: Code Review requires concrete code from a project, with sufficient context for reviewers to understand how that code is used. Pseudocode, stub code, hypothetical code, obfuscated code, and generic best practices are outside the scope of this site." – Toby Speight, Ludisposed, Sᴀᴍ Onᴇᴌᴀ, alecxe
If this question can be reworded to fit the rules in the help center, please edit the question.
1
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago
add a comment |
1
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago
1
1
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago
add a comment |
1 Answer
1
active
oldest
votes
up vote
0
down vote
You have a docstring for __init__
(which is good; not enough people do this) but it's both wrong and incomplete. Add entries for every parameter and remove watcher_env
.
This:
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
can be
self.base_checkpoint_name = base_checkpoint_name or model_name
This:
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
has a few problems. First, the name is_best
suggests that it returns a boolean and doesn't change anything, which it does. Perhaps you want to rename it to take_best
. Also, the contents of the function can be replaced with
if self.best_loss > avg_loss:
self.best_loss = avg_loss
return True
return False
Try replacing this string:
'val loss:%s, val metric: %s' % (str(loss_scalar), str(metric_val))
with this:
f'val loss: {loss_scalar}, val metric: {metric_val}'
This has a typo:
# todo - add output dimention assertion
It's "dimension".
This code:
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
has a few issues. Don't do an "if in / key lookup" if possible; there should be fewer key lookups done. Also, the else
is redundant because of the previous return
. So:
counter = self.counters.get(tag)
if counter is None:
self.counters[tag] = 0
return 0
self.counters[tag] = counter + 1
return counter
add a comment |
1 Answer
1
active
oldest
votes
1 Answer
1
active
oldest
votes
active
oldest
votes
active
oldest
votes
up vote
0
down vote
You have a docstring for __init__
(which is good; not enough people do this) but it's both wrong and incomplete. Add entries for every parameter and remove watcher_env
.
This:
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
can be
self.base_checkpoint_name = base_checkpoint_name or model_name
This:
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
has a few problems. First, the name is_best
suggests that it returns a boolean and doesn't change anything, which it does. Perhaps you want to rename it to take_best
. Also, the contents of the function can be replaced with
if self.best_loss > avg_loss:
self.best_loss = avg_loss
return True
return False
Try replacing this string:
'val loss:%s, val metric: %s' % (str(loss_scalar), str(metric_val))
with this:
f'val loss: {loss_scalar}, val metric: {metric_val}'
This has a typo:
# todo - add output dimention assertion
It's "dimension".
This code:
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
has a few issues. Don't do an "if in / key lookup" if possible; there should be fewer key lookups done. Also, the else
is redundant because of the previous return
. So:
counter = self.counters.get(tag)
if counter is None:
self.counters[tag] = 0
return 0
self.counters[tag] = counter + 1
return counter
add a comment |
up vote
0
down vote
You have a docstring for __init__
(which is good; not enough people do this) but it's both wrong and incomplete. Add entries for every parameter and remove watcher_env
.
This:
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
can be
self.base_checkpoint_name = base_checkpoint_name or model_name
This:
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
has a few problems. First, the name is_best
suggests that it returns a boolean and doesn't change anything, which it does. Perhaps you want to rename it to take_best
. Also, the contents of the function can be replaced with
if self.best_loss > avg_loss:
self.best_loss = avg_loss
return True
return False
Try replacing this string:
'val loss:%s, val metric: %s' % (str(loss_scalar), str(metric_val))
with this:
f'val loss: {loss_scalar}, val metric: {metric_val}'
This has a typo:
# todo - add output dimention assertion
It's "dimension".
This code:
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
has a few issues. Don't do an "if in / key lookup" if possible; there should be fewer key lookups done. Also, the else
is redundant because of the previous return
. So:
counter = self.counters.get(tag)
if counter is None:
self.counters[tag] = 0
return 0
self.counters[tag] = counter + 1
return counter
add a comment |
up vote
0
down vote
up vote
0
down vote
You have a docstring for __init__
(which is good; not enough people do this) but it's both wrong and incomplete. Add entries for every parameter and remove watcher_env
.
This:
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
can be
self.base_checkpoint_name = base_checkpoint_name or model_name
This:
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
has a few problems. First, the name is_best
suggests that it returns a boolean and doesn't change anything, which it does. Perhaps you want to rename it to take_best
. Also, the contents of the function can be replaced with
if self.best_loss > avg_loss:
self.best_loss = avg_loss
return True
return False
Try replacing this string:
'val loss:%s, val metric: %s' % (str(loss_scalar), str(metric_val))
with this:
f'val loss: {loss_scalar}, val metric: {metric_val}'
This has a typo:
# todo - add output dimention assertion
It's "dimension".
This code:
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
has a few issues. Don't do an "if in / key lookup" if possible; there should be fewer key lookups done. Also, the else
is redundant because of the previous return
. So:
counter = self.counters.get(tag)
if counter is None:
self.counters[tag] = 0
return 0
self.counters[tag] = counter + 1
return counter
You have a docstring for __init__
(which is good; not enough people do this) but it's both wrong and incomplete. Add entries for every parameter and remove watcher_env
.
This:
if base_checkpoint_name is None:
self.base_checkpoint_name = model_name
else:
self.base_checkpoint_name = base_checkpoint_name
can be
self.base_checkpoint_name = base_checkpoint_name or model_name
This:
def is_best(self, avg_loss):
best = avg_loss < self.best_loss
if best:
self.best_loss = avg_loss
has a few problems. First, the name is_best
suggests that it returns a boolean and doesn't change anything, which it does. Perhaps you want to rename it to take_best
. Also, the contents of the function can be replaced with
if self.best_loss > avg_loss:
self.best_loss = avg_loss
return True
return False
Try replacing this string:
'val loss:%s, val metric: %s' % (str(loss_scalar), str(metric_val))
with this:
f'val loss: {loss_scalar}, val metric: {metric_val}'
This has a typo:
# todo - add output dimention assertion
It's "dimension".
This code:
if tag in self.counters.keys():
result = self.counters[tag]
self.counters[tag] += 1
return result
else:
self.counters[tag] = 0
return 0
has a few issues. Don't do an "if in / key lookup" if possible; there should be fewer key lookups done. Also, the else
is redundant because of the previous return
. So:
counter = self.counters.get(tag)
if counter is None:
self.counters[tag] = 0
return 0
self.counters[tag] = counter + 1
return counter
answered 11 hours ago
Reinderien
1,887616
1,887616
add a comment |
add a comment |
O3KbUUV1X0F0oYeoI6VK,Ufs0 5os WbWlsp L0Z2G94nNr5D9X4K
1
Can add some more information on what this code is trying to achieve, how to run? Currently this post lacks context
– Ludisposed
12 hours ago