Migrate A Keras Network to the PyTorch framework.

Photo by Louis Reed on Unsplash

Migrate A Keras Network to the PyTorch framework.

Using Clair3 as an example


4 min read

Clair3 is a DeepVariant-like variant calling model widely used in long-read sequencing technology, especially in Nanopore sequencing. It was developed by HKU and published in Nature Computer Science, while the code is available on GitHub. However, Clair3 is implemented by TensorFlow Keras, and some GPUs with different architectures do not support it well. So, it is necessary to migrate Clair3 to the PyTorch framework.

One solution is Keras3, formally named Keras core, which contains both TensorFlow and PyTorch backend. It can be easy to migrate with minor code modifications. However, there is much dirty code in Clair3, and the network architecture is simple, so it is much easier to de-novo rewrite the Clair3 in PyTorch.

Clair3 contains two main networks, one named pileup network and another is a full-alignment network. The pileup network containing the Bi-LSTM layer, while the full-alignment network contains CNN layers, are both common in other machines learning small models. This migration work note may be helpful to those who are migrating some Keras small networks to the PyTorch framework.

Network Reconstruction

The network contains common layers, such as Bi-LSTM, Conv2D, Pooling2D, flatten, SELU, ReLU, and FC layer. Some are simple because PyTorch provides the same API as Keras, while some are tricky.

Easy-to-migrate Layers

Bi-LSTM layer: keras.layers.Bidirectional(keras.layers.LSTM(units=num_units, return_sequences=True)) -> torch.nn.LSTM(input_size=x, hidden_size=num_units, bidirectional=True) .

FC layer: keras.layers.Dense(units=num_units, activation='selu') -> torch.nn.Sequential(torch.nn.Linear(input, num_units), torch.nn.SELU()).

Should-be-careful Layers

Batch Normalization Layers: Mathematically, the batch normalization layers are calculated as \(y = \frac{x - E(x)}{\sqrt(\text{Var}[x] + \epsilon)} * \gamma +\beta\). Although batch normalization layers exist in Keras and PyTorch, the default parameters differ. In Keras, the momentum is 0.99, the \(\epsilon\) is 0.001, while in PyTorch, the momentum is 0.1, and the \(\epsilon\) is 1e-5.

Flatten Layers: In Keras, the default data shape is the channels-last format (NHWC), but in PyTorch, it is NCHW. So before flattening the data, permuting is required to make the data after flattening the same. keras.layers.Flatten(x) -> torch.nn.Flatten(x.permute(0,3,1,2))

Hard-to-migrate Layers

Conv2D and Pooling2D layers are complex to migrate due to Keras' same-padding features, while it is not easy to implement in PyTorch. The output data shape \(H_{o}, W_o\)can be calculated by the following equation:

$$\begin{array}{ll}H_o = &&\left\lfloor \frac{H_i + 2 \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1)-1}{\text{stride[0]}} +1 \right\rfloor \\ W_o = &&\left\lfloor \frac{H_i + 2 \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1)-1}{\text{stride[1]}} +1 \right\rfloor\end{array}$$

When padding='same' and stride=1 means the output has the same size as the input, which means \(H_o = H_i \) and \(W_o = W_i\). The padding size is important to meet the conditions.

An example of Conv2D code is:

class Conv2dSamePadding(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args. **kwargs)

    def forward(self, input):
        in_height, in_width = input.shape[2:]

        if type(self.stride) is not tuple:
            self.stride = tuple(self.stride)
        if type(self.kernel_size) is not tuple:
            self.kernel_size = tuple(self.kernel_size)

        if (in_height % self.stride[0] == 0):
            pad_height = max(self.kernel_size[0] - self.stride[0], 0)
            pad_height = max(self.kernel_size[0] - (in_height % self.stride[0]), 0)
        if (in_width % self.stride[1] == 0):
            pad_width = max(self.kernel_size[1] - self.stride[1], 0)
            pad_width = max(self.kernel_size[1] - (in_width % self.stride[1]), 0)

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        input = F.pad(input, (pad_top + pad_bottom, pad_left, pad_right), value=0.0)
        self.padding = (0, 0)

        return self._conv_forward(input, self.weight, self.bias)

MaxPooling2D is the same as Conv2D with a padding value -inf not 0.

Weights Migration

Weights are easy to migrate with h5py. One not-obvious is LSTM bias.

The following function computes LSTM:

$$\begin{array}{l} i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \end{array}$$

PyTorch contains two bias variables, bias_ih , and bias_hh, while in Keras, there is only one bias variable bias. It is obvious to find that \(b_{ih} + b_{hh} = b\), so just set \(b_{ih} = 0\) and \(b_{hh} = b\).


After migration, the final result is the same as the original version, with a max error of 1e-6, which does not affect the variant calling result. With acceptable precision, PyTorch takes 1.5 times more time than the original version.