OxyGenie.learn#

Module contents#

OxyGenie.learn.SimuDataset

A PyTorch dataset class for loading and preprocessing simulation data.

OxyGenie.learn.EUNet

A U-Net-based neural network for image processing with additional embedding input.

class OxyGenie.learn.EUNet[source]#

A U-Net-based neural network for image processing with additional embedding input.

This model processes spatial inputs through a series of down-sampling and up-sampling layers, integrating additional parameter embeddings at the bottleneck.

forward(x1, x2)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

predict(img, params)[source]#

Inference method for processing input images and parameters.

class OxyGenie.learn.SimuDataset(dataset_path, transform=None, random_flip=True)[source]#

A PyTorch dataset class for loading and preprocessing simulation data.

This dataset handles input features (X_1, X_2) and corresponding target outputs (Y), applying specified transformations and augmentations during loading.

descale(y_scaled)[source]#

Converts normalized target data back to its original range.

hvflip(x, y)[source]#

Applies random horizontal and vertical flips to inputs and targets.