Lightining update (#104)

* multiple functions for version 0.0
* lightining update
* minor changes
* data pinn  loss added
---------

Co-authored-by: Nicola Demo <demo.nicola@gmail.com>
Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
Dario Coscia
2023-06-07 15:34:43 +02:00
committed by Nicola Demo
parent 0e3625de80
commit 63fd068988
16 changed files with 710 additions and 603 deletions

View File

@@ -10,6 +10,29 @@ from .label_tensor import LabelTensor
import torch
def check_consistency(object, object_instance, object_name, subclass=False):
"""Helper function to check object inheritance consistency.
Given a specific ``'object'`` we check if the object is
instance of a specific ``'object_instance'``, or in case
``'subclass=True'`` we check if the object is subclass
if the ``'object_instance'``.
:param Object object: The object to check the inheritance
:param Object object_instance: The parent class from where the object
is expected to inherit
:param str object_name: The name of the object
:param bool subclass: Check if is a subclass and not instance
:raises ValueError: If the object does not inherit from the
specified class
"""
if not subclass:
if not isinstance(object, object_instance):
raise ValueError(f"{object_name} must be {object_instance}")
else:
if not issubclass(object, object_instance):
raise ValueError(f"{object_name} must be {object_instance}")
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
"""
Return the number of parameters of a given `model`.
@@ -189,8 +212,7 @@ class LabelTensorDataset(Dataset):
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
print(data)
gggggggggg
pass
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
# class SampleDataset(torch.utils.data.Dataset):
@@ -239,5 +261,4 @@ class LabelTensorDataset(Dataset):
class LabelTensorDataLoader(DataLoader):
def collate_fn(self, data):
print(data)
gggggggggg
pass