Add check conditions-solver consistency

This commit is contained in:
FilippoOlivo
2025-01-16 19:42:46 +01:00
committed by Nicola Demo
parent f2340cd4ee
commit a6f0336d06
3 changed files with 20 additions and 22 deletions

View File

@@ -10,7 +10,6 @@ import torch
import sys
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver base class. This class inherits is a wrapper of
@@ -133,23 +132,14 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
return super().on_train_start()
def _check_solver_consistency(self, problem):
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object
batch_size = 0
for data in batch:
batch_size += len(data[1]['input_points'])
return batch_size
return batch_size
def _check_solver_consistency(self, problem):
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)