pruning
- ledidi.pruning.greedy_pruning(model, X, X_hat, threshold=1, target=None, verbose=False)
A method for pruning edits to remove those that are irrelevant.
This method will greedily go through all of the proposed edits and evaluate the effect of removing them, one at a time. As a greedy method, this will iteratively scan over all edits and remove the one with the smallest change in model output assuming that change is below the predefined threshold. Once the change in output from the edit with the smallest change is above the threshold, the procedure will stop and return the remaining edits.
Note: Only one sequence is pruned at a time.
Parameters
- model: torch.nn.Module
A PyTorch model used to evaluate the edits.
- X: torch.tensor, shape=(1, d, length)
A tensor where the second dimension is the number of categories (e.g., 4 for DNA) and the third dimension is the length of the sequence, and is one-hot encoded.
- X_hat: torch.tensor, shape=(1, d, length)
A tensor of the same shape as X except that it contains the proposed edits.
- threshold: float, optional
A threshold on the maximum change in model output that removing an edit can have. Default is 1.
- target: int or None
When given a multi-task model, the target to slice out and feed into output_loss when calculating the gradient. If None, perform no slicing. Default is None.
- verbose: bool, optional
Whether to print out the index and delta at each iteration.
Returns
- X_m: torch.tensor, shape=(1, d, length)
A tensor of the same shape as X_hat except with some of the edits reverted back to what they were in X.