Source code for OxyGenie.learn.model

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv_op(x)


class DownSample(nn.Module):
    def __init__(self, in_chanels, out_chanels):
        super().__init__()
        self.conv = DoubleConv(in_chanels, out_chanels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip = self.conv(x)
        down = self.pool(skip)

        return skip, down


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)


[docs] class EUNet(nn.Module): """ A U-Net-based neural network for image processing with additional embedding input. This model processes spatial inputs through a series of down-sampling and up-sampling layers, integrating additional parameter embeddings at the bottleneck. """ def __init__(self): super().__init__() self.down_1 = DownSample(1, 32) self.down_2 = DownSample(32, 64) self.down_3 = DownSample(64, 128) # self.down_4 = DownSample(in_chanels, out_chanels) self.embedding_layer = nn.Sequential( nn.Linear(2, 128), # Embedding of size 2 to match 256 channels nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), ) self.bottom = DoubleConv(128, 128) self.up_1 = UpSample(256, 128) self.up_2 = UpSample(128, 64) self.up_3 = UpSample(64, 32) # self.up_4 = UpSample(in_chanels, out_chanels) self.out = nn.Conv2d(32, 1, kernel_size=3, padding=1)
[docs] def forward(self, x1, x2): s1, d1 = self.down_1(x1) s2, d2 = self.down_2(d1) s3, d3 = self.down_3(d2) b = self.bottom(d3) embedding = self.embedding_layer(x2) # Reshape to [batch_size, 256, 1, 1] embedding = embedding.unsqueeze(-1).unsqueeze(-1) # Match spatial dimensions embedding = embedding.expand(-1, -1, b.size(2), b.size(3)) b_combined = torch.cat([b, embedding], 1) u1 = self.up_1(b_combined, s3) u2 = self.up_2(u1, s2) u3 = self.up_3(u2, s1) y = self.out(u3) return y
[docs] @torch.no_grad def predict(self, img, params): """ Inference method for processing input images and parameters. """ test_image = img.copy() x_min, x_max = np.min(test_image), np.max(test_image) test_image = (test_image - x_min) / (x_max - x_min) test_image = Image.fromarray(test_image).resize((512, 512)) im = TF.to_tensor(test_image).unsqueeze(0) params = params.copy() params[0] = params[0] / 10 params[1] = params[1] * 100 params = torch.tensor(params).unsqueeze(0) result_np = self(im, params).squeeze().numpy() return (result_np - np.min(result_np)) * 100