One solution is Keras3, formally named Keras core, which contains both TensorFlow and PyTorch backend. It can be easy to migrate with minor code modifications. However, there is much dirty code in Clair3, and the network architecture is simple, so it is much easier to de-novo rewrite the Clair3 in PyTorch.

Clair3 contains two main networks, one named pileup network and another is a full-alignment network. The pileup network containing the Bi-LSTM layer, while the full-alignment network contains CNN layers, are both common in other machines learning small models. This migration work note may be helpful to those who are migrating some Keras small networks to the PyTorch framework.

The network contains common layers, such as Bi-LSTM, Conv2D, Pooling2D, flatten, SELU, ReLU, and FC layer. Some are simple because PyTorch provides the same API as Keras, while some are tricky.

**Bi-LSTM layer:** `keras.layers.Bidirectional(keras.layers.LSTM(units=num_units, return_sequences=True))`

-> `torch.nn.LSTM(input_size=x, hidden_size=num_units, bidirectional=True)`

.

**FC layer:** `keras.layers.Dense(units=num_units, activation='selu')`

-> `torch.nn.Sequential(torch.nn.Linear(input, num_units), torch.nn.SELU()).`

**Batch Normalization Layers:** Mathematically, the batch normalization layers are calculated as \(y = \frac{x - E(x)}{\sqrt(\text{Var}[x] + \epsilon)} * \gamma +\beta\). Although batch normalization layers exist in Keras and PyTorch, the default parameters differ. In Keras, the momentum is 0.99, the \(\epsilon\) is 0.001, while in PyTorch, the momentum is 0.1, and the \(\epsilon\) is 1e-5.

**Flatten Layers:** In Keras, the default data shape is the channels-last format (NHWC), but in PyTorch, it is NCHW. So before flattening the data, permuting is required to make the data after flattening the same. `keras.layers.Flatten(x)`

-> `torch.nn.Flatten(x.permute(0,3,1,2))`

Conv2D and Pooling2D layers are complex to migrate due to Keras' same-padding features, while it is not easy to implement in PyTorch. The output data shape \(H_{o}, W_o\)can be calculated by the following equation:

$$\begin{array}{ll}H_o = &&\left\lfloor \frac{H_i + 2 \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1)-1}{\text{stride[0]}} +1 \right\rfloor \\ W_o = &&\left\lfloor \frac{H_i + 2 \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1)-1}{\text{stride[1]}} +1 \right\rfloor\end{array}$$

When `padding='same'`

and `stride=1`

means the output has the same size as the input, which means \(H_o = H_i \) and \(W_o = W_i\). The padding size is important to meet the conditions.

An example of Conv2D code is:

`class Conv2dSamePadding(nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args. **kwargs) def forward(self, input): in_height, in_width = input.shape[2:] if type(self.stride) is not tuple: self.stride = tuple(self.stride) if type(self.kernel_size) is not tuple: self.kernel_size = tuple(self.kernel_size) if (in_height % self.stride[0] == 0): pad_height = max(self.kernel_size[0] - self.stride[0], 0) else: pad_height = max(self.kernel_size[0] - (in_height % self.stride[0]), 0) if (in_width % self.stride[1] == 0): pad_width = max(self.kernel_size[1] - self.stride[1], 0) else: pad_width = max(self.kernel_size[1] - (in_width % self.stride[1]), 0) pad_top = pad_height // 2 pad_bottom = pad_height - pad_top pad_left = pad_width // 2 pad_right = pad_width - pad_left input = F.pad(input, (pad_top + pad_bottom, pad_left, pad_right), value=0.0) self.padding = (0, 0) return self._conv_forward(input, self.weight, self.bias)`

MaxPooling2D is the same as Conv2D with a padding value `-inf`

not `0`

.

Weights are easy to migrate with h5py. One not-obvious is LSTM bias.

The following function computes LSTM:

$$\begin{array}{l} i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \end{array}$$

PyTorch contains two bias variables, `bias_ih`

, and `bias_hh`

, while in Keras, there is only one bias variable `bias`

. It is obvious to find that \(b_{ih} + b_{hh} = b\), so just set \(b_{ih} = 0\) and \(b_{hh} = b\).

After migration, the final result is the same as the original version, with a max error of 1e-6, which does not affect the variant calling result. With acceptable precision, PyTorch takes 1.5 times more time than the original version.

]]>