torch.nn.utils.prune.random_unstructured¶
- torch.nn.utils.prune.random_unstructured(module, name, amount)[source]¶
Prunes tensor corresponding to parameter called
name
inmodule
by removing the specifiedamount
of (currently unpruned) units selected at random. Modifies module in place (and also return the modified module) by:adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method.replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
- Parameters:
module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
module
on which pruning will act.amount (int or float) – quantity of parameters to prune. If
float
, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. Ifint
, it represents the absolute number of parameters to prune.
- Returns:
modified (i.e. pruned) version of the input module
- Return type:
module (nn.Module)
Examples
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) >>> torch.sum(m.weight_mask == 0) tensor(1)