Custom Loss Functions

Ledidi optimizes an objective function that is made up of two terms: the input loss, and the output loss. The input loss is, by default, an L1 loss between the original sequence and the proposed sequence and so effectively counts the number of edits that have been made. The output loss is, by default, an L2 loss between the model predictions given the edited sequence and the desired output from the model. Combined, this objective function tries to minimize the number of edits made in order to achieve a desired prediction from the model.

However, the losses do not have to be these defaults. In general, using L1 as the input loss probably makes the most conceptual sense given the formulation of Ledidi, but there may be alternatives with better convergence properties. We anticipate that the most creativity will come from using custom loss functions for the output loss, where a great deal of flexibility can be used to control the design process or work with models with interesting outputs.

Handling models with structured outputs

In most of our evaluations, we have either used models whose output takes the form of a count (e.g., Beluga, BPNet) or models whose output can be easily converted into something that acts like a count (e.g., Enformer, Borzoi). This is mostly because it is easy to work with models whose outputs are like this. But not all models can easily have their output converted into something that acts as a count, and even when they can, sometimes this eliminates useful information that is in the output.

For example, BPNet models make predictions for both the total counts in the region but also the basepair resolution profiles of where these counts are predicted to map to. These profiles take the form of probability distributions where the output correspond to either the probability of a read mapping to each basepair, of something like an unnormalized logit that can be easily converted to such a profile. We have been discarding these profiles because they are harder to work with, but they also provide us with the potential to finely control where binding sites are inserted as opposed to just how many.

Let’s take a look. Here, we will quickly define a wrapper that pulls out the predicted profile instead of the predicted total counts, as we have been using before. Because the BPNet objects from bpnetlite predict logits by default, we will quickly convert those to probabilities in a numerically stable way using log_softmax. Note that the ProfileWrapper object that is built into bpnet-lite does not simply return the profile, but rather returns a conversion of the profile into a single number for use in attribution algorithms. tl;dr, do not simply use that one.

[1]:
import torch

from bpnetlite.bpnet import ControlWrapper

class ProfileWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ProfileWrapper, self).__init__()
        self.model = model

    def forward(self, X):
        y = self.model(X)[0]
        return torch.exp(torch.nn.functional.log_softmax(y, dim=-1))


model = torch.load("../../../../models/bpnet/GATA2.torch", weights_only=False)
model = ProfileWrapper(ControlWrapper(model))

Now that we have our model, let’s generate a random sequence and apply our BPNet model to it to get a sense for what the profile predictions look like.

[2]:
from tangermeme.utils import random_one_hot
from tangermeme.predict import predict

X = random_one_hot((1, 4, 2114), random_state=0).float()
y_hat = predict(model, X)
[3]:
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn; seaborn.set_style('whitegrid')

plt.figure(figsize=(8, 3))
plt.plot(y_hat[0, 0], c='0.3', label="fwd strand")
plt.plot(y_hat[0, 1], c='0.7', label="bwd strand")

plt.title("Original BNet Predictions")
plt.xlabel("Genomic Position")
plt.ylabel("Predicted Probability")

plt.xlim(0, 1000)
plt.legend()

seaborn.despine(bottom=True, left=True)
plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_6_-_Custom_Loss_Functions_5_0.png

There are a few things to make note of here. Of course, the first is that we are no longer dealing with a single number but rather one number per position in the 1kbp output. But it’s not actually one number, it’s two numbers, because the profile predictions from BPNet are stranded. An important property here is that the predictions will be offset from each other because BPNet was trained on fragment starts but the binding event happens in the middle of the fragment. This means that reads mapping to the + strand will map to the left of the actual binding event and reads mapping to the - strand will map to the right of the actual binding event and that the distance to the binding event is proportional to the fragment length, which is expertiment-specific. Finally, each strand has been normalized separately and the predicted probabilities for each strand sum to 1.

These factors matter when creating the desired output from the model. We can begin by specifying a tensor in a conceptually similar manner to when we worked with multiple outputs. After all, one can view predictions for each basepair as basically the same thing as predictions from multiple models. But when we set the desired output, we need to keep in mind an appropriate magnitude for these predictions (a value of 2 at a certain basepair is infeasible if the sum across the region has to be equal to 1) and also that we expect the peaks to be offset across the two strands.

Fortunately, these are both somewhat easy to handle. We can just specify where we want a peak to be and calculate spans from there, making sure to divide by the total at the end so the desired output sums to 1 across each strand.

[4]:
mid = 250

y_bar = torch.zeros(1, 2, 1000).float()
y_bar[:, 0, mid-150:mid+50] = 1
y_bar[:, 1, mid-50 :mid+150] = 1
y_bar /= y_bar.sum(dim=-1, keepdims=True)

Given this desired output we can just plop it into Ledidi! As a technical note, given how small the predicted and desired values are, we would need to drop lambda by quite a bit to get the input and output losses to be on the same scale. For the purpose of demonstration we will just drop lambda.

[5]:
from ledidi import ledidi

X_hat = ledidi(model, X, y_bar, l=0.0, verbose=True)
y_hat = predict(model, X_hat)
iter=I  input_loss=0.0  output_loss=4.414e-06   total_loss=4.414e-06    time=0.0
iter=100        input_loss=38.38        output_loss=1.457e-06   total_loss=1.457e-06    time=2.501
iter=200        input_loss=39.56        output_loss=9.372e-07   total_loss=9.372e-07    time=1.196
iter=300        input_loss=38.75        output_loss=8.945e-07   total_loss=8.945e-07    time=1.195
iter=400        input_loss=40.12        output_loss=8.436e-07   total_loss=8.436e-07    time=1.195
iter=F  input_loss=42.19        output_loss=8.31e-07    total_loss=8.31e-07     time=7.136

We did not need to do anything particularly special here to work with the profiles from Ledidi. Let’s take a look at the design, noting where we wanted the peaks to happen.

[6]:
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn; seaborn.set_style('whitegrid')

plt.figure(figsize=(8, 3))
plt.plot(y_hat[0, 0], c='0.3', label="fwd strand")
plt.plot(y_hat[0, 1], c='0.7', label="bwd strand")

plt.plot([100, 300], [0.0090, 0.0090], c='0.3', linestyle='--', linewidth=2)
plt.plot([200, 400], [0.0087, 0.0087], c='0.7', linestyle='--', linewidth=2)

plt.title("Ledidi Design using MSE")
plt.xlabel("Genomic Position")
plt.ylabel("Predicted Probability")

plt.xlim(0, 1000)
plt.legend()

seaborn.despine(bottom=True, left=True)
plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_6_-_Custom_Loss_Functions_11_0.png

Great! Looks like we have peak predictions at the desired locations.

But is MSE the best loss to optimize here? BPNet intentionally chooses to use a different loss when optimizing the profile predictions because MSE has been shown to have weird artifacts when working with probabilities, like this. For example, note how much density of the predictions falls out of the desired range.

Maybe we can do better if we use a loss function that is explicitly meant to operate on probability distributions. KL-divergence has been shown to work well in this setting. To use a custom loss function, all we need to do is pass in a function to the output_loss parameter with the signature f(y_hat, y_bar) where y_bar is the desired output and y_hat is the predicted output given the current edited sequence. As a technical note, we are wrapping it a little bit because the implementation we are using reports the KL divergence for each example but we need the average divergence across all examples.

[7]:
from bpnetlite.performance import _kl_divergence

kld = lambda x, y: _kl_divergence(y, x).mean()

X_hat = ledidi(model, X, y_bar, l=0.0, output_loss=kld, verbose=True)
iter=I  input_loss=0.0  output_loss=1.755       total_loss=1.755        time=0.0
iter=100        input_loss=751.0        output_loss=0.1578      total_loss=0.1578       time=1.26
iter=200        input_loss=702.8        output_loss=0.1326      total_loss=0.1326       time=1.282
iter=300        input_loss=760.6        output_loss=0.1126      total_loss=0.1126       time=1.258
iter=400        input_loss=762.8        output_loss=0.09138     total_loss=0.09138      time=1.258
iter=500        input_loss=776.6        output_loss=0.08533     total_loss=0.08533      time=1.262
iter=600        input_loss=793.1        output_loss=0.07705     total_loss=0.07705      time=1.254
iter=700        input_loss=800.4        output_loss=0.07333     total_loss=0.07333      time=1.256
iter=800        input_loss=823.8        output_loss=0.07241     total_loss=0.07241      time=1.255
iter=F  input_loss=807.0        output_loss=0.06802     total_loss=0.06802      time=10.29
[8]:
y_hat = predict(model, X_hat)

plt.figure(figsize=(6, 3))
plt.plot(y_hat[0, 0], c='0.3', label="fwd strand")
plt.plot(y_hat[0, 1], c='0.7', label="bwd strand")

plt.title("Ledidi Design using KL divergence")
plt.xlabel("Genomic Position")
plt.ylabel("Predicted Probability")

plt.plot([100, 300], [0.0090, 0.0090], c='0.3', linestyle='--', linewidth=2)
plt.plot([200, 400], [0.0087, 0.0087], c='0.7', linestyle='--', linewidth=2)

plt.xlim(0, 1000)
plt.legend()

seaborn.despine(bottom=True, left=True)
plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_6_-_Custom_Loss_Functions_14_0.png

That looks remarkably sharper. Not only are both peaks at basically the same height, but they sharply decrease in predictions at the desired edges. It is likely, especially given the number of edits necessary to make these profiles, that the designed sequence will look sort of weird. But this is just a demonstration of (1) how to use a custom loss fnction and (2) the differences in output one might expect when using a loss function more tailored for the model output.

Let’s try to take this one step further. Can we design edits that create two peaks where one peak is much stronger than the other?

[9]:
y_bar = torch.zeros(1, 2, 1000).float()
y_bar[:, 0, 100:300] = 1
y_bar[:, 1, 200:400] = 1
y_bar[:, 0, 600:800] = 3
y_bar[:, 1, 700:900] = 3
y_bar /= y_bar.sum(dim=-1, keepdims=True)

X_hat = ledidi(model, X, y_bar, l=0.0, output_loss=kld, verbose=True)
iter=I  input_loss=0.0  output_loss=1.398       total_loss=1.398        time=0.0
iter=100        input_loss=718.0        output_loss=0.1765      total_loss=0.1765       time=1.259
iter=200        input_loss=722.9        output_loss=0.1369      total_loss=0.1369       time=1.259
iter=300        input_loss=748.1        output_loss=0.1254      total_loss=0.1254       time=1.258
iter=400        input_loss=778.4        output_loss=0.1056      total_loss=0.1056       time=1.256
iter=500        input_loss=757.3        output_loss=0.1188      total_loss=0.1188       time=1.257
iter=F  input_loss=763.6        output_loss=0.09379     total_loss=0.09379      time=7.296
[10]:
y_hat = predict(model, X_hat)

plt.figure(figsize=(8, 3))
plt.plot(y_hat[0, 0], c='0.3', label="fwd strand")
plt.plot(y_hat[0, 1], c='0.7', label="bwd strand")

plt.title("Ledidi Design using KL divergence")
plt.xlabel("Genomic Position")
plt.ylabel("Predicted Probability")

plt.xlim(0, 1000)
plt.legend()

seaborn.despine(bottom=True, left=True)
plt.tight_layout()
plt.show()
../_images/tutorials_Tutorial_6_-_Custom_Loss_Functions_17_0.png

Looks like the answer is yes! Just because we have changed the output function to be KL divergence instead of MSE does not mean that we lose any flexibility to design edits with whatever properties we would like.

Unconstrained Rewards

A downside of using MSE in the multi-output setting is that we have to have a precise target value that we want to match for each of the outputs. Sometimes, this is the case and so using MSE is not a problem. However, other times, we may not have or even be able to reasonably guess what the output should be from each of the models. In these cases, we usually have a general sense that we would like designs that generally increase or decrease predictions from the models but do not know what the exact values should be.

As a concrete example, consider the design of accessible sites whose accessibility is driven by a specific TF. In this setting, we likely have a target accessibility value (e.g., choosing a target that makes the region in the top percentile of accessible sites). But we may not know what the corresponding predictions should be for the TF model. This is because the dynamic range for each TF model will depend on the read depth and quality of the data used to train it and also because the effect on accessibility of the binding of each TF differs based on the structure and activity of the protein. For example, a pioneer factor will open chromatin significantly more than a “settler” factor. Trying to estimate the connection between TF binding and accessibility and then converting the desired accessibility value into a TF binding value is tricky: set the value too low and other TFs will be used to drive accessibility, set the value too high and suddenly you have an infeasible objective where either accessibility will be too high or TF binding will be too low.

A solution to this problem is to have an output loss function where some of the outputs have target values and other outputs are just rewarded for being higher or lower without constraint. In the previous example, we set a target accessibility value and then just reward higher predictions from the TF binding model without requiring that they meet a certain value. One way of viewing this is that there are many ways to design an accessible site and this approach pushes the designer towards those that utilize the binding of a specific TF.

To see this in action, let’s load up a ChromBPNet model and a BPNet model that makes predictions for the binding of MAX.

[11]:
from bpnetlite import BPNet
from bpnetlite.bpnet import CountWrapper

X = random_one_hot((1, 4, 2114), random_state=0).float()

chrombpnet = BPNet.from_chrombpnet("../../../../models/chrombpnet/fold_0/model.chrombpnet_nobias.fold_0.ENCSR868FGK.h5")
chrombpnet = CountWrapper(chrombpnet).cuda()

bpnet_max = torch.load("../../../../models/bpnet/MAX.torch", weights_only=False)
bpnet_max = CountWrapper(ControlWrapper(bpnet_max))

We can combine the two models and use them with Ledidi to design edits, just as we did in the multi-output tutorial.

[12]:
from ledidi.wrappers import DesignWrapper

designer = DesignWrapper([chrombpnet, bpnet_max])

y_bar = torch.tensor([[13.0, 7.0]])

X_bar = ledidi(designer, X, y_bar, verbose=True)
y_hat = predict(designer, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=26.07       total_loss=26.07        time=0.0
iter=100        input_loss=108.4        output_loss=1.235       total_loss=12.08        time=46.12
iter=200        input_loss=121.2        output_loss=0.6067      total_loss=12.73        time=24.36
iter=300        input_loss=114.0        output_loss=0.773       total_loss=12.17        time=24.37
iter=F  input_loss=99.06        output_loss=0.8359      total_loss=10.74        time=115.3
[12]:
tensor([12.4062,  5.8811])

Looks like it is able to generate edits that gets close to the two targets.

Now, let’s consider the alternate loss function. In this loss, we want to meet a desired output but we will also reward increases in MAX prediction. There are a few ones that one could implement this but probably the simplest is a function that takes in a y_hat tensor with two elements (accessibility and MAX predictions) and a y_bar tensor with one element (just the desired accessibility output), calculate a MSE loss on the accessibility component, and simply reward higher predictions from the MAX model.

[13]:
def mse_max_loss(y_hat, y_bar):
    return torch.nn.MSELoss()(y_hat[: :1], y_bar) - y_hat[:, 1].mean()

A convenient aspect of framing the loss like this is that we do not have to make up some values for the “desired output” of MAX which we would then subsequently ignore. Rather, we are only passing in desired values for components that we actually want to control. Note that we are subtracting the MAX predictions (the second column in y_hat) because Ledidi minimizes the output loss and we want to maximize these predictions. Also note that we do not have to use the built-in MSE loss for the first component. As long as the function is differentiable we can do whatever we want. We are just using it here for convneience.

After specifying the output loss function, we can just pass that in like we did before.

[14]:
X_bar = ledidi(designer, X, y_bar[:, :1], output_loss=mse_max_loss, verbose=True)

y_hat = predict(designer, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=76.21       total_loss=76.21        time=0.0
/users/jacob.schreiber/anaconda3/lib/python3.12/site-packages/torch/nn/modules/loss.py:608: UserWarning: Using a target size (torch.Size([1, 1])) that is different to the input size (torch.Size([1, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
/users/jacob.schreiber/anaconda3/lib/python3.12/site-packages/torch/nn/modules/loss.py:608: UserWarning: Using a target size (torch.Size([16, 1])) that is different to the input size (torch.Size([16, 2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
iter=100        input_loss=326.2        output_loss=1.18        total_loss=33.81        time=24.38
iter=200        input_loss=360.9        output_loss=-4.321      total_loss=31.77        time=24.39
iter=300        input_loss=354.4        output_loss=-4.972      total_loss=30.47        time=24.37
iter=400        input_loss=347.4        output_loss=-5.126      total_loss=29.62        time=24.38
iter=500        input_loss=344.4        output_loss=-6.973      total_loss=27.47        time=24.38
iter=600        input_loss=335.4        output_loss=-7.331      total_loss=26.21        time=24.38
iter=F  input_loss=329.9        output_loss=-7.888      total_loss=25.11        time=164.6
[14]:
tensor([12.3439, 10.7223])

By using this loss function we are able to get much higher predictions of MAX while still getting the same accessibility predictions. Because the binding of many TFs (including MAX) subsequently cause an increase in accessibility, one can infer that having much higher predictions for MAX mean that other TFs must be binding less to achieve the same accessibility. Although we will not go deeper into it here, one can augment this approach by additionally using models for any of TF that might be relevant and rewarding low predictions from them. Together, that would yield designs that explicitly want higher MAX predictions and lower predictions of other TFs.

Cell Type-Specific Design and Min Gap

A common goal when designing elements is trying to create cell type-specific activity. This can be achieved in the default setting by setting high predicted values in the cell type(s) that one wants the element to be active in and low values in the cell type(s) that the element should not be active in. However, there can be two issues with simply using MSE to design specific elements like this. First, because MSE has a smoothing property (like we saw at the beginning of the notebok) is may not yield elements that are truly as specific as one might like. Second, an increasing amount of recent work has shown that the most cell type-specific elements are those that only exhibit weak signal and that making these regions stronger might risk making them less specific. Put another way: the target value for a cell type-specific element might be on a different range than the normal dynamic range of the model.

A potential solution to this wa proposed by Gosai et al. in what they called the “min gap.” Essentially, the loss function involves maximizing the gap between (1) the minimum on-target value and (2) the maximum off-target value. A strength of this loss function is that it gets around the smoothing property of MSE. It does not matter if the elements are off in 9 out of 10 off-target cell lines. Each component of the loss is only as good as the worse element in that group.

Let’s demonstrate this using the Malinois model from the same work.

[15]:
from boda.model import BassetBranched

checkpoint = torch.load("../../../../models/malinois/torch_checkpoint.pt", weights_only=False)
malinois = BassetBranched(**vars(checkpoint['model_hparams']))
malinois.load_state_dict(checkpoint['model_state_dict'])

X = random_one_hot((1, 4, 600), random_state=0).float()

We can load the MinGap loss from ledidi.losses. Because there are two groups – the on-target and the off-target outputs – we have to pass in a binary vector indicating which elements are part of the on-target group. To support any composition of classes, MinGap is a torch.nn.Module instance that you pass a binary vector into and whose forward function is defned as the loss function. Here, let’s say that we want to design edits that turn the sequence into something specific to the second output.

[16]:
from ledidi.losses import MinGap

in_mask = torch.tensor([False, True, False])
output_loss = MinGap(in_mask)

An important aspect of the min gap loss is that there are no target values. This has two practical ramifications. First, there is nothing saying that the predictions have to be high in the on-target cell lines and low in the off-target ones. Because the loss only cares about maximizing the loss, both values could in theory be negative so long as the gap between them is high. Second, Ledidi’s implementation expects something for y_bar even if that gets ignored by the loss. Here, we have to just make up some values but they do not get used by min gap at all

[17]:
y_bar = torch.tensor([[0, 0, 0]])
X_bar = ledidi(malinois, X, y_bar, output_loss=output_loss, verbose=True)

y_hat = predict(malinois, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=-0.1533     total_loss=-0.1533      time=0.0
iter=100        input_loss=92.38        output_loss=-21.22      total_loss=-11.98       time=0.8177
iter=200        input_loss=117.4        output_loss=-37.31      total_loss=-25.57       time=0.6526
iter=300        input_loss=123.6        output_loss=-38.47      total_loss=-26.11       time=0.6503
iter=400        input_loss=124.4        output_loss=-40.66      total_loss=-28.23       time=0.6506
iter=500        input_loss=120.2        output_loss=-41.29      total_loss=-29.27       time=0.6505
iter=600        input_loss=122.0        output_loss=-41.84      total_loss=-29.64       time=0.6471
iter=700        input_loss=124.6        output_loss=-42.18      total_loss=-29.72       time=0.6456
iter=800        input_loss=128.4        output_loss=-43.39      total_loss=-30.55       time=0.6451
iter=900        input_loss=126.2        output_loss=-42.93      total_loss=-30.31       time=0.6477
iter=F  input_loss=125.4        output_loss=-43.88      total_loss=-31.33       time=6.318
[17]:
tensor([ 2.5212, 46.3993,  2.4737])

Looks like the predictions from the second output are significantly higher than the predictions from the other two outputs.

To show how general-purpose this loss is, let’s say we want to make elements that are active in the first two but not the last output.

[18]:
y_bar = torch.tensor([[0, 0, 0]])
X_bar = ledidi(malinois, X, y_bar, output_loss=MinGap(torch.tensor([True, True, False])), verbose=True)

y_hat = predict(malinois, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=-0.2093     total_loss=-0.2093      time=0.0
iter=100        input_loss=66.06        output_loss=-12.86      total_loss=-6.255       time=0.6808
iter=200        input_loss=77.12        output_loss=-14.54      total_loss=-6.827       time=0.7468
iter=300        input_loss=85.38        output_loss=-15.81      total_loss=-7.273       time=0.6467
iter=F  input_loss=82.06        output_loss=-15.91      total_loss=-7.704       time=2.674
[18]:
tensor([16.6600, 16.7327,  0.7471])

Assay-specific Design and Min Gap

As a side-note about the min gap loss, you do not have to use it only for cell type-specific design. More broadly, it can be used for output-specific design regardless of what those outputs are. If we combine two of our previous examples, let’s say that we want to design elements that are TF specific. We can apply the min gap across TFs by designating them as either being on-target or off-target and design elements where the predictions should only be high for the on-target outputs.

[19]:
bpnet_ctcf = torch.load("../../../../models/bpnet/CTCF.torch", weights_only=False)
bpnet_ctcf = CountWrapper(ControlWrapper(bpnet_ctcf))

bpnet_e2f6 = torch.load("../../../../models/bpnet/E2F6.torch", weights_only=False)
bpnet_e2f6 = CountWrapper(ControlWrapper(bpnet_e2f6))

bpnet_junb = torch.load("../../../../models/bpnet/JUNB.torch", weights_only=False)
bpnet_junb = CountWrapper(ControlWrapper(bpnet_junb))

designer = DesignWrapper([bpnet_max, bpnet_ctcf, bpnet_e2f6, bpnet_junb])

Just as before we can specify a min gap loss according to these four outputs. If we want to try designing something that is CTCF-specific all we need to do is set it to be the in-group and everything else to be the out-group.

[20]:
X = random_one_hot((1, 4, 2114), random_state=0).float()

y_bar = torch.tensor([[0, 0, 0]])
X_bar = ledidi(designer, X, y_bar, l=0.001, output_loss=MinGap(torch.tensor([False, True, False, False])), verbose=True)

y_hat = predict(designer, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=4.06        total_loss=4.06 time=0.0
iter=100        input_loss=744.4        output_loss=-4.478      total_loss=-3.734       time=5.709
iter=200        input_loss=868.8        output_loss=-6.906      total_loss=-6.038       time=5.703
iter=300        input_loss=869.7        output_loss=-7.69       total_loss=-6.82        time=5.703
iter=400        input_loss=878.2        output_loss=-8.609      total_loss=-7.731       time=5.704
iter=500        input_loss=911.9        output_loss=-9.325      total_loss=-8.414       time=5.717
iter=600        input_loss=932.3        output_loss=-9.352      total_loss=-8.42        time=5.703
iter=700        input_loss=970.6        output_loss=-9.956      total_loss=-8.985       time=5.702
iter=800        input_loss=996.0        output_loss=-9.996      total_loss=-9.0 time=5.706
iter=F  input_loss=1.005e+03    output_loss=-10.18      total_loss=-9.174       time=50.97
[20]:
tensor([ 1.3769, 14.2045,  3.2331,  4.0257])

Looks like we are getting very strong predictions from the CTCF model and lower predictions from everything else.

We can also specify groups of proteins that we would like to have high values just by setting several values to True.

[21]:
y_bar = torch.tensor([[0, 0, 0]])
X_bar = ledidi(designer, X, y_bar, l=0.001, output_loss=MinGap(torch.tensor([False, False, True, True])), verbose=True)

y_hat = predict(designer, X_bar)
y_hat.mean(axis=0)
iter=I  input_loss=0.0  output_loss=-0.9792     total_loss=-0.9792      time=0.0
iter=100        input_loss=422.1        output_loss=-5.737      total_loss=-5.315       time=5.718
iter=200        input_loss=548.4        output_loss=-6.782      total_loss=-6.233       time=5.709
iter=300        input_loss=598.2        output_loss=-7.302      total_loss=-6.704       time=5.706
iter=400        input_loss=625.4        output_loss=-7.944      total_loss=-7.318       time=5.711
iter=500        input_loss=655.7        output_loss=-8.127      total_loss=-7.472       time=5.705
iter=600        input_loss=670.2        output_loss=-8.68       total_loss=-8.009       time=5.706
iter=700        input_loss=663.0        output_loss=-8.955      total_loss=-8.292       time=5.711
iter=800        input_loss=660.1        output_loss=-9.117      total_loss=-8.457       time=5.708
iter=900        input_loss=661.3        output_loss=-9.458      total_loss=-8.796       time=5.707
iter=1000       input_loss=679.2        output_loss=-9.66       total_loss=-8.981       time=5.707
iter=F  input_loss=679.2        output_loss=-9.66       total_loss=-8.981       time=57.09
[21]:
tensor([-0.9855, -0.9598,  8.7180,  8.7110])

A note when using min gap in the multi-model setting is making sure that the dynamic range of predicted values between the models is comparable. This is important because it will be challenging (if not impossible) to get an off-target model to predict a smaller value than an on-target model when the range of the off-target model is between 10 and 20 and the range of the on-target model is between 0 and 4. Because experimental artfiacts such as read depth and quality will influence the range of the predictions, one may wish to carefully inspect the models and potentially even z-score normalize the predictions to circumvent these potential issues.

Custom Loss for Less Specificity

A point that we have repeated about using MSE is that you have to set specific target values and that figuring out what these values should be can be challenging. One way to reduce the impact of this decision is to use a loss function that is more resistant to the precise target value and focuses more on getting in the right ballpark. As an example, we can define a custom “ballpark loss” that says a perfect match is anything within a radius of the target value. By using this loss, if you specify a desired value of 7.5, a predicted value of 7.2 is just as good as 7.5 – you just want the predictions to be in the ballpark.

[22]:
def ballpark_loss(y_hat, y_bar, z=1):
    return torch.maximum(torch.nn.MSELoss(reduction='none')(y_hat, y_bar) - z, torch.zeros_like(y_hat))

We can take a look at the loss directly to see what the different curves look like.

[23]:
y = torch.tensor([5.0])
x = torch.arange(0, 10.1, 0.1)

loss1 = torch.nn.MSELoss(reduction='none')(x, y)
loss2 = ballpark_loss(x, y, 0.5).numpy(force=True)
loss3 = ballpark_loss(x, y, 1).numpy(force=True)
loss4 = ballpark_loss(x, y, 2).numpy(force=True)

plt.title("Loss Functions (Target=5.0)")
plt.plot(x, loss1, label="MSE Loss")
plt.plot(x, loss2, label="Ballpark Loss (0.5)")
plt.plot(x, loss3, label="Ballpark Loss (1.0)")
plt.plot(x, loss4, label="Ballpark Loss (2.0)")
plt.xlabel("Prediction")
plt.ylabel("Loss")

plt.legend()
seaborn.despine(bottom=True, left=True)
plt.tight_layout()
plt.show()
/users/jacob.schreiber/anaconda3/lib/python3.12/site-packages/torch/nn/modules/loss.py:608: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([101])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
../_images/tutorials_Tutorial_6_-_Custom_Loss_Functions_46_1.png

We can just plug the loss directly into Ledidi with the small modification that we need to take the mean to get a single number.

[24]:
def ballpark_loss(y_hat, y_bar, z=1):
    return torch.maximum(torch.nn.MSELoss(reduction='none')(y_hat, y_bar) - z, torch.zeros_like(y_hat)).mean()

y_bar = torch.tensor([[7.0]])
X_bar = ledidi(bpnet_ctcf, X, y_bar, output_loss=ballpark_loss, verbose=True)

y_hat = predict(bpnet_ctcf, X_bar)
y_hat.mean()
iter=I  input_loss=0.0  output_loss=44.93       total_loss=44.93        time=0.0
iter=100        input_loss=85.56        output_loss=1.468       total_loss=10.02        time= 1.2
iter=200        input_loss=93.56        output_loss=0.1706      total_loss=9.527        time=1.204
iter=300        input_loss=52.75        output_loss=0.1505      total_loss=5.425        time=1.201
iter=400        input_loss=48.0 output_loss=0.04736     total_loss=4.847        time=1.184
iter=500        input_loss=41.31        output_loss=0.007231    total_loss=4.138        time=1.182
iter=600        input_loss=43.44        output_loss=0.05472     total_loss=4.398        time=1.185
iter=700        input_loss=39.62        output_loss=0.1141      total_loss=4.077        time=1.183
iter=F  input_loss=34.44        output_loss=0.05474     total_loss=3.498        time=9.109
[24]:
tensor(5.9989)

Looks like we get a value near 6. This is not particularly surprising because we had a target value of 7 and are giving it slack to be anywhere between 6 and 8. Because we are also penalizing edits, we end up near the lower end of the desired range because that requires making the fewest edits.

One-Directional Loss

Potentially, when specifying your loss, what you really want are designs that are that strong or more extreme. Basically, you want to make sure that the designs are at least some minimum strength but if the procedure is able to be better than that, even better! There are a few ways that this can be specified. First, we can just say that we will not penalize designs that are greater than the target value.

[25]:
def one_directional_loss(y_hat, y_bar):
    return torch.where(y_hat < y_bar, torch.nn.MSELoss(reduction='none')(y_hat, y_bar), torch.zeros_like(y_hat)).mean()


y_bar = torch.tensor([[7.0]])
X_bar = ledidi(bpnet_ctcf, X, y_bar, l=0, output_loss=one_directional_loss, verbose=True)

y_hat = predict(bpnet_ctcf, X_bar)
y_hat.mean()
iter=I  input_loss=0.0  output_loss=45.93       total_loss=45.93        time=0.0
iter=100        input_loss=808.9        output_loss=0.7674      total_loss=0.7674       time=1.184
iter=200        input_loss=563.3        output_loss=0.0157      total_loss=0.0157       time=1.181
iter=F  input_loss=795.7        output_loss= 0.0        total_loss= 0.0 time=2.884
[25]:
tensor(7.3477)

Another way that we could do this is to offer a reward for being above the desired output that is not as strong as the penalty for being below it.

[26]:
def one_directional_loss(y_hat, y_bar):
    mse = torch.nn.MSELoss(reduction='none')(y_hat, y_bar)
    return torch.where(y_hat < y_bar, mse, -0.1 * mse).mean()


y_bar = torch.tensor([[7.0]])
X_bar = ledidi(bpnet_ctcf, X, y_bar, l=0, output_loss=one_directional_loss, verbose=True)

y_hat = predict(bpnet_ctcf, X_bar)
y_hat.mean()
iter=I  input_loss=0.0  output_loss=45.93       total_loss=45.93        time=0.0
iter=100        input_loss=782.8        output_loss=0.03437     total_loss=0.03437      time=1.181
iter=200        input_loss=606.1        output_loss=-0.002722   total_loss=-0.002722    time=1.178
iter=300        input_loss=857.7        output_loss=-0.8195     total_loss=-0.8195      time=1.18
iter=400        input_loss=971.6        output_loss=-1.479      total_loss=-1.479       time=1.178
iter=500        input_loss=1.018e+03    output_loss=-2.272      total_loss=-2.272       time=1.179
iter=600        input_loss=1.055e+03    output_loss=-2.686      total_loss=-2.686       time=1.177
iter=700        input_loss=1.092e+03    output_loss=-3.353      total_loss=-3.353       time=1.179
iter=800        input_loss=1.117e+03    output_loss=-3.812      total_loss=-3.812       time=1.178
iter=900        input_loss=1.145e+03    output_loss=-4.589      total_loss=-4.589       time=1.178
iter=1000       input_loss=1.16e+03     output_loss=-5.064      total_loss=-5.064       time=1.18
iter=F  input_loss=1.159e+03    output_loss=-5.171      total_loss=-5.171       time=11.79
[26]:
tensor(14.1902)

In both examples, we set lambda to 0 so that we can see the true effects, but balancing lambda with the other terms of the loss will be important for design in practice. Regardless, this is just another example of how creative one can be when coming up with custom losses.