Solving problems related to Geometry (#118)
* fix and add tests * minor fix on domain classes --------- Co-authored-by: Dario Coscia <dariocoscia@cli-10-110-0-208.WIFIeduroamSTUD.units.it> Co-authored-by: Dario Coscia <dariocoscia@dhcp-040.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
62ec69ccac
commit
982af4a04d
@@ -2,6 +2,7 @@ import torch
|
||||
from .location import Location
|
||||
from ..utils import check_consistency
|
||||
from ..label_tensor import LabelTensor
|
||||
import random
|
||||
|
||||
|
||||
class Union(Location):
|
||||
@@ -87,7 +88,7 @@ class Union(Location):
|
||||
>>> ellipsoid2 = EllipsoidDomain({'x': [0, 2], 'y': [0, 2]})
|
||||
|
||||
# Create a union of the ellipsoid domains
|
||||
>>> union = GeometryUnion([ellipsoid1, ellipsoid2])
|
||||
>>> union = Union([ellipsoid1, ellipsoid2])
|
||||
|
||||
>>> union.sample(n=1000)
|
||||
LabelTensor([[-0.2025, 0.0072],
|
||||
@@ -108,11 +109,18 @@ class Union(Location):
|
||||
num_points = n // len(self.geometries)
|
||||
|
||||
# sample the points
|
||||
for i, geometry in enumerate(self.geometries):
|
||||
# add to sample total if remainder is not 0
|
||||
if i < remainder:
|
||||
num_points += 1
|
||||
sampled_points.append(geometry.sample(num_points, mode, variables))
|
||||
# NB. geometries as shuffled since if we sample
|
||||
# multiple times just one point, we would end
|
||||
# up sampling only from the first geometry.
|
||||
iter_ = random.sample(self.geometries, len(self.geometries))
|
||||
for i, geometry in enumerate(iter_):
|
||||
# int(i < remainder) is one only if we have a remainder
|
||||
# different than zero. Notice that len(geometries) is
|
||||
# always smaller than remaider.
|
||||
sampled_points.append(geometry.sample(num_points + int(i < remainder), mode, variables))
|
||||
# in case number of sampled points is smaller than the number of geometries
|
||||
if len(sampled_points) >= n:
|
||||
break
|
||||
|
||||
return LabelTensor(torch.cat(sampled_points), labels=[f'{i}' for i in self.variables])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user