Lightining update (#104)
* multiple functions for version 0.0 * lightining update * minor changes * data pinn loss added --------- Co-authored-by: Nicola Demo <demo.nicola@gmail.com> Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-3-125.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.station> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@192.168.1.38>
This commit is contained in:
committed by
Nicola Demo
parent
0e3625de80
commit
63fd068988
@@ -10,6 +10,29 @@ from .label_tensor import LabelTensor
|
||||
import torch
|
||||
|
||||
|
||||
def check_consistency(object, object_instance, object_name, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
instance of a specific ``'object_instance'``, or in case
|
||||
``'subclass=True'`` we check if the object is subclass
|
||||
if the ``'object_instance'``.
|
||||
|
||||
:param Object object: The object to check the inheritance
|
||||
:param Object object_instance: The parent class from where the object
|
||||
is expected to inherit
|
||||
:param str object_name: The name of the object
|
||||
:param bool subclass: Check if is a subclass and not instance
|
||||
:raises ValueError: If the object does not inherit from the
|
||||
specified class
|
||||
"""
|
||||
if not subclass:
|
||||
if not isinstance(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
else:
|
||||
if not issubclass(object, object_instance):
|
||||
raise ValueError(f"{object_name} must be {object_instance}")
|
||||
|
||||
|
||||
def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check
|
||||
"""
|
||||
Return the number of parameters of a given `model`.
|
||||
@@ -189,8 +212,7 @@ class LabelTensorDataset(Dataset):
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
pass
|
||||
# return dict(zip(self.pinn.input_pts.keys(), dataloaders))
|
||||
|
||||
# class SampleDataset(torch.utils.data.Dataset):
|
||||
@@ -239,5 +261,4 @@ class LabelTensorDataset(Dataset):
|
||||
class LabelTensorDataLoader(DataLoader):
|
||||
|
||||
def collate_fn(self, data):
|
||||
print(data)
|
||||
gggggggggg
|
||||
pass
|
||||
Reference in New Issue
Block a user