Add check conditions-solver consistency
This commit is contained in:
committed by
Nicola Demo
parent
f2340cd4ee
commit
a6f0336d06
@@ -10,7 +10,6 @@ import torch
|
||||
import sys
|
||||
|
||||
|
||||
|
||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
"""
|
||||
Solver base class. This class inherits is a wrapper of
|
||||
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||
|
||||
return super().on_train_start()
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
pass
|
||||
#TODO : Implement this method for the conditions
|
||||
'''
|
||||
|
||||
|
||||
for _, condition in problem.conditions.items():
|
||||
if not set(condition.condition_type).issubset(
|
||||
set(self.accepted_condition_types)):
|
||||
raise ValueError(
|
||||
f'{self.__name__} dose not support condition '
|
||||
f'{condition.condition_type}')
|
||||
'''
|
||||
@staticmethod
|
||||
def get_batch_size(batch):
|
||||
# Assuming batch is your custom Batch object
|
||||
batch_size = 0
|
||||
for data in batch:
|
||||
batch_size += len(data[1]['input_points'])
|
||||
return batch_size
|
||||
return batch_size
|
||||
|
||||
def _check_solver_consistency(self, problem):
|
||||
for condition in problem.conditions.values():
|
||||
check_consistency(condition, self.accepted_conditions_types)
|
||||
|
||||
Reference in New Issue
Block a user