classification_report.report

Module Contents

class classification_report.report.Report(classes: TrainType, dir_name: str = None)[source]

Generating Report for classification model by tracking Model training and giving different types of metrics to evaluate the Model.

For any classification problem during Model’s training it is very important to track Model’s Weight Biases and Gradients. After training the important part is the model evaluation where we evaluate the model performance. This Report class simplify the evaluation part where all the evaluation metrics are automatically generated for the model.It uses Tensorboard to visualize all these.

write_a_batch(self, loss: LossType, batch_size: int, actual: ActualType, prediction: PredictionType, train: bool = True)[source]

This methods records the batch information during train and val phase.

During training and validation record the loss, batch actual labels and predicted labels.

Note

For prediction don’t pass raw_logits pass softmax output.

Parameters:
  • loss – The batch loss.
  • batch_size – The batch size on which the loss was calculated. The batch_size may change during last iteration so calculate batch_size from data.
  • actual – The actual labels.
  • prediction – The predicted labels.
  • train – True signifies training mode and False Validation Mode.
Returns:

Report class instance.

update_actual_prediction(self, actual: ActualType, prediction: PredictionType, train_type: TrainType)[source]

Stores actual and predicted labels seperately for training and validation and after every batch call the values are appended. :param actual: The actual labels. :param prediction: The predicted labels. :param train_type: The labels belong to train or valid.

Returns:Report class instances.
update_loss(self, loss: LossType, batch_size: int, train_type: TrainType)[source]

Accumlates loss for every batch seperately for training and validation.

Parameters:
  • loss – The batch loss.
  • batch_size – The batch size on which the loss was calculated. The batch_size may change during last iteration so calculate batch_size from data.
  • train_type – The Labels belong to train or valid.
Returns:

Report class instance.

update_data_iter(self, batch_size: int, train_type: TrainType)[source]

Accumlates the iteration count and data point count for training and validation.

Parameters:
  • batch_size – The batch size on which the loss was calculated. The batch_size may change during last iteration so calculate batch_size from data.
  • train_type – The Labels belong to train or valid.
Returns:

Report class instance.

plot_an_epoch(self, detail: bool = False)[source]

Plot an epoch method simplifies ploting standard things which are needed to be plotted after an epoch completion for granular control use this which detail = False and call other methods on top of it.

Parameters:detail – whether to use detail mode or not.
Returns:Report class instance.
init_data_storage(self)[source]

Clean the data storage units after every epoch.

change_data_type(self, data: LossType, required_data_type: LossType)[source]

Change the data type of input to required data type.

Parameters:
  • data – Input data type.
  • required_data_type – Change the data type to given format, can be either np or f.
Returns:

The data in required data type.

close(self)[source]

Close the tensorboard writer object.

After calling this method report will not track anything.

write_to_tensorboard(self)[source]

This methods call various other method which write on tensorboard.

plot_loss(self)[source]

Plots loss at the end of the epoch.

Returns:Report class instance.
plot_model(self, model: torch.nn.Module, data: torch.Tensor)[source]

Plot model graph.

Parameters:
  • model – The model architecture.
  • data – The input to the model.
Returns:

Report class instance.

plot_confusion_matrix(self, at_which_epoch)[source]

Plots confusion matrix.

Parameters:at_which_epoch – After how many epochs the confusion matrix should be plotted. For example if the model is trained for 10 epochs and you want to plot confusion matrix after every 5 epoch then the input to this method will be 5.
Returns:Report class instance.
plot_precision_recall(self)[source]

Plots Precision Recall F1-score graph for all Classes with Weighted Average and Macro Average.

Returns:Report class instance.
plot_missclassification_count(self, at_which_epoch)[source]

Plot Misclassification Count for each class.

Bar graph for False Positive and False Negative Count.

Parameters:at_which_epoch – After how many epochs the Misclassification Count should be plotted. For example if the model is trained for 10 epochs and you want to plot this after every 5 epoch then the input to this method will be 5.
Returns:Report class instance.
calculate_fp_fn(self, actual: ActualType, pred: PredictionType)[source]

Calculates False Postive and False Negative Count per class.

Parameters:
  • actual – The actual labels.
  • pred – The predicted Labels.
Returns:

Report class instance.

plot_mcc(self)[source]

Plots Mathews Correlation Coefficient.

Returns:Report class instance.
plot_pred_prob(self, at_which_epoch: int)[source]

Plots scatter plot for the predicted probabilites for each class.

Parameters:at_which_epoch – After how many epochs the predicted probabilites should be plotted. For example if the model is trained for 10 epochs and you want to plot this after every 5 epoch then the input to this method will be 5.
Returns:Report class instance.
plot_model_data_grad(self, at_which_iter: int)[source]

Plot Histogram and Distribution for each layers of model Weights, Bias and Gradients.

Parameters:at_which_iter – After how many iteration this should be plotted. The ideal way to plot this to plot after every one-half or one-third of the train_iterator.
Returns:Report class instance.
Examples::
>>> report.plot_model_data_grad(at_which_iter = len(train_iterator)/2)
plot_hparams(self, config: HyperParameters)[source]

Plot Hyper parameters for the model. This method should be called once training is over.

Parameters:config – Hyperparameter Configs.
Returns:Report class instance.