Skip to content

Remove imperative filling functions _ #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions tests/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ def enumerate(semiring, edge, lengths=None):
semiring = semiring
ssize = semiring.size()
edge, batch, N, C, lengths = model._check_potentials(edge, lengths)
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]
chains = [
[
(
[c],
semiring.fill(
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
),
)
for c in range(C)
]
]

enum_lengths = torch.LongTensor(lengths.shape)
for n in range(1, N):
Expand Down Expand Up @@ -128,7 +138,13 @@ def enumerate(semiring, edge):
edge = semiring.convert(edge)
chains = {}
chains[0] = [
([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)
(
[(c, 0)],
semiring.fill(
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
),
)
for c in range(C)
]

for n in range(1, N + 1):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_generic_lengths(model_test, data):
part = model().sum(vals, lengths=lengths)

# Check that max is correct
assert (maxes <= part).all()
assert (maxes <= part + 1e-3).all()
m_part = model(MaxSemiring).sum(vals, lengths=lengths)
assert (torch.isclose(maxes, m_part)).all(), maxes - m_part

Expand Down
6 changes: 4 additions & 2 deletions torch_struct/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ def log_prob(self, value, sparse=False):
return wrap(scores, sample)

def _beam_search(self, semiring, gumbel=False):
beam = semiring.one_(
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device)
beam = semiring.fill(
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device),
torch.tensor(True),
semiring.one,
)
ssize = semiring.size()

Expand Down
29 changes: 22 additions & 7 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False):
]
for _ in range(2)
]
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
mask = torch.zeros(alpha[A][C][L].data.shape).bool()
mask[:, :, :, 0].fill_(True)
alpha[A][C][L].data[:] = semiring.fill(
alpha[A][C][L].data[:], mask, semiring.one
)
alpha[A][C][R].data[:] = semiring.fill(
alpha[A][C][R].data[:], mask, semiring.one
)
mask = torch.zeros(alpha[B][C][L].data[:].shape).bool()
mask[:, :, :, -1].fill_(True)
alpha[B][C][L].data[:] = semiring.fill(
alpha[B][C][L].data[:], mask, semiring.one
)
alpha[B][C][R].data[:] = semiring.fill(
alpha[B][C][R].data[:], mask, semiring.one
)

if multiroot:
start_idx = 0
Expand Down Expand Up @@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None):
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)
assert max(lengths) <= N, "Length longer than N"
arc_scores = semiring.convert(arc_scores)
for b in range(batch):
semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])

# Set the extra elements of the log-potentials to zero.
keep = torch.ones_like(arc_scores).bool()
for b in range(batch):
keep[:, b, lengths[b] + 1 :, :].fill_(0.0)
keep[:, b, :, lengths[b] + 1 :].fill_(0.0)
arc_scores = semiring.fill(arc_scores, ~keep, semiring.zero)
return arc_scores, batch, N, lengths

def _arrange_marginals(self, grads):
Expand Down
1 change: 1 addition & 0 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class StructDistribution(Distribution):
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
lengths (long tensor, batch_shape) : integers for length masking
"""
validate_args = False

def __init__(self, log_potentials, lengths=None, args={}):
batch_shape = log_potentials.shape[:1]
Expand Down
34 changes: 17 additions & 17 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

class Chart:
def __init__(self, size, potentials, semiring):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
c = torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
c[:] = semiring.zero.view((semiring.size(),) + len(size) * (1,))

self.data = c
self.grad = self.data.detach().clone().fill_(0.0)

def __getitem__(self, ind):
Expand Down Expand Up @@ -50,18 +51,17 @@ def _chart(self, size, potentials, force_grad):
return self._make_chart(1, size, potentials, force_grad)[0]

def _make_chart(self, N, size, potentials, force_grad=False):
return [
(
self.semiring.zero_(
torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
).requires_grad_(force_grad and not potentials.requires_grad)
chart = []
for _ in range(N):
c = torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
for _ in range(N)
]
c[:] = self.semiring.zero.view((self.semiring.size(),) + len(size) * (1,))
c.requires_grad_(force_grad and not potentials.requires_grad)
chart.append(c)
return chart

def sum(self, logpotentials, lengths=None, _raw=False):
"""
Expand Down
8 changes: 5 additions & 3 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))
init = torch.zeros(*chart.shape).bool()
init.diagonal(0, 3, 4).fill_(True)
chart = semiring.fill(chart, init, semiring.one)

# Length mask
big = torch.zeros(
Expand All @@ -71,8 +73,8 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
c.data[:] = semiring.fill(c.data, ~mask, semiring.zero)

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))

Expand Down
14 changes: 8 additions & 6 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))
mask = torch.zeros(*init.shape).bool()
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

# Length mask
big = torch.zeros(
Expand All @@ -54,16 +56,16 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
mask = mask.to(log_potentials.device)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
c.data[:, :, :, 0] = semiring.fill(c.data[:, :, :, 0], (~mask), semiring.zero)
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
mask = torch.zeros(*init.shape).bool()
for k in range(1, K - 1):
semiring.one_(
init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)
)
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1

Expand Down
1 change: 1 addition & 0 deletions torch_struct/semirings/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
try:
import genbmm
from genbmm import BandedMatrix

has_genbmm = True
except ImportError:
pass
Expand Down
Loading