Training package

super_gradients.training module

class super_gradients.training.DataAugmentation[source]

Bases: object

static to_tensor()[source]
static normalize(mean, std)[source]
static cutout(mask_size, p=1, cutout_inside=False, mask_color=(0, 0, 0))[source]
class super_gradients.training.TestDatasetInterface(trainset, dataset_params={}, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

get_data_loaders(batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None, distributed_sampler=False)[source]

Get self.train_loader, self.val_loader, self.test_loader, self.classes.

If the data loaders haven’t been initialized yet, build them first.

Parameters

kwargs – kwargs are passed to build_data_loaders.

class super_gradients.training.SgModel(experiment_name: str, device: Optional[str] = None, multi_gpu: Union[super_gradients.common.data_types.enum.multi_gpu_mode.MultiGPUMode, str] = <MultiGPUMode.OFF: 'Off'>, model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth', post_prediction_callback: Optional[super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback] = None, ckpt_root_dir: Optional[str] = None, train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, valid_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, test_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, classes: Optional[List[Any]] = None)[source]

Bases: object

SuperGradient Model - Base Class for Sg Models

train(max_epochs: int, initial_epoch: int, save_model: bool)[source]

the main function used for the training, h.p. updating, logging etc.

predict(idx: int)[source]

returns the predictions and label of the current inputs

test(epoch : int, idx : int, save : bool):

returns the test loss, accuracy and runtime

connect_dataset_interface(dataset_interface: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface, data_loader_num_workers: int = 8)[source]
Parameters
  • dataset_interface – DatasetInterface object

  • data_loader_num_workers – The number of threads to initialize the Data Loaders with The dataset to be connected

build_model(architecture: Union[str, torch.nn.modules.module.Module], arch_params={}, checkpoint_params={}, *args, **kwargs)[source]
Parameters
  • architecture – Defines the network’s architecture from models/ALL_ARCHITECTURES

  • arch_params – Architecture H.P. e.g.: block, num_blocks, num_classes, etc.

  • checkpoint_params

    Dictionary like object with the following key:values:

    load_checkpoint: Load a pre-trained checkpoint strict_load: See StrictLoad class documentation for details. source_ckpt_folder_name: folder name to load the checkpoint from (self.experiment_name if none is given) load_weights_only: loads only the weight from the checkpoint and zeroize the training params load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative

    (ie: path/to/checkpoint.pth). If provided, will automatically attempt to load the checkpoint even if the load_checkpoint flag is not provided.

train(training_params: dict = {})[source]

train - Trains the Model

IMPORTANT NOTE: Additional batch parameters can be added as a third item (optional) if a tuple is returned by

the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.

param training_params
  • max_epochs : int

    Number of epochs to run training.

  • lr_updates : list(int)

    List of fixed epoch numbers to perform learning rate updates when lr_mode=’step’.

  • lr_decay_factor : float

    Decay factor to apply to the learning rate at each update when lr_mode=’step’.

  • lr_mode : str

    Learning rate scheduling policy, one of [‘step’,’poly’,’cosine’,’function’]. ‘step’ refers to constant updates at epoch numbers passed through lr_updates. ‘cosine’ refers to Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. ‘poly’ refers to polynomial decrease i.e in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9) ‘function’ refers to user defined learning rate scheduling function, that is passed through lr_schedule_function.

  • lr_schedule_function : Union[callable,None]

    Learning rate scheduling function to be used when lr_mode is ‘function’.

  • lr_warmup_epochs : int (default=0)

    Number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).

  • cosine_final_lr_ratiofloat (default=0.01)
    Final learning rate ratio (only relevant when `lr_mode`=’cosine’). The cosine starts from initial_lr and reaches

    initial_lr * cosine_final_lr_ratio in last epoch

  • inital_lr : float

    Initial learning rate.

  • loss : Union[nn.module, str]

    Loss function for training. One of SuperGradient’s built in options:

    “cross_entropy”: LabelSmoothingCrossEntropyLoss, “mse”: MSELoss, “r_squared_loss”: RSquaredLoss, “detection_loss”: YoLoV3DetectionLoss, “shelfnet_ohem_loss”: ShelfNetOHEMLoss, “shelfnet_se_loss”: ShelfNetSemanticEncodingLoss, “ssd_loss”: SSDLoss,

    or user defined nn.module loss function.

    IMPORTANT: forward(…) should return a (loss, loss_items) tuple where loss is the tensor used for backprop (i.e what your original loss function returns), and loss_items should be a tensor of shape (n_items), of values computed during the forward pass which we desire to log over the entire epoch. For example- the loss itself should always be logged. Another example is a scenario where the computed loss is the sum of a few components we would like to log- these entries in loss_items).

    When training, set the loss_logging_items_names parameter in train_params to be a list of strings, of length n_items who’s ith element is the name of the ith entry in loss_items. Then each item will be logged, rendered on tensorboard and “watched” (i.e saving model checkpoints according to it).

    Since running logs will save the loss_items in some internal state, it is recommended that loss_items are detached from their computational graph for memory efficiency.

  • optimizer : Union[str, torch.optim.Optimizer]

    Optimization algorithm. One of [‘Adam’,’SGD’,’RMSProp’] corresponding to the torch.optim optimzers implementations, or any object that implements torch.optim.Optimizer.

  • criterion_params : dict

    Loss function parameters.

  • optimizer_paramsdict

    When optimizer is one of [‘Adam’,’SGD’,’RMSProp’], it will be initialized with optimizer_params.

    (see https://pytorch.org/docs/stable/optim.html for the full list of parameters for each optimizer).

  • train_metrics_list : list(torchmetrics.Metric)

    Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.

  • valid_metrics_list : list(torchmetrics.Metric)

    Metrics to log during validation/testing. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.

  • loss_logging_items_names : list(str)

    The list of names/titles for the outputs returned from the loss functions forward pass (reminder- the loss function should return the tuple (loss, loss_items)). These names will be used for logging their values.

  • metric_to_watch : str (default=”Accuracy”)

    will be the metric which the model checkpoint will be saved according to, and can be set to any of the following:

    a metric name (str) of one of the metric objects from the valid_metrics_list

    a “metric_name” if some metric in valid_metrics_list has an attribute component_names which is a list referring to the names of each entry in the output metric (torch tensor of size n)

    one of “loss_logging_items_names” i.e which will correspond to an item returned during the loss function’s forward pass.

    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth

  • greater_metric_to_watch_is_better : bool

    When choosing a model’s checkpoint to be saved, the best achieved model is the one that maximizes the

    metric_to_watch when this parameter is set to True, and a one that minimizes it otherwise.

  • ema : bool (default=False)

    Whether to use Model Exponential Moving Average (see https://github.com/rwightman/pytorch-image-models ema implementation)

  • batch_accumulate : int (default=1)

    Number of batches to accumulate before every backward pass.

  • ema_params : dict

    Parameters for the ema model.

  • zero_weight_decay_on_bias_and_bn : bool (default=False)

    Whether to apply weight decay on batch normalization parameters or not (ignored when the passed optimizer has already been initialized).

  • load_opt_params : bool (default=True)

    Whether to load the optimizers parameters as well when loading a model’s checkpoint.

  • run_validation_freq : int (default=1)

    The frequency in which validation is performed during training (i.e the validation is ran every

    run_validation_freq epochs.

  • save_model : bool (default=True)

    Whether to save the model checkpoints.

  • silent_mode : bool

    Silents the print outs.

  • mixed_precision : bool

    Whether to use mixed precision or not.

  • save_ckpt_epoch_list : list(int) (default=[])

    List of fixed epoch indices the user wishes to save checkpoints in.

  • average_best_models : bool (default=False)

    If set, a snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when training is completed. The snapshot file will only be deleted upon completing the training. The snapshot dict will be managed on cpu.

  • precise_bn : bool (default=False)

    Whether to use precise_bn calculation during the training.

  • precise_bn_batch_size : int (default=None)

    The effective batch size we want to calculate the batchnorm on. For example, if we are training a model on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192 (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus). If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken.

  • seed : int (default=42)

    Random seed to be set for torch, numpy, and random. When using DDP each process will have it’s seed set to seed + rank.

  • log_installed_packages : bool (default=False)

    When set, the list of all installed packages (and their versions) will be written to the tensorboard

    and logfile (useful when trying to reproduce results).

  • dataset_statistics : bool (default=False)

    Enable a statistic analysis of the dataset. If set to True the dataset will be analyzed and a report will be added to the tensorboard along with some sample images from the dataset. Currently only detection datasets are supported for analysis.

  • save_full_train_log : bool (default=False)

    When set, a full log (of all super_gradients modules, including uncaught exceptions from any other

    module) of the training will be saved in the checkpoint directory under full_train_log.log

  • sg_logger : Union[AbstractSGLogger, str] (defauls=base_sg_logger)

    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging and remote storage. By overriding the default base_sg_logger, you can change the storage location, support external monitoring and logging or support remote storage.

  • sg_logger_params : dict

    SGLogger parameters

  • clip_grad_norm : float

    Defines a maximal L2 norm of the gradients. Values which exceed the given value will be clipped

  • lr_cooldown_epochs : int (default=0)

    Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).

  • pre_prediction_callback : Callable (default=None)

    When not None, this callback will be applied to images and targets, and returning them to be used

    for the forward pass, and further computations. Args for this callable should be in the order (inputs, targets, batch_idx) returning modified_inputs, modified_targets

  • ckpt_best_name : str (default=’ckpt_best.pth’)

    The best checkpoint (according to metric_to_watch) will be saved under this filename in the checkpoints directory.

  • enable_qat: bool (default=False)

    Adds a QATCallback to the phase callbacks, that triggers quantization aware training starting from

    qat_params[“start_epoch”]

  • qat_params: dict-like object with the following key/values:

    start_epoch: int, first epoch to start QAT.

    quant_modules_calib_method: str, One of [percentile, mse, entropy, max]. Statistics method for amax

    computation of the quantized modules (default=percentile).

    per_channel_quant_modules: bool, whether quant modules should be per channel (default=False).

    calibrate: bool, whether to perfrom calibration (default=False).

    calibrated_model_path: str, path to a calibrated checkpoint (default=None).

    calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. When None,

    context.train_loader will be used (default=None).

    num_calib_batches: int, number of batches to collect the statistics from.

    percentile: float, percentile value to use when SgModel,quant_modules_calib_method=’percentile’.

    Discarded when other methods are used (Default=99.99).

Returns

predict(inputs, targets=None, half=False, normalize=False, verbose=False, move_outputs_to_cpu=True)[source]

A fast predictor for a batch of inputs :param inputs: torch.tensor or numpy.array

a batch of inputs

Parameters
  • targets – torch.tensor() corresponding labels - if non are given - accuracy will not be computed

  • verbose – bool print the results to screen

  • normalize – bool If true, normalizes the tensor according to the dataloader’s normalization values

  • half – Performs half precision evaluation

  • move_outputs_to_cpu – Moves the results from the GPU to the CPU

Returns

outputs, acc, net_time, gross_time networks predictions, accuracy calculation, forward pass net time, function gross time

property get_arch_params
property get_structure
property get_architecture
set_experiment_name(experiment_name)[source]
property get_module
set_module(module)[source]
test(test_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, loss: Optional[torch.nn.modules.loss._Loss] = None, silent_mode: bool = False, test_metrics_list=None, loss_logging_items_names=None, metrics_progress_verbose=False, test_phase_callbacks=None, use_ema_net=True)tuple[source]

Evaluates the model on given dataloader and metrics.

Parameters
  • test_loader – dataloader to perform test on.

  • test_metrics_list – (list(torchmetrics.Metric)) metrics list for evaluation.

  • silent_mode – (bool) controls verbosity

  • metrics_progress_verbose – (bool) controls the verbosity of metrics progress (default=False). Slows down the program.

:param use_ema_net (bool) whether to perform test on self.ema_model.ema (when self.ema_model.ema exists,

otherwise self.net will be tested) (default=True)

Returns

results tuple (tuple) containing the loss items and metric values.

All of the above args will override SgModel’s corresponding attribute when not equal to None. Then evaluation

is ran on self.test_loader with self.test_metrics.

evaluate(data_loader: torch.utils.data.dataloader.DataLoader, metrics: torchmetrics.collections.MetricCollection, evaluation_type: super_gradients.common.data_types.enum.evaluation_type.EvaluationType, epoch: Optional[int] = None, silent_mode: bool = False, metrics_progress_verbose: bool = False)[source]

Evaluates the model on given dataloader and metrics.

Parameters
  • data_loader – dataloader to perform evaluataion on

  • metrics – (MetricCollection) metrics for evaluation

  • evaluation_type – (EvaluationType) controls which phase callbacks will be used (for example, on batch end, when evaluation_type=EvaluationType.VALIDATION the Phase.VALIDATION_BATCH_END callbacks will be triggered)

  • epoch – (int) epoch idx

  • silent_mode – (bool) controls verbosity

  • metrics_progress_verbose – (bool) controls the verbosity of metrics progress (default=False). Slows down the program significantly.

Returns

results tuple (tuple) containing the loss items and metric values.

property get_net

Getter for network. :return: torch.nn.Module, self.net

set_net(net: torch.nn.modules.module.Module)[source]

Setter for network.

Parameters

net – torch.nn.Module, value to set net

Returns

set_ckpt_best_name(ckpt_best_name)[source]

Setter for best checkpoint filename.

Parameters

ckpt_best_name – str, value to set ckpt_best_name

set_ema(val: bool)[source]

Setter for self.ema

Parameters

val – bool, value to set ema

class super_gradients.training.KDModel(*args, **kwargs)[source]

Bases: super_gradients.training.sg_model.sg_model.SgModel

build_model(architecture: Union[str, super_gradients.training.models.kd_modules.kd_module.KDModule] = 'kd_module', arch_params={}, checkpoint_params={}, *args, **kwargs)[source]
Parameters
  • architecture – (Union[str, KDModule]) Defines the network’s architecture from models/KD_ARCHITECTURES (default=’kd_module’)

  • arch_params – (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd architecture class (discarded when architecture is KDModule instance)

  • checkpoint_params

    (dict) A dictionary like object with the following keys/values:

    student_pretrained_weights: String describing the dataset of the pretrained weights (for example “imagenent”) for the student network.

    teacher_pretrained_weights: String describing the dataset of the pretrained weights (for example “imagenent”) for the teacher network.

    teacher_checkpoint_path: Local path to the teacher’s checkpoint. Note that when passing pretrained_weights

    through teacher_arch_params these weights will be overridden by the pretrained checkpoint. (default=None)

    load_kd_model_checkpoint: Whether to load an entire KDModule checkpoint (used to continue KD training)

    (default=False)

    kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from

    (self.experiment_name if none is given) to resume KD training (default=None)

    kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative

    (ie: path/to/checkpoint.pth). If provided, will automatically attempt to load the checkpoint even if the load_checkpoint flag is not provided. (deafult=None)

Keyword Arguments
  • student_architecture – (Union[str, SgModule]) Defines the student’s architecture from models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).

  • teacher_architecture – (Union[str, SgModule]) Defines the teacher’s architecture from models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).

  • student_arch_params – (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student net. (deafult={})

  • teacher_arch_params – (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher net. (deafult={})

  • run_teacher_on_eval – (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)

class super_gradients.training.MultiGPUMode(value)[source]

Bases: str, enum.Enum

OFF                       - Single GPU Mode / CPU Mode
DATA_PARALLEL             - Multiple GPUs, Synchronous
DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous
OFF = 'Off'
DATA_PARALLEL = 'DP'
DISTRIBUTED_DATA_PARALLEL = 'DDP'
AUTO = 'AUTO'
class super_gradients.training.SegmentationTestDatasetInterface(dataset_params={}, image_size=512, batch_size=4)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.DetectionTestDatasetInterface(dataset_params={}, image_size=320, batch_size=4, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.ClassificationTestDatasetInterface(dataset_params={}, image_size=32, batch_size=5, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.StrictLoad(value)[source]

Bases: enum.Enum

Wrapper for adding more functionality to torch’s strict_load parameter in load_state_dict().
Attributes:

OFF - Native torch “strict_load = off” behaviour. See nn.Module.load_state_dict() documentation for more details. ON - Native torch “strict_load = on” behaviour. See nn.Module.load_state_dict() documentation for more details. NO_KEY_MATCHING - Allows the usage of SuperGradient’s adapt_checkpoint function, which loads a checkpoint by matching each

layer’s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).

OFF = False
ON = True
NO_KEY_MATCHING = 'no_key_matching'

super_gradients.training.datasets module

class super_gradients.training.datasets.DataAugmentation[source]

Bases: object

static to_tensor()[source]
static normalize(mean, std)[source]
static cutout(mask_size, p=1, cutout_inside=False, mask_color=(0, 0, 0))[source]
class super_gradients.training.datasets.ListDataset(root, file, sample_loader: Callable = <function default_loader>, target_loader: Optional[Callable] = None, collate_fn: Optional[Callable] = None, sample_extensions: tuple = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'), sample_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, target_extension='.npy')[source]

Bases: Generic[torch.utils.data.dataset.T_co]

ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.

Then, the assumption is that for every sample, there is a * matching target * in the same path but with a different extension, i.e:

for the samples paths: (That appear in the list file)

/root/dataset/class_x/sample1.png /root/dataset/class_y/sample123.png

the matching labels paths: (That DO NOT appear in the list file)

/root/dataset/class_x/sample1.ext /root/dataset/class_y/sample123.ext

class super_gradients.training.datasets.DirectoryDataSet(root: str, samples_sub_directory: str, targets_sub_directory: str, target_extension: str, sample_loader: Callable = <function default_loader>, target_loader: Optional[Callable] = None, collate_fn: Optional[Callable] = None, sample_extensions: tuple = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'), sample_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:
  • Sub-Directory for Samples

  • Sub-Directory for Targets

class super_gradients.training.datasets.SegmentationDataSet(root: str, list_file: str = None, samples_sub_directory: str = None, targets_sub_directory: str = None, img_size: int = 608, crop_size: int = 512, batch_size: int = 16, augment: bool = False, dataset_hyper_params: dict = None, cache_labels: bool = False, cache_images: bool = False, sample_loader: Callable = None, target_loader: Callable = None, collate_fn: Callable = None, target_extension: str = '.png', image_mask_transforms: torchvision.transforms.transforms.Compose = None, image_mask_transforms_aug: torchvision.transforms.transforms.Compose = None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

static sample_loader(sample_path: str)<module ‘PIL.Image’ from ‘/Users/shaniperl/opt/anaconda3/lib/python3.9/site-packages/PIL/Image.py’>[source]
sample_loader - Loads a dataset image from path using PIL
param sample_path

The path to the sample image

return

The loaded Image

static sample_transform(image)[source]

sample_transform - Transforms the sample image

param image

The input image to transform

return

The transformed image

static target_loader(target_path: str)<module ‘PIL.Image’ from ‘/Users/shaniperl/opt/anaconda3/lib/python3.9/site-packages/PIL/Image.py’>[source]
Parameters

target_path – The path to the sample image

Returns

The loaded Image

static target_transform(target)[source]

target_transform - Transforms the sample image

param target

The target mask to transform

return

The transformed target mask

class super_gradients.training.datasets.PascalVOC2012SegmentationDataSet(sample_suffix=None, target_suffix=None, *args, **kwargs)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set

decode_segmentation_mask(label_mask: numpy.ndarray)[source]
decode_segmentation_mask - Decodes the colors for the Segmentation Mask
param

label_mask: an (M,N) array of integer values denoting the class label at each spatial location.

Returns

class super_gradients.training.datasets.PascalAUG2012SegmentationDataSet(*args, **kwargs)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set

static target_loader(target_path: str)<module ‘PIL.Image’ from ‘/Users/shaniperl/opt/anaconda3/lib/python3.9/site-packages/PIL/Image.py’>[source]
Parameters

target_path – The path to the target data

Returns

The loaded target

class super_gradients.training.datasets.CoCoSegmentationDataSet(dataset_classes_inclusion_tuples_list: Optional[list] = None, *args, **kwargs)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

CoCoSegmentationDataSet - Segmentation Data Set Class for COCO 2017 Segmentation Data Set

target_loader(mask_metadata_tuple)<module ‘PIL.Image’ from ‘/Users/shaniperl/opt/anaconda3/lib/python3.9/site-packages/PIL/Image.py’>[source]
Parameters

mask_metadata_tuple – A tuple of (coco_image_id, original_image_height, original_image_width)

Returns

The mask image created from the array

class super_gradients.training.datasets.TestDatasetInterface(trainset, dataset_params={}, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

get_data_loaders(batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None, distributed_sampler=False)[source]

Get self.train_loader, self.val_loader, self.test_loader, self.classes.

If the data loaders haven’t been initialized yet, build them first.

Parameters

kwargs – kwargs are passed to build_data_loaders.

class super_gradients.training.datasets.DatasetInterface(dataset_params={}, train_loader=None, val_loader=None, test_loader=None, classes=None)[source]

Bases: object

DatasetInterface - This class manages all of the “communiation” the Model has with the Data Sets

download_from_cloud()[source]
build_data_loaders(batch_size_factor=1, num_workers=8, train_batch_size=None, val_batch_size=None, test_batch_size=None, distributed_sampler: bool = False)[source]

define train, val (and optionally test) loaders. The method deals separately with distributed training and standard (non distributed, or parallel training). In the case of distributed training we need to rely on distributed samplers. :param batch_size_factor: int - factor to multiply the batch size (usually for multi gpu) :param num_workers: int - number of workers (parallel processes) for dataloaders :param train_batch_size: int - batch size for train loader, if None will be taken from dataset_params :param val_batch_size: int - batch size for val loader, if None will be taken from dataset_params :param distributed_sampler: boolean flag for distributed training mode :return: train_loader, val_loader, classes: list of classes

get_data_loaders(**kwargs)[source]

Get self.train_loader, self.val_loader, self.test_loader, self.classes.

If the data loaders haven’t been initialized yet, build them first.

Parameters

kwargs – kwargs are passed to build_data_loaders.

get_val_sample(num_samples=1)[source]
get_dataset_params()[source]
print_dataset_details()[source]
class super_gradients.training.datasets.Cifar10DatasetInterface(dataset_params={})[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.LibraryDatasetInterface

class super_gradients.training.datasets.CoCoSegmentationDatasetInterface(dataset_params=None, cache_labels: bool = False, cache_images: bool = False, dataset_classes_inclusion_tuples_list: Optional[list] = None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.CoCoDataSetInterfaceBase

class super_gradients.training.datasets.PascalVOC2012SegmentationDataSetInterface(dataset_params=None, cache_labels=False, cache_images=False)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

class super_gradients.training.datasets.PascalAUG2012SegmentationDataSetInterface(dataset_params=None, cache_labels=False, cache_images=False)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

class super_gradients.training.datasets.TestYoloDetectionDatasetInterface(dataset_params={}, input_dims=(3, 32, 32), batch_size=5)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

note: the output size is (batch_size, 6) in the test while in real training the size of axis 0 can vary (the number of bounding boxes)

class super_gradients.training.datasets.DetectionTestDatasetInterface(dataset_params={}, image_size=320, batch_size=4, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.datasets.ClassificationTestDatasetInterface(dataset_params={}, image_size=32, batch_size=5, classes=None)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.datasets.SegmentationTestDatasetInterface(dataset_params={}, image_size=512, batch_size=4)[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.TestDatasetInterface

class super_gradients.training.datasets.ImageNetDatasetInterface(dataset_params={}, data_dir='/data/Imagenet')[source]

Bases: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface

class super_gradients.training.datasets.DetectionDataset(data_dir: str, input_dim: tuple, original_target_format: super_gradients.training.utils.detection_utils.DetectionTargetsFormat, max_num_samples: Optional[int] = None, cache: bool = False, cache_path: Optional[str] = None, transforms: List[super_gradients.training.transforms.transforms.DetectionTransform] = [], all_classes_list: Optional[List[str]] = None, class_inclusion_list: Optional[List[str]] = None, ignore_empty_annotations: bool = True, target_fields: Optional[List[str]] = None, output_fields: Optional[List[str]] = None)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Detection dataset.

This is a boilerplate class to facilitate the implementation of datasets.

HOW TO CREATE A DATASET THAT INHERITS FROM DetectionDataSet ?
  • Inherit from DetectionDataSet

  • implement the method self._load_annotation to return at least the fields “target” and “img_path”

  • Call super().__init__ with the required params.
    //!super().__init__ will call self._load_annotation, so make sure that every required

    attributes are set up before calling super().__init__ (ideally just call it last)

WORKFLOW:
  • On instantiation:
    • All annotations are cached. If class_inclusion_list was specified, there is also subclassing at this step.

    • If cache is True, the images are also cached

  • On call (__getitem__) for a specific image index:
    • The image and annotations are grouped together in a dict called SAMPLE

    • the sample is processed according to th transform

    • Only the specified fields are returned by __getitem__

TERMINOLOGY
  • TARGET: Groundtruth, made of bboxes. The format can vary from one dataset to another

  • ANNOTATION: Combination of targets (groundtruth) and metadata of the image, but without the image itself.

    > Has to include the fields “target” and “img_path” > Can include other fields like “crowd_target”, “image_info”, “segmentation”, …

  • SAMPLE: Outout of the dataset:

    > Has to include the fields “target” and “image” > Can include other fields like “crowd_target”, “image_info”, “segmentation”, …

  • INDEX: Refers to the index in the dataset.

  • SAMPLE ID: Refers to the id of sample before droping any annotaion.

    Let’s imagine a situation where the downloaded data is made of 120 images but 20 were drop because they had no annotation. In that case:

    > We have 120 samples so sample_id will be between 0 and 119 > But only 100 will be indexed so index will be between 0 and 99 > Therefore, we also have len(self) = 100

get_random_item()[source]
get_sample(index: int)Dict[str, Union[numpy.ndarray, Any]][source]

Get raw sample, before any transform (beside subclassing). :param index: Image index :return: Sample, i.e. a dictionary including at least “image” and “target”

get_resized_image(index: int)numpy.ndarray[source]

Get the resized image at a specific sample_id, either from cache or by loading from disk, based on self.cached_imgs :param index: Image index :return: Resized image

apply_transforms(sample: Dict[str, Union[numpy.ndarray, Any]])Dict[str, Union[numpy.ndarray, Any]][source]

Applies self.transforms sequentially to sample

If a transforms has the attribute ‘additional_samples_count’, additional samples will be loaded and stored in

sample[“additional_samples”] prior to applying it. Combining with the attribute “non_empty_annotations” will load only additional samples with objects in them.

Parameters

sample – Sample to apply the transforms on to (loaded with self.get_sample)

Returns

Transformed sample

get_random_samples(count: int, non_empty_annotations_only: bool = False)List[Dict[str, Union[numpy.ndarray, Any]]][source]

Load random samples.

Parameters
  • count – The number of samples wanted

  • non_empty_annotations_only – If true, only return samples with at least 1 annotation

Returns

A list of samples satisfying input params

get_random_sample(non_empty_annotations_only: bool = False)[source]
property output_target_format
plot(max_samples_per_plot: int = 16, n_plots: int = 1, plot_transformed_data: bool = True)[source]

Combine samples of images with bbox into plots and display the result.

Parameters
  • max_samples_per_plot – Maximum number of images to be displayed per plot

  • n_plots – Number of plots to display (each plot being a combination of img with bbox)

  • plot_transformed_data – If True, the plot will be over samples after applying transforms (i.e. on __getitem__). If False, the plot will be over the raw samples (i.e. on get_sample)

Returns

class super_gradients.training.datasets.COCODetectionDataset(img_size: tuple, data_dir: Optional[str] = None, json_file: str = 'instances_train2017.json', name: str = 'images/train2017', cache: bool = False, cache_dir_path: Optional[str] = None, tight_box_rotation: bool = False, transforms: list = [], with_crowd: bool = True)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Detection dataset COCO implementation

load_resized_img(index)[source]

Loads image at index, and resizes it to self.input_dim

Parameters

index – index to load the image from

Returns

resized_img

load_sample(index)[source]
Loads sample at self.ids[index] as dictionary that holds:

“image”: Image resized to self.input_dim “target”: Detection ground truth, np.array shaped (num_targets, 5), format is [class,x1,y1,x2,y2] with

image coordinates.

“target_seg”: Segmentation map convex hull derived detection target. “info”: Original shape (height,width). “id”: COCO image id

Parameters

index – Sample index

Returns

sample as described above

load_image(index)[source]

Loads image at index with its original resolution :param index: index in self.annotations :return: image (np.array)

apply_transforms(sample: dict)[source]

Applies self.transforms sequentially to sample

If a transforms has the attribute ‘additional_samples_count’, additional samples will be loaded and stored in

sample[“additional_samples”] prior to applying it. Combining with the attribute “non_empty_targets” will load only additional samples with objects in them.

Parameters

sample – Sample to apply the transforms on to (loaded with self.load_sample)

Returns

Transformed sample

class super_gradients.training.datasets.PascalVOCDetectionDataset(images_sub_directory: str, *args, **kwargs)[source]

Bases: Generic[torch.utils.data.dataset.T_co]

Dataset for Pascal VOC object detection

static download(data_dir: str)[source]

Download Pascal dataset in XYXY_LABEL format.

Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/

super_gradients.training.exceptions module

super_gradients.training.legacy module

super_gradients.training.losses_models module

class super_gradients.training.losses.FocalLoss(loss_fcn: torch.nn.modules.loss.BCEWithLogitsLoss, gamma=1.5, alpha=0.25)[source]

Bases: torch.nn.modules.loss._Loss

Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)

reduction: str
forward(pred, true)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class super_gradients.training.losses.LabelSmoothingCrossEntropyLoss(weight=None, ignore_index=- 100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True)[source]

Bases: torch.nn.modules.loss.CrossEntropyLoss

CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing

forward(input, target, smooth_dist=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ignore_index: int
label_smoothing: float
class super_gradients.training.losses.ShelfNetOHEMLoss(threshold: float = 0.7, mining_percent: float = 0.0001, ignore_lb: int = 255)[source]

Bases: super_gradients.training.losses.ohem_ce_loss.OhemCELoss

forward(predictions_list: list, targets)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction: str
class super_gradients.training.losses.ShelfNetSemanticEncodingLoss(se_weight=0.2, nclass=21, aux_weight=0.4, weight=None, ignore_index=- 1)[source]

Bases: torch.nn.modules.loss.CrossEntropyLoss

2D Cross Entropy Loss with Auxilary Loss

forward(logits, labels)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ignore_index: int
label_smoothing: float
class super_gradients.training.losses.YoloXDetectionLoss(strides: list, num_classes: int, use_l1: bool = False, center_sampling_radius: float = 2.5, iou_type='iou')[source]

Bases: torch.nn.modules.loss._Loss

Calculate YOLOX loss: L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1

where:
  • L_iou, L_classification and L_l1 are calculated only between cells and targets that suit them;

  • L_objectivness is calculated for all cells.

L_classification:

for cells that have suitable ground truths in their grid locations add BCEs to force a prediction of IoU with a GT in a multi-label way Coef: 1.

L_iou:

for cells that have suitable ground truths in their grid locations add (1 - IoU^2), IoU between a predicted box and each GT box, force maximum IoU Coef: 5.

L_l1:

for cells that have suitable ground truths in their grid locations l1 distance between the logits and GTs in “logits” format (the inverse of “logits to predictions” ops) Coef: 1[use_l1]

L_objectness:

for each cell add BCE with a label of 1 if there is GT assigned to the cell Coef: 1

strides

list: List of Yolo levels output grid sizes (i.e [8, 16, 32]).

num_classes

int: Number of classes.

use_l1

bool: Controls the L_l1 Coef as discussed above (default=False).

center_sampling_radius

float: Sampling radius used for center sampling when creating the fg mask (default=2.5).

iou_type

str: Iou loss type, one of [“iou”,”giou”] (deafult=”iou”).

forward(model_output: Union[list, Tuple[torch.Tensor, List]], targets: torch.Tensor)[source]
Parameters
  • model_output

    Union[list, Tuple[torch.Tensor, List]]: When list-

    output from all Yolo levels, each of shape [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]

    And when tuple- the second item is the described list (first item is discarded)

  • targets – torch.Tensor: Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h

Returns

loss, all losses separately in a detached tensor

prepare_predictions(predictions: List[torch.Tensor])Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Convert raw outputs of the network into a format that merges outputs from all levels :param predictions: output from all Yolo levels, each of shape

[Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]

Returns

5 tensors representing predictions: * x_shifts: shape [1 x * num_cells x 1],

where num_cells = grid1X * grid1Y + grid2X * grid2Y + grid3X * grid3Y, x coordinate on the grid cell the prediction is coming from

  • y_shifts: shape [1 x num_cells x 1], y coordinate on the grid cell the prediction is coming from

  • expanded_strides: shape [1 x num_cells x 1], stride of the output grid the prediction is coming from

  • transformed_outputs: shape [batch_size x num_cells x (num_classes + 5)], predictions with boxes in real coordinates and logprobabilities

  • raw_outputs: shape [batch_size x num_cells x (num_classes + 5)], raw predictions with boxes and confidences as logits

get_l1_target(l1_target, gt, stride, x_shifts, y_shifts, eps=1e-08)[source]
Parameters
  • l1_target – tensor of zeros of shape [Num_cell_gt_pairs x 4]

  • gt – targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)]

Returns

targets in the format corresponding to logits

get_assignments(image_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, mode='gpu', ious_loss_cost_coeff=3.0, outside_boxes_and_center_cost_coeff=100000.0)[source]
Match cells to ground truth:
  • at most 1 GT per cell

  • dynamic number of cells per GT

Parameters
  • outside_boxes_and_center_cost_coeff – float: Cost coefficiant of cells the radius and bbox of gts in dynamic matching (default=100000).

  • ious_loss_cost_coeff – float: Cost coefficiant for iou loss in dynamic matching (default=3).

  • image_idx – int: Image index in batch.

  • num_gt – int: Number of ground trunth targets in the image.

  • total_num_anchors – int: Total number of possible bboxes = sum of all grid cells.

  • gt_bboxes_per_image – torch.Tensor: Tensor of gt bboxes for the image, shape: (num_gt, 4).

  • gt_classes – torch.Tesnor: Tensor of the classes in the image, shape: (num_preds,4).

  • bboxes_preds_per_image – Tensor of the classes in the image, shape: (num_preds).

  • expanded_strides – torch.Tensor: Stride of the output grid the prediction is coming from, shape (1 x num_cells x 1).

  • x_shifts – torch.Tensor: X’s in cell coordinates, shape (1,num_cells,1).

  • y_shifts – torch.Tensor: Y’s in cell coordinates, shape (1,num_cells,1).

  • cls_preds – torch.Tensor: Class predictions in all cells, shape (batch_size, num_cells).

  • obj_preds – torch.Tensor: Objectness predictions in all cells, shape (batch_size, num_cells).

  • mode – str: One of [“gpu”,”cpu”], Controls the device the assignment operation should be taken place on (deafult=”gpu”)

get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)[source]
Create a mask for all cells, mask in only foreground: cells that have a center located:
  • withing a GT box;

OR * within a fixed radius around a GT box (center sampling);

Parameters
  • num_gt – int: Number of ground trunth targets in the image.

  • total_num_anchors – int: Sum of all grid cells.

  • gt_bboxes_per_image – torch.Tensor: Tensor of gt bboxes for the image, shape: (num_gt, 4).

  • expanded_strides – torch.Tensor: Stride of the output grid the prediction is coming from, shape (1 x num_cells x 1).

  • x_shifts – torch.Tensor: X’s in cell coordinates, shape (1,num_cells,1).

  • y_shifts – torch.Tensor: Y’s in cell coordinates, shape (1,num_cells,1).

:return is_in_boxes_anchor, is_in_boxes_and_center
where:
  • is_in_boxes_anchor masks the cells that their cell center is inside a gt bbox and within

    self.center_sampling_radius cells away, without reduction (i.e shape=(num_gts, num_fgs))

  • is_in_boxes_and_center masks the cells that their center is either inside a gt bbox or within

    self.center_sampling_radius cells away, shape (num_fgs)

dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)[source]
Parameters
  • cost – pairwise cost, [num_FGs x num_GTs]

  • pair_wise_ious – pairwise IoUs, [num_FGs x num_GTs]

  • gt_classes – class of each GT

  • num_gt – number of GTs

:return num_fg, (number of foregrounds)

gt_matched_classes, (the classes that have been matched with fgs) pred_ious_this_matching matched_gt_inds

reduction: str
class super_gradients.training.losses.RSquaredLoss(size_average=None, reduce=None, reduction: str = 'mean')[source]

Bases: torch.nn.modules.loss._Loss

forward(output, target)[source]

Computes the R-squared for the output and target values :param output: Tensor / Numpy / List

The prediction

Parameters

target – Tensor / Numpy / List The corresponding lables

reduction: str
class super_gradients.training.losses.SSDLoss(dboxes: super_gradients.training.utils.ssd_utils.DefaultBoxes, alpha: float = 1.0, iou_thresh: float = 0.5, neg_pos_ratio: float = 3.0)[source]

Bases: torch.nn.modules.loss._Loss

Implements the loss as the sum of the followings: 1. Confidence Loss: All labels, with hard negative mining 2. Localization Loss: Only on positive labels

L = (2 - alpha) * L_l1 + alpha * L_cls, where
  • L_cls is HardMiningCrossEntropyLoss

  • L_l1 = [SmoothL1Loss for all positives]

match_dboxes(targets)[source]

creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.

  • Each GT is assigned with a grid cell with the highest IoU, this creates a pair for each GT and some cells;

  • The rest of grid cells are assigned to a GT with the highest IoU, assuming it’s > self.iou_thresh; If this condition is not met the grid cell is marked as background

GT-wise: one to many Grid-cell-wise: one to one

Parameters

targets – a tensor containing the boxes for a single image; shape [num_boxes, 6] (image_id, label, x, y, w, h)

Returns

two tensors boxes - shape of dboxes [4, num_dboxes] (x,y,w,h) labels - sahpe [num_dboxes]

forward(predictions: Tuple, targets)[source]
Compute the loss

:param predictions - predictions tensor coming from the network, tuple with shapes ([Batch Size, 4, num_dboxes], [Batch Size, num_classes + 1, num_dboxes]) were predictions have logprobs for background and other classes :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)

reduction: str
class super_gradients.training.losses.BCEDiceLoss(loss_weights=[0.5, 0.5], logits=True)[source]

Bases: torch.nn.modules.module.Module

Binary Cross Entropy + Dice Loss

Weighted average of BCE and Dice loss

loss_weights

list of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice

respectively.
forward(input: torch.Tensor, target: torch.Tensor)torch.Tensor[source]

@param input: Network’s raw output shaped (N,1,H,W) @param target: Ground truth shaped (N,H,W)

training: bool
class super_gradients.training.losses.KDLogitsLoss(task_loss_fn: torch.nn.modules.loss._Loss, distillation_loss_fn: torch.nn.modules.loss._Loss = KDklDivLoss(), distillation_loss_coeff: float = 0.5)[source]

Bases: torch.nn.modules.loss._Loss

Knowledge distillation loss, wraps the task loss and distillation loss

forward(kd_module_output, target)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction: str
class super_gradients.training.losses.DiceCEEdgeLoss(num_classes: int, num_aux_heads: int = 2, num_detail_heads: int = 1, weights: Union[tuple, list] = (1, 1, 1, 1), dice_ce_weights: Union[tuple, list] = (1, 1), ignore_index: int = - 100, edge_kernel: int = 3, ce_edge_weights: Union[tuple, list] = (0.5, 0.5))[source]

Bases: torch.nn.modules.loss._Loss

forward(preds: Tuple[torch.Tensor], target: torch.Tensor)[source]
Parameters

preds – Model output predictions, must be in the followed format: [Main-feats, Aux-feats[0], …, Aux-feats[num_auxs-1], Detail-feats[0], …, Detail-feats[num_details-1]

reduction: str

super_gradients.training.metrics module

super_gradients.training.metrics.accuracy(output, target, topk=(1))[source]

Computes the precision@k for the specified values of k :param output: Tensor / Numpy / List

The prediction

Parameters
  • target – Tensor / Numpy / List The corresponding lables

  • topk – tuple The type of accuracy to calculate, e.g. topk=(1,5) returns accuracy for top-1 and top-5

class super_gradients.training.metrics.Accuracy(dist_sync_on_step=False)[source]

Bases: torchmetrics.classification.accuracy.Accuracy

update(preds: torch.Tensor, target: torch.Tensor)[source]

Update state with predictions and targets. See pages/classification:input types for more information on input types.

Parameters
  • preds – Predictions from model (logits, probabilities, or labels)

  • target – Ground truth labels

correct: torch.Tensor
total: torch.Tensor
class super_gradients.training.metrics.Top5(dist_sync_on_step=False)[source]

Bases: torchmetrics.metric.Metric

update(preds: torch.Tensor, target: torch.Tensor)[source]

Override this method to update the state variables of your metric class.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

class super_gradients.training.metrics.ToyTestClassificationMetric(dist_sync_on_step=False)[source]

Bases: torchmetrics.metric.Metric

Dummy classification Mettric object returning 0 always (for testing).

update(preds: torch.Tensor, target: torch.Tensor)None[source]

Override this method to update the state variables of your metric class.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

class super_gradients.training.metrics.DetectionMetrics(num_cls: int, post_prediction_callback: Optional[super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback] = None, normalize_targets: bool = False, iou_thres: super_gradients.training.utils.detection_utils.IouThreshold = <IouThreshold.MAP_05_TO_095: (0.5, 0.95)>, recall_thres: Optional[torch.Tensor] = None, score_thres: float = 0.1, top_k_predictions: int = 100, dist_sync_on_step: bool = False, accumulate_on_cpu: bool = True)[source]

Bases: torchmetrics.metric.Metric

Metric class for computing F1, Precision, Recall and Mean Average Precision.

num_cls

Number of classes.

post_prediction_callback

DetectionPostPredictionCallback to be applied on net’s output prior to the metric computation (NMS).

normalize_targets

Whether to normalize bbox coordinates by image size (default=False).

iou_thresholds

IoU threshold to compute the mAP (default=torch.linspace(0.5, 0.95, 10)).

recall_thresholds

Recall threshold to compute the mAP (default=torch.linspace(0, 1, 101)).

score_threshold

Score threshold to compute Recall, Precision and F1 (default=0.1)

top_k_predictions

Number of predictions per class used to compute metrics, ordered by confidence score (default=100)

dist_sync_on_step

Synchronize metric state across processes at each forward() before returning the value at the step. (default=False)

accumulate_on_cpu: Run on CPU regardless of device used in other parts.

This is to avoid “CUDA out of memory” that might happen on GPU (default False)

update(preds, target: torch.Tensor, device: str, inputs: torch._VariableFunctionsClass.tensor, crowd_targets: Optional[torch.Tensor] = None)[source]

Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.

:param predsRaw output of the model, the format might change from one model to another, but has to fit

the input format of the post_prediction_callback

Parameters
  • target – Targets for all images of shape (total_num_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]

  • device – Device to run on

  • inputs – Input image tensor of shape (batch_size, n_img, height, width)

  • crowd_targets – Crowd targets for all images of shape (total_num_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]

compute()Dict[str, Union[float, torch.Tensor]][source]

Compute the metrics for all the accumulated results. :return: Metrics of interest

class super_gradients.training.metrics.PreprocessSegmentationMetricsArgs(apply_arg_max: bool = False, apply_sigmoid: bool = False)[source]

Bases: super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn

Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and apply normalizations.

class super_gradients.training.metrics.PixelAccuracy(ignore_label=- 100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn] = None)[source]

Bases: torchmetrics.metric.Metric

update(preds: torch.Tensor, target: torch.Tensor)[source]

Override this method to update the state variables of your metric class.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

class super_gradients.training.metrics.IoU(num_classes: int, dist_sync_on_step: bool = False, ignore_index: Optional[int] = None, reduction: str = 'elementwise_mean', threshold: float = 0.5, metrics_args_prep_fn: Optional[super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn] = None)[source]

Bases: torchmetrics.classification.jaccard.JaccardIndex

update(preds, target: torch.Tensor)[source]

Update state with predictions and targets.

Parameters
  • preds – Predictions from model

  • target – Ground truth values

confmat: torch.Tensor
class super_gradients.training.metrics.Dice(num_classes: int, dist_sync_on_step: bool = False, ignore_index: Optional[int] = None, reduction: str = 'elementwise_mean', threshold: float = 0.5, metrics_args_prep_fn: Optional[super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn] = None)[source]

Bases: torchmetrics.classification.jaccard.JaccardIndex

update(preds, target: torch.Tensor)[source]

Update state with predictions and targets.

Parameters
  • preds – Predictions from model

  • target – Ground truth values

compute()torch.Tensor[source]

Computes Dice coefficient

confmat: torch.Tensor
class super_gradients.training.metrics.BinaryIOU(dist_sync_on_step=True, ignore_index: Optional[int] = None, threshold: float = 0.5, metrics_args_prep_fn: Optional[super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn] = None)[source]

Bases: super_gradients.training.metrics.segmentation_metrics.IoU

compute()[source]

Computes intersection over union (IoU)

confmat: torch.Tensor
training: bool
class super_gradients.training.metrics.BinaryDice(dist_sync_on_step=True, ignore_index: Optional[int] = None, threshold: float = 0.5, metrics_args_prep_fn: Optional[super_gradients.training.metrics.segmentation_metrics.AbstractMetricsArgsPrepFn] = None)[source]

Bases: super_gradients.training.metrics.segmentation_metrics.Dice

compute()[source]

Computes Dice coefficient

confmat: torch.Tensor
training: bool

super_gradients.training.models module

super_gradients.training.sg_model module

class super_gradients.training.sg_model.SgModel(experiment_name: str, device: Optional[str] = None, multi_gpu: Union[super_gradients.common.data_types.enum.multi_gpu_mode.MultiGPUMode, str] = <MultiGPUMode.OFF: 'Off'>, model_checkpoints_location: str = 'local', overwrite_local_checkpoint: bool = True, ckpt_name: str = 'ckpt_latest.pth', post_prediction_callback: Optional[super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback] = None, ckpt_root_dir: Optional[str] = None, train_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, valid_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, test_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, classes: Optional[List[Any]] = None)[source]

Bases: object

SuperGradient Model - Base Class for Sg Models

train(max_epochs: int, initial_epoch: int, save_model: bool)[source]

the main function used for the training, h.p. updating, logging etc.

predict(idx: int)[source]

returns the predictions and label of the current inputs

test(epoch : int, idx : int, save : bool):

returns the test loss, accuracy and runtime

connect_dataset_interface(dataset_interface: super_gradients.training.datasets.dataset_interfaces.dataset_interface.DatasetInterface, data_loader_num_workers: int = 8)[source]
Parameters
  • dataset_interface – DatasetInterface object

  • data_loader_num_workers – The number of threads to initialize the Data Loaders with The dataset to be connected

build_model(architecture: Union[str, torch.nn.modules.module.Module], arch_params={}, checkpoint_params={}, *args, **kwargs)[source]
Parameters
  • architecture – Defines the network’s architecture from models/ALL_ARCHITECTURES

  • arch_params – Architecture H.P. e.g.: block, num_blocks, num_classes, etc.

  • checkpoint_params

    Dictionary like object with the following key:values:

    load_checkpoint: Load a pre-trained checkpoint strict_load: See StrictLoad class documentation for details. source_ckpt_folder_name: folder name to load the checkpoint from (self.experiment_name if none is given) load_weights_only: loads only the weight from the checkpoint and zeroize the training params load_backbone: loads the provided checkpoint to self.net.backbone instead of self.net external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative

    (ie: path/to/checkpoint.pth). If provided, will automatically attempt to load the checkpoint even if the load_checkpoint flag is not provided.

train(training_params: dict = {})[source]

train - Trains the Model

IMPORTANT NOTE: Additional batch parameters can be added as a third item (optional) if a tuple is returned by

the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.

param training_params
  • max_epochs : int

    Number of epochs to run training.

  • lr_updates : list(int)

    List of fixed epoch numbers to perform learning rate updates when lr_mode=’step’.

  • lr_decay_factor : float

    Decay factor to apply to the learning rate at each update when lr_mode=’step’.

  • lr_mode : str

    Learning rate scheduling policy, one of [‘step’,’poly’,’cosine’,’function’]. ‘step’ refers to constant updates at epoch numbers passed through lr_updates. ‘cosine’ refers to Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. ‘poly’ refers to polynomial decrease i.e in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9) ‘function’ refers to user defined learning rate scheduling function, that is passed through lr_schedule_function.

  • lr_schedule_function : Union[callable,None]

    Learning rate scheduling function to be used when lr_mode is ‘function’.

  • lr_warmup_epochs : int (default=0)

    Number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).

  • cosine_final_lr_ratiofloat (default=0.01)
    Final learning rate ratio (only relevant when `lr_mode`=’cosine’). The cosine starts from initial_lr and reaches

    initial_lr * cosine_final_lr_ratio in last epoch

  • inital_lr : float

    Initial learning rate.

  • loss : Union[nn.module, str]

    Loss function for training. One of SuperGradient’s built in options:

    “cross_entropy”: LabelSmoothingCrossEntropyLoss, “mse”: MSELoss, “r_squared_loss”: RSquaredLoss, “detection_loss”: YoLoV3DetectionLoss, “shelfnet_ohem_loss”: ShelfNetOHEMLoss, “shelfnet_se_loss”: ShelfNetSemanticEncodingLoss, “ssd_loss”: SSDLoss,

    or user defined nn.module loss function.

    IMPORTANT: forward(…) should return a (loss, loss_items) tuple where loss is the tensor used for backprop (i.e what your original loss function returns), and loss_items should be a tensor of shape (n_items), of values computed during the forward pass which we desire to log over the entire epoch. For example- the loss itself should always be logged. Another example is a scenario where the computed loss is the sum of a few components we would like to log- these entries in loss_items).

    When training, set the loss_logging_items_names parameter in train_params to be a list of strings, of length n_items who’s ith element is the name of the ith entry in loss_items. Then each item will be logged, rendered on tensorboard and “watched” (i.e saving model checkpoints according to it).

    Since running logs will save the loss_items in some internal state, it is recommended that loss_items are detached from their computational graph for memory efficiency.

  • optimizer : Union[str, torch.optim.Optimizer]

    Optimization algorithm. One of [‘Adam’,’SGD’,’RMSProp’] corresponding to the torch.optim optimzers implementations, or any object that implements torch.optim.Optimizer.

  • criterion_params : dict

    Loss function parameters.

  • optimizer_paramsdict

    When optimizer is one of [‘Adam’,’SGD’,’RMSProp’], it will be initialized with optimizer_params.

    (see https://pytorch.org/docs/stable/optim.html for the full list of parameters for each optimizer).

  • train_metrics_list : list(torchmetrics.Metric)

    Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.

  • valid_metrics_list : list(torchmetrics.Metric)

    Metrics to log during validation/testing. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.

  • loss_logging_items_names : list(str)

    The list of names/titles for the outputs returned from the loss functions forward pass (reminder- the loss function should return the tuple (loss, loss_items)). These names will be used for logging their values.

  • metric_to_watch : str (default=”Accuracy”)

    will be the metric which the model checkpoint will be saved according to, and can be set to any of the following:

    a metric name (str) of one of the metric objects from the valid_metrics_list

    a “metric_name” if some metric in valid_metrics_list has an attribute component_names which is a list referring to the names of each entry in the output metric (torch tensor of size n)

    one of “loss_logging_items_names” i.e which will correspond to an item returned during the loss function’s forward pass.

    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth

  • greater_metric_to_watch_is_better : bool

    When choosing a model’s checkpoint to be saved, the best achieved model is the one that maximizes the

    metric_to_watch when this parameter is set to True, and a one that minimizes it otherwise.

  • ema : bool (default=False)

    Whether to use Model Exponential Moving Average (see https://github.com/rwightman/pytorch-image-models ema implementation)

  • batch_accumulate : int (default=1)

    Number of batches to accumulate before every backward pass.

  • ema_params : dict

    Parameters for the ema model.

  • zero_weight_decay_on_bias_and_bn : bool (default=False)

    Whether to apply weight decay on batch normalization parameters or not (ignored when the passed optimizer has already been initialized).

  • load_opt_params : bool (default=True)

    Whether to load the optimizers parameters as well when loading a model’s checkpoint.

  • run_validation_freq : int (default=1)

    The frequency in which validation is performed during training (i.e the validation is ran every

    run_validation_freq epochs.

  • save_model : bool (default=True)

    Whether to save the model checkpoints.

  • silent_mode : bool

    Silents the print outs.

  • mixed_precision : bool

    Whether to use mixed precision or not.

  • save_ckpt_epoch_list : list(int) (default=[])

    List of fixed epoch indices the user wishes to save checkpoints in.

  • average_best_models : bool (default=False)

    If set, a snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when training is completed. The snapshot file will only be deleted upon completing the training. The snapshot dict will be managed on cpu.

  • precise_bn : bool (default=False)

    Whether to use precise_bn calculation during the training.

  • precise_bn_batch_size : int (default=None)

    The effective batch size we want to calculate the batchnorm on. For example, if we are training a model on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192 (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus). If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken.

  • seed : int (default=42)

    Random seed to be set for torch, numpy, and random. When using DDP each process will have it’s seed set to seed + rank.

  • log_installed_packages : bool (default=False)

    When set, the list of all installed packages (and their versions) will be written to the tensorboard

    and logfile (useful when trying to reproduce results).

  • dataset_statistics : bool (default=False)

    Enable a statistic analysis of the dataset. If set to True the dataset will be analyzed and a report will be added to the tensorboard along with some sample images from the dataset. Currently only detection datasets are supported for analysis.

  • save_full_train_log : bool (default=False)

    When set, a full log (of all super_gradients modules, including uncaught exceptions from any other

    module) of the training will be saved in the checkpoint directory under full_train_log.log

  • sg_logger : Union[AbstractSGLogger, str] (defauls=base_sg_logger)

    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging and remote storage. By overriding the default base_sg_logger, you can change the storage location, support external monitoring and logging or support remote storage.

  • sg_logger_params : dict

    SGLogger parameters

  • clip_grad_norm : float

    Defines a maximal L2 norm of the gradients. Values which exceed the given value will be clipped

  • lr_cooldown_epochs : int (default=0)

    Number of epochs to cooldown LR (i.e the last epoch from scheduling view point=max_epochs-cooldown).

  • pre_prediction_callback : Callable (default=None)

    When not None, this callback will be applied to images and targets, and returning them to be used

    for the forward pass, and further computations. Args for this callable should be in the order (inputs, targets, batch_idx) returning modified_inputs, modified_targets

  • ckpt_best_name : str (default=’ckpt_best.pth’)

    The best checkpoint (according to metric_to_watch) will be saved under this filename in the checkpoints directory.

  • enable_qat: bool (default=False)

    Adds a QATCallback to the phase callbacks, that triggers quantization aware training starting from

    qat_params[“start_epoch”]

  • qat_params: dict-like object with the following key/values:

    start_epoch: int, first epoch to start QAT.

    quant_modules_calib_method: str, One of [percentile, mse, entropy, max]. Statistics method for amax

    computation of the quantized modules (default=percentile).

    per_channel_quant_modules: bool, whether quant modules should be per channel (default=False).

    calibrate: bool, whether to perfrom calibration (default=False).

    calibrated_model_path: str, path to a calibrated checkpoint (default=None).

    calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. When None,

    context.train_loader will be used (default=None).

    num_calib_batches: int, number of batches to collect the statistics from.

    percentile: float, percentile value to use when SgModel,quant_modules_calib_method=’percentile’.

    Discarded when other methods are used (Default=99.99).

Returns

predict(inputs, targets=None, half=False, normalize=False, verbose=False, move_outputs_to_cpu=True)[source]

A fast predictor for a batch of inputs :param inputs: torch.tensor or numpy.array

a batch of inputs

Parameters
  • targets – torch.tensor() corresponding labels - if non are given - accuracy will not be computed

  • verbose – bool print the results to screen

  • normalize – bool If true, normalizes the tensor according to the dataloader’s normalization values

  • half – Performs half precision evaluation

  • move_outputs_to_cpu – Moves the results from the GPU to the CPU

Returns

outputs, acc, net_time, gross_time networks predictions, accuracy calculation, forward pass net time, function gross time

property get_arch_params
property get_structure
property get_architecture
set_experiment_name(experiment_name)[source]
property get_module
set_module(module)[source]
test(test_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, loss: Optional[torch.nn.modules.loss._Loss] = None, silent_mode: bool = False, test_metrics_list=None, loss_logging_items_names=None, metrics_progress_verbose=False, test_phase_callbacks=None, use_ema_net=True)tuple[source]

Evaluates the model on given dataloader and metrics.

Parameters
  • test_loader – dataloader to perform test on.

  • test_metrics_list – (list(torchmetrics.Metric)) metrics list for evaluation.

  • silent_mode – (bool) controls verbosity

  • metrics_progress_verbose – (bool) controls the verbosity of metrics progress (default=False). Slows down the program.

:param use_ema_net (bool) whether to perform test on self.ema_model.ema (when self.ema_model.ema exists,

otherwise self.net will be tested) (default=True)

Returns

results tuple (tuple) containing the loss items and metric values.

All of the above args will override SgModel’s corresponding attribute when not equal to None. Then evaluation

is ran on self.test_loader with self.test_metrics.

evaluate(data_loader: torch.utils.data.dataloader.DataLoader, metrics: torchmetrics.collections.MetricCollection, evaluation_type: super_gradients.common.data_types.enum.evaluation_type.EvaluationType, epoch: Optional[int] = None, silent_mode: bool = False, metrics_progress_verbose: bool = False)[source]

Evaluates the model on given dataloader and metrics.

Parameters
  • data_loader – dataloader to perform evaluataion on

  • metrics – (MetricCollection) metrics for evaluation

  • evaluation_type – (EvaluationType) controls which phase callbacks will be used (for example, on batch end, when evaluation_type=EvaluationType.VALIDATION the Phase.VALIDATION_BATCH_END callbacks will be triggered)

  • epoch – (int) epoch idx

  • silent_mode – (bool) controls verbosity

  • metrics_progress_verbose – (bool) controls the verbosity of metrics progress (default=False). Slows down the program significantly.

Returns

results tuple (tuple) containing the loss items and metric values.

property get_net

Getter for network. :return: torch.nn.Module, self.net

set_net(net: torch.nn.modules.module.Module)[source]

Setter for network.

Parameters

net – torch.nn.Module, value to set net

Returns

set_ckpt_best_name(ckpt_best_name)[source]

Setter for best checkpoint filename.

Parameters

ckpt_best_name – str, value to set ckpt_best_name

set_ema(val: bool)[source]

Setter for self.ema

Parameters

val – bool, value to set ema

class super_gradients.training.sg_model.MultiGPUMode(value)[source]

Bases: str, enum.Enum

OFF                       - Single GPU Mode / CPU Mode
DATA_PARALLEL             - Multiple GPUs, Synchronous
DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous
OFF = 'Off'
DATA_PARALLEL = 'DP'
DISTRIBUTED_DATA_PARALLEL = 'DDP'
AUTO = 'AUTO'
class super_gradients.training.sg_model.StrictLoad(value)[source]

Bases: enum.Enum

Wrapper for adding more functionality to torch’s strict_load parameter in load_state_dict().
Attributes:

OFF - Native torch “strict_load = off” behaviour. See nn.Module.load_state_dict() documentation for more details. ON - Native torch “strict_load = on” behaviour. See nn.Module.load_state_dict() documentation for more details. NO_KEY_MATCHING - Allows the usage of SuperGradient’s adapt_checkpoint function, which loads a checkpoint by matching each

layer’s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).

OFF = False
ON = True
NO_KEY_MATCHING = 'no_key_matching'

super_gradients.training.utils module

class super_gradients.training.utils.Timer(device: str)[source]

Bases: object

A class to measure time handling both GPU & CPU processes Returns time in milliseconds

start()[source]
stop()[source]
class super_gradients.training.utils.HpmStruct(**entries)[source]

Bases: object

set_schema(schema: dict)[source]
override(**entries)[source]
to_dict()[source]
validate()[source]

Validate the current dict values according to the provided schema :raises

AttributeError if schema was not set jsonschema.exceptions.ValidationError if the instance is invalid jsonschema.exceptions.SchemaError if the schema itselfis invalid

class super_gradients.training.utils.WrappedModel(module)[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
super_gradients.training.utils.convert_to_tensor(array)[source]

Converts numpy arrays and lists to Torch tensors before calculation losses :param array: torch.tensor / Numpy array / List

super_gradients.training.utils.get_param(params, name, default_val=None)[source]

Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val. In case the default_val is of type dictionary, and a value is found in the params - the function will return the default value dictionary with internal values overridden by the found value

i.e. default_opt_params = {‘lr’:0.1, ‘momentum’:0.99, ‘alpha’:0.001} training_params = {‘optimizer_params’: {‘lr’:0.0001}, ‘batch’: 32 …. } get_param(training_params, name=’optimizer_params’, default_val=default_opt_params) will return {‘lr’:0.0001, ‘momentum’:0.99, ‘alpha’:0.001}

Parameters
  • params – an object (typically HpmStruct) or a dict holding the params

  • name – name of the searched parameter

  • default_val – assumed to be the same type as the value searched in the params

Returns

the found value, or default if not found

super_gradients.training.utils.tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], device: str, non_blocking=True)[source]
recursively send compounded objects to device (sending all tensors to device and maintaining structure)

:param obj the object to send to device (list / tuple / tensor / dict) :param device: device to send the tensors to :param non_blocking: used for DistributedDataParallel :returns an object with the same structure (tensors, lists, tuples) with the device pointers (like

the return value of Tensor.to(device)

super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: Optional[callable] = None)[source]

Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit the ckpt in order to properly load the weights into the model. If unsuccessful - returns None

param model_state_dict

the model state_dict

param source_ckpt

checkpoint dict

:param exclude optional list for excluded layers :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)

that returns a desired weight for ckpt_val.

return

renamed checkpoint dict (if possible)

super_gradients.training.utils.raise_informative_runtime_error(state_dict, checkpoint, exception_msg)[source]

Given a model state dict and source checkpoints, the method calls “adapt_state_dict_to_fit_model_layer_names” and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible

super_gradients.training.utils.random_seed(is_ddp, device, seed)[source]

Sets random seed of numpy, torch and random.

When using ddp a seed will be set for each process according to its local rank derived from the device number. :param is_ddp: bool, will set different random seed for each process when using ddp. :param device: ‘cuda’,’cpu’, ‘cuda:<device_number>’ :param seed: int, random seed to be set

Module contents