From 42cf7fcaf1774a14f9179acb4eb2376558f48db3 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Wed, 22 Nov 2023 15:30:50 +0100 Subject: [PATCH] temporary fix for py37 (combinedloader) (#215) * temporary fix for py37 (combinedloader) --- .gitignore | 3 +++ pina/dataset.py | 2 +- pina/solvers/garom.py | 7 ++++++- pina/solvers/pinn.py | 6 +++++- pina/solvers/supervised.py | 6 +++++- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 73358ad..fd0e93a 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,6 @@ dmypy.json # Cython debug symbols cython_debug/ +# Lightning logs dir +*/lightning_logs/* +lightning_logs/* diff --git a/pina/dataset.py b/pina/dataset.py index 650818c..1cf2f96 100644 --- a/pina/dataset.py +++ b/pina/dataset.py @@ -237,4 +237,4 @@ class SamplePointLoader: 'output': self.batch_output_pts[idx_], 'condition': self.batch_data_conditions[idx_], } - yield d \ No newline at end of file + yield d diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 433e2bc..d7781bd 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -1,6 +1,7 @@ """ Module for GAROM """ import torch +import sys try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 except ImportError: @@ -251,7 +252,11 @@ class GAROM(SolverInterface): for condition_id in range(condition_idx.min(), condition_idx.max()+1): - condition_name = dataloader.condition_names[condition_id] + if sys.version_info >= (3, 8): + condition_name = dataloader.condition_names[condition_id] + else: + condition_name = dataloader.loaders.condition_names[condition_id] + condition = self.problem.conditions[condition_name] pts = batch['pts'].detach() out = batch['output'] diff --git a/pina/solvers/pinn.py b/pina/solvers/pinn.py index 37d968c..5cafbba 100644 --- a/pina/solvers/pinn.py +++ b/pina/solvers/pinn.py @@ -5,6 +5,7 @@ try: except ImportError: from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # torch < 2.0 +import sys from torch.optim.lr_scheduler import ConstantLR from .solver import SolverInterface @@ -143,7 +144,10 @@ class PINN(SolverInterface): for condition_id in range(condition_idx.min(), condition_idx.max()+1): - condition_name = dataloader.condition_names[condition_id] + if sys.version_info >= (3, 8): + condition_name = dataloader.condition_names[condition_id] + else: + condition_name = dataloader.loaders.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch['pts'] diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c98146b..8abf8a6 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,5 +1,6 @@ """ Module for SupervisedSolver """ import torch +import sys try: from torch.optim.lr_scheduler import LRScheduler # torch >= 2.0 except ImportError: @@ -98,7 +99,10 @@ class SupervisedSolver(SolverInterface): for condition_id in range(condition_idx.min(), condition_idx.max()+1): - condition_name = dataloader.condition_names[condition_id] + if sys.version_info >= (3, 8): + condition_name = dataloader.condition_names[condition_id] + else: + condition_name = dataloader.loaders.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch['pts'] out = batch['output']