From c1ac142dd5a235c430932b2dae9fb626342417f3 Mon Sep 17 00:00:00 2001
From: "ahernanzl@aemet.es" <ahernanzl@aemet.es>
Date: Wed, 12 Mar 2025 07:52:30 +0000
Subject: [PATCH] DeepESD expanded for other variables

---
 lib/DeepESD_lib.py        | 7 ++++---
 lib/down_scene_DeepESD.py | 7 ++++---
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/lib/DeepESD_lib.py b/lib/DeepESD_lib.py
index 8bf8331..1fe0b66 100644
--- a/lib/DeepESD_lib.py
+++ b/lib/DeepESD_lib.py
@@ -119,10 +119,11 @@ def train(targetVar, methodName, family, mode, fields):
 
     model_name = 'DeepESD-' + targetVar
 
-    if targetVar in ('tasmax', 'tasmin', 'tas'):
-        model = deep_models.DeepESDtas(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1, stochastic=False)
-    elif targetVar == 'pr':
+    if targetVar == 'pr':
         model = deep_models.DeepESDpr(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1, stochastic=False)
+    else:
+        model = deep_models.DeepESDtas(x_shape=X_train.shape, y_shape=y_train.shape, filters_last_conv=1,
+                                       stochastic=False)
 
     num_epochs = 10000
     patience_early_stopping = 20
diff --git a/lib/down_scene_DeepESD.py b/lib/down_scene_DeepESD.py
index 42502ba..b6c3186 100644
--- a/lib/down_scene_DeepESD.py
+++ b/lib/down_scene_DeepESD.py
@@ -108,10 +108,11 @@ def downscale(targetVar, methodName, family, mode, fields, scene, model):
     # Load trained model
     model_name = 'DeepESD-' + targetVar
     pathModel = pathAux + 'TRAINED_MODELS/' + targetVar.upper() + '/' + methodName + '/'
-    if targetVar in ('tasmax', 'tasmin', 'tas'):
-        model_deep = deep_models.DeepESDtas(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False)
-    elif targetVar == 'pr':
+    if targetVar == 'pr':
         model_deep = deep_models.DeepESDpr(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False)
+    else:
+        model_deep = deep_models.DeepESDtas(x_shape=X_test.shape, y_shape=y_shape, filters_last_conv=1, stochastic=False)
+
     model_deep.load_state_dict(torch.load(f'{pathModel}/{model_name}.pt'))
 
     # Create template
-- 
GitLab