From c1ac142dd5a235c430932b2dae9fb626342417f3 Mon Sep 17 00:00:00 2001 From: "ahernanzl@aemet.es" <ahernanzl@aemet.es> Date: Wed, 12 Mar 2025 07:52:30 +0000 Subject: [PATCH] DeepESD expanded for other variables --- lib/DeepESD_lib.py | 7 ++++--- lib/down_scene_DeepESD.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lib/DeepESD_lib.py b/lib/DeepESD_lib.py index 8bf8331..1fe0b66 100644 --- a/lib/DeepESD_lib.py +++ b/lib/DeepESD_lib.py @@ -119,10 +119,11 @@ def train(targetVar, methodName, family, mode, fields): model_name = 'DeepESD-' + targetVar - if targetVar in ('tasmax', 'tasmin', 'tas'): - model = deep_models.DeepESDtas(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1, stochastic=False) - elif targetVar == 'pr': + if targetVar == 'pr': model = deep_models.DeepESDpr(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1, stochastic=False) + else: + model = deep_models.DeepESDtas(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1, + stochastic=False) num_epochs = 10000 patience_early_stopping = 20 diff --git a/lib/down_scene_DeepESD.py b/lib/down_scene_DeepESD.py index 42502ba..b6c3186 100644 --- a/lib/down_scene_DeepESD.py +++ b/lib/down_scene_DeepESD.py @@ -108,10 +108,11 @@ def downscale(targetVar, methodName, family, mode, fields, scene, model): # Load trained model model_name = 'DeepESD-' + targetVar pathModel = pathAux + 'TRAINED_MODELS/' + targetVar.upper() + '/' + methodName + '/' - if targetVar in ('tasmax', 'tasmin', 'tas'): - model_deep = deep_models.DeepESDtas(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False) - elif targetVar == 'pr': + if targetVar == 'pr': model_deep = deep_models.DeepESDpr(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False) + else: + model_deep = deep_models.DeepESDtas(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False) + model_deep.load_state_dict(torch.load(f'{pathModel}/{model_name}.pt')) # Create template -- GitLab