fix tests

This commit is contained in:
Nicola Demo
2025-01-23 09:52:23 +01:00
parent 9aed1a30b3
commit a899327de1
32 changed files with 2331 additions and 2428 deletions

View File

@@ -49,7 +49,7 @@ class R3Refinement(Callback):
"""
# extract the solver and device from trainer
solver = trainer._model
solver = trainer.solver
device = trainer._accelerator_connector._accelerator_flag
precision = trainer.precision
if precision == "64-true":
@@ -67,7 +67,7 @@ class R3Refinement(Callback):
# compute residual
res_loss = {}
tot_loss = []
for location in self._sampling_locations:
for location in self._sampling_locations: #TODO fix for new collector
condition = solver.problem.conditions[location]
pts = solver.problem.input_pts[location]
# send points to correct device
@@ -79,6 +79,8 @@ class R3Refinement(Callback):
res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
tot_loss.append(torch.abs(target))
print(tot_loss)
return torch.vstack(tot_loss), res_loss
def _r3_routine(self, trainer):
@@ -139,7 +141,7 @@ class R3Refinement(Callback):
:rtype: None
"""
# extract locations for sampling
problem = trainer._model.problem
problem = trainer.solver.problem
locations = []
for condition_name in problem.conditions:
condition = problem.conditions[condition_name]