From 9d83f46702ef130be87dbea21d6c4b17e467ea43 Mon Sep 17 00:00:00 2001
From: "ahernanzl@aemet.es" <ahernanzl@aemet.es>
Date: Wed, 12 Mar 2025 07:45:59 +0000
Subject: [PATCH] DeepESD expanded for other variables

---
 config/default_gui_settings.py | 15 +++++++++++++++
 config/manual_settings.py      | 15 +++++++++++++++
 lib/DeepESD_lib.py             |  7 +++----
 src/gui_mode.py                |  2 --
 4 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/config/default_gui_settings.py b/config/default_gui_settings.py
index 33841f5..7d0de63 100644
--- a/config/default_gui_settings.py
+++ b/config/default_gui_settings.py
@@ -175,6 +175,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         'WG-PDF',
     ],
     'vas': [
@@ -202,6 +203,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         'WG-PDF',
     ],
     'sfcWind': [
@@ -228,6 +230,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'hurs': [
         'RAW',
@@ -253,6 +256,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'huss': [
         'RAW',
@@ -278,6 +282,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'clt': [
         'RAW',
@@ -303,6 +308,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'rsds': [
         'RAW',
@@ -328,6 +334,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'rlds': [
         'RAW',
@@ -353,6 +360,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'evspsbl': [
         'RAW',
@@ -378,6 +386,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'evspsblpot': [
         'RAW',
@@ -403,6 +412,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'psl': [
         'RAW',
@@ -428,6 +438,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'ps': [
         'RAW',
@@ -453,6 +464,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'mrro': [
         'RAW',
@@ -478,6 +490,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'mrso': [
         'RAW',
@@ -503,6 +516,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'myTargetVar': [
         'RAW',
@@ -529,6 +543,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         # 'WG-PDF',
     ],
 }
diff --git a/config/manual_settings.py b/config/manual_settings.py
index d66a8cc..3c64617 100644
--- a/config/manual_settings.py
+++ b/config/manual_settings.py
@@ -175,6 +175,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         'WG-PDF',
     ],
     'vas': [
@@ -202,6 +203,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         'WG-PDF',
     ],
     'sfcWind': [
@@ -228,6 +230,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'hurs': [
         'RAW',
@@ -253,6 +256,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'huss': [
         'RAW',
@@ -278,6 +282,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'clt': [
         'RAW',
@@ -303,6 +308,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'rsds': [
         'RAW',
@@ -328,6 +334,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'rlds': [
         'RAW',
@@ -353,6 +360,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'evspsbl': [
         'RAW',
@@ -378,6 +386,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'evspsblpot': [
         'RAW',
@@ -403,6 +412,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'psl': [
         'RAW',
@@ -428,6 +438,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'ps': [
         'RAW',
@@ -453,6 +464,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'mrro': [
         'RAW',
@@ -478,6 +490,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'mrso': [
         'RAW',
@@ -503,6 +516,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
     ],
     'myTargetVar': [
         'RAW',
@@ -529,6 +543,7 @@ methods = {
         # 'XGB',
         # 'ANN',
         # 'CNN',
+        # 'DeepESD',
         # 'WG-PDF',
     ],
 }
diff --git a/lib/DeepESD_lib.py b/lib/DeepESD_lib.py
index dbefc46..8bf8331 100644
--- a/lib/DeepESD_lib.py
+++ b/lib/DeepESD_lib.py
@@ -82,9 +82,8 @@ def train(targetVar, methodName, family, mode, fields):
     X_train = X_train[valid]
     y_train = y_train[valid]
 
-    if targetVar in ('tasmax', 'tasmin', 'tas'):
-        loss_function = deep_loss.MseLoss(ignore_nans=True)
-    elif targetVar == 'pr':
+
+    if targetVar == 'pr':
         asym_path = pathAux + 'DeepESD/ASYM/'
         os.makedirs(asym_path, exist_ok=True)
         loss_function = deep_loss.Asym(ignore_nans=True, asym_path=asym_path)
@@ -98,7 +97,7 @@ def train(targetVar, methodName, family, mode, fields):
             loss_function.compute_parameters(data=y_train_ds, var_target=targetVar)
         loss_function.prepare_parameters(device=device)
     else:
-        exit('Variable ' + targetVar + ' not implemented for DeepESD')
+        loss_function = deep_loss.MseLoss(ignore_nans=True)
 
     # Create Dataset
     train_dataset = deep_utils.StandardDataset(x=X_train, y=y_train)
diff --git a/src/gui_mode.py b/src/gui_mode.py
index d0233f2..8de68a9 100644
--- a/src/gui_mode.py
+++ b/src/gui_mode.py
@@ -1164,8 +1164,6 @@ class frameMethodsClass(tk.Frame):
             disabled_methods = ['MLR', 'MLR-ANA', 'MLR-WT']
         else:
             disabled_methods = ['GLM-LIN', 'GLM-EXP', 'GLM-CUB', 'WG-NMM']
-            if targetVar not in ['tasmax', 'tasmin', 'tas', 'pr']:
-                disabled_methods.append('DeepESD')
             gaussian_variables = ['tasmax', 'tasmin', 'tas', 'uas', 'vas', 'psl', 'ps', ]
             if (targetVar == 'myTargetVar' and isGaussian != True):
                 disabled_methods.append('PSDM')
-- 
GitLab