From 9d83f46702ef130be87dbea21d6c4b17e467ea43 Mon Sep 17 00:00:00 2001 From: "ahernanzl@aemet.es" <ahernanzl@aemet.es> Date: Wed, 12 Mar 2025 07:45:59 +0000 Subject: [PATCH] DeepESD expanded for other variables --- config/default_gui_settings.py | 15 +++++++++++++++ config/manual_settings.py | 15 +++++++++++++++ lib/DeepESD_lib.py | 7 +++---- src/gui_mode.py | 2 -- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/config/default_gui_settings.py b/config/default_gui_settings.py index 33841f5..7d0de63 100644 --- a/config/default_gui_settings.py +++ b/config/default_gui_settings.py @@ -175,6 +175,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', 'WG-PDF', ], 'vas': [ @@ -202,6 +203,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', 'WG-PDF', ], 'sfcWind': [ @@ -228,6 +230,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'hurs': [ 'RAW', @@ -253,6 +256,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'huss': [ 'RAW', @@ -278,6 +282,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'clt': [ 'RAW', @@ -303,6 +308,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'rsds': [ 'RAW', @@ -328,6 +334,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'rlds': [ 'RAW', @@ -353,6 +360,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'evspsbl': [ 'RAW', @@ -378,6 +386,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'evspsblpot': [ 'RAW', @@ -403,6 +412,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'psl': [ 'RAW', @@ -428,6 +438,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'ps': [ 'RAW', @@ -453,6 +464,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'mrro': [ 'RAW', @@ -478,6 +490,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'mrso': [ 'RAW', @@ -503,6 +516,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'myTargetVar': [ 'RAW', @@ -529,6 +543,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', # 'WG-PDF', ], } diff --git a/config/manual_settings.py b/config/manual_settings.py index d66a8cc..3c64617 100644 --- a/config/manual_settings.py +++ b/config/manual_settings.py @@ -175,6 +175,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', 'WG-PDF', ], 'vas': [ @@ -202,6 +203,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', 'WG-PDF', ], 'sfcWind': [ @@ -228,6 +230,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'hurs': [ 'RAW', @@ -253,6 +256,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'huss': [ 'RAW', @@ -278,6 +282,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'clt': [ 'RAW', @@ -303,6 +308,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'rsds': [ 'RAW', @@ -328,6 +334,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'rlds': [ 'RAW', @@ -353,6 +360,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'evspsbl': [ 'RAW', @@ -378,6 +386,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'evspsblpot': [ 'RAW', @@ -403,6 +412,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'psl': [ 'RAW', @@ -428,6 +438,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'ps': [ 'RAW', @@ -453,6 +464,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'mrro': [ 'RAW', @@ -478,6 +490,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'mrso': [ 'RAW', @@ -503,6 +516,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', ], 'myTargetVar': [ 'RAW', @@ -529,6 +543,7 @@ methods = { # 'XGB', # 'ANN', # 'CNN', + # 'DeepESD', # 'WG-PDF', ], } diff --git a/lib/DeepESD_lib.py b/lib/DeepESD_lib.py index dbefc46..8bf8331 100644 --- a/lib/DeepESD_lib.py +++ b/lib/DeepESD_lib.py @@ -82,9 +82,8 @@ def train(targetVar, methodName, family, mode, fields): X_train = X_train[valid] y_train = y_train[valid] - if targetVar in ('tasmax', 'tasmin', 'tas'): - loss_function = deep_loss.MseLoss(ignore_nans=True) - elif targetVar == 'pr': + + if targetVar == 'pr': asym_path = pathAux + 'DeepESD/ASYM/' os.makedirs(asym_path, exist_ok=True) loss_function = deep_loss.Asym(ignore_nans=True, asym_path=asym_path) @@ -98,7 +97,7 @@ def train(targetVar, methodName, family, mode, fields): loss_function.compute_parameters(data=y_train_ds, var_target=targetVar) loss_function.prepare_parameters(device=device) else: - exit('Variable ' + targetVar + ' not implemented for DeepESD') + loss_function = deep_loss.MseLoss(ignore_nans=True) # Create Dataset train_dataset = deep_utils.StandardDataset(x=X_train, y=y_train) diff --git a/src/gui_mode.py b/src/gui_mode.py index d0233f2..8de68a9 100644 --- a/src/gui_mode.py +++ b/src/gui_mode.py @@ -1164,8 +1164,6 @@ class frameMethodsClass(tk.Frame): disabled_methods = ['MLR', 'MLR-ANA', 'MLR-WT'] else: disabled_methods = ['GLM-LIN', 'GLM-EXP', 'GLM-CUB', 'WG-NMM'] - if targetVar not in ['tasmax', 'tasmin', 'tas', 'pr']: - disabled_methods.append('DeepESD') gaussian_variables = ['tasmax', 'tasmin', 'tas', 'uas', 'vas', 'psl', 'ps', ] if (targetVar == 'myTargetVar' and isGaussian != True): disabled_methods.append('PSDM') -- GitLab