Source code for nnsvs.tacotron.decoder

# The code was adapted from ttslearn https://github.com/r9y9/ttslearn
# NonAttentiveDecoder is added to the original code.
# Acknowledgement: some of the code was adapted from ESPnet
#  Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import torch
import torch.nn.functional as F
from nnsvs.base import BaseModel, PredictionType
from nnsvs.mdn import MDNLayer, mdn_get_most_probable_sigma_and_mu, mdn_get_sample
from nnsvs.util import init_weights
from torch import nn


def decoder_init(m):
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight, nn.init.calculate_gain("tanh"))


class ZoneOutCell(nn.Module):
    def __init__(self, cell, zoneout=0.1):
        super().__init__()
        self.cell = cell
        self.hidden_size = cell.hidden_size
        self.zoneout = zoneout

    def forward(self, inputs, hidden):
        next_hidden = self.cell(inputs, hidden)
        next_hidden = self._zoneout(hidden, next_hidden, self.zoneout)
        return next_hidden

    def _zoneout(self, h, next_h, prob):
        h_0, c_0 = h
        h_1, c_1 = next_h
        h_1 = self._apply_zoneout(h_0, h_1, prob)
        c_1 = self._apply_zoneout(c_0, c_1, prob)
        return h_1, c_1

    def _apply_zoneout(self, h, next_h, prob):
        if self.training:
            if prob > 0.0:
                mask = h.new(*h.size()).bernoulli_(prob)
            else:
                mask = 0
            return mask * h + (1 - mask) * next_h
        else:
            return prob * h + (1 - prob) * next_h


class Prenet(nn.Module):
    """Pre-Net of Tacotron.

    Args:
        in_dim (int) : dimension of input
        layers (int) : number of pre-net layers
        hidden_dim (int) : dimension of hidden layer
        dropout (float) : dropout rate
    """

    def __init__(
        self, in_dim, layers=2, hidden_dim=256, dropout=0.5, eval_dropout=True
    ):
        super().__init__()
        self.dropout = dropout
        self.eval_dropout = eval_dropout
        prenet = nn.ModuleList()
        for layer in range(layers):
            prenet += [
                nn.Linear(in_dim if layer == 0 else hidden_dim, hidden_dim),
                nn.ReLU(),
            ]
        self.prenet = nn.Sequential(*prenet)

    def forward(self, x):
        """Forward step

        Args:
            x (torch.Tensor) : input tensor

        Returns:
            torch.Tensor : output tensor
        """
        for layer in self.prenet:
            if self.eval_dropout:
                x = F.dropout(layer(x), self.dropout, training=True)
            else:
                x = F.dropout(layer(x), self.dropout, training=self.training)
        return x


class NonAttentiveDecoder(BaseModel):
    """Decoder of Tacotron w/o attention mechanism

    Args:
        in_dim (int) : dimension of encoder hidden layer
        out_dim (int) : dimension of output
        layers (int) : number of LSTM layers
        hidden_dim (int) : dimension of hidden layer
        prenet_layers (int) : number of pre-net layers
        prenet_hidden_dim (int) : dimension of pre-net hidden layer
        prenet_dropout (float) : dropout rate of pre-net
        zoneout (float) : zoneout rate
        reduction_factor (int) : reduction factor
        attention_hidden_dim (int) : dimension of attention hidden layer
        attention_conv_channel (int) : number of attention convolution channels
        attention_conv_kernel_size (int) : kernel size of attention convolution
        downsample_by_conv (bool) : if True, downsample by convolution
        initial_value (float) : initial value for the autoregressive decoder.
    """

    def __init__(
        self,
        in_dim=512,
        out_dim=80,
        layers=2,
        hidden_dim=1024,
        prenet_layers=2,
        prenet_hidden_dim=256,
        prenet_dropout=0.5,
        zoneout=0.1,
        reduction_factor=1,
        downsample_by_conv=False,
        init_type="none",
        eval_dropout=True,
        prenet_noise_std=0.0,
        initial_value=0.0,
    ):
        super().__init__()
        self.out_dim = out_dim
        self.reduction_factor = reduction_factor
        self.prenet_dropout = prenet_dropout
        self.prenet_noise_std = prenet_noise_std
        self.initial_value = initial_value

        if prenet_layers > 0:
            self.prenet = Prenet(
                out_dim,
                prenet_layers,
                prenet_hidden_dim,
                prenet_dropout,
                eval_dropout=eval_dropout,
            )
            lstm_in_dim = in_dim + prenet_hidden_dim
        else:
            self.prenet = None
            prenet_hidden_dim = 0
            lstm_in_dim = in_dim + out_dim

        self.lstm = nn.ModuleList()
        for layer in range(layers):
            lstm = nn.LSTMCell(
                lstm_in_dim if layer == 0 else hidden_dim,
                hidden_dim,
            )
            self.lstm += [ZoneOutCell(lstm, zoneout)]

        proj_in_dim = in_dim + hidden_dim
        self.feat_out = nn.Linear(proj_in_dim, out_dim * reduction_factor, bias=False)

        if reduction_factor > 1 and downsample_by_conv:
            self.conv_downsample = nn.Conv1d(
                in_dim,
                in_dim,
                kernel_size=reduction_factor,
                stride=reduction_factor,
                groups=in_dim,
            )
        else:
            self.conv_downsample = None

        init_weights(self, init_type)

    def _zero_state(self, hs):
        init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
        return init_hs

    def is_autoregressive(self):
        return True

    def forward(self, encoder_outs, in_lens, decoder_targets=None):
        """Forward step

        Args:
            encoder_outs (torch.Tensor): encoder outputs (B, T, C)
            in_lens (torch.Tensor): input lengths
            decoder_targets (torch.Tensor): decoder targets for teacher-forcing. (B, T, C)

        Returns:
            torch.Tensor: the output (B, C, T)
        """
        is_inference = decoder_targets is None
        if not is_inference:
            assert encoder_outs.shape[1] == decoder_targets.shape[1]

        # Adjust number of frames according to the reduction factor
        # (B, Lmax, out_dim) ->  (B, Lmax/r, out_dim)
        if self.reduction_factor > 1 and not is_inference:
            decoder_targets = decoder_targets[
                :, self.reduction_factor - 1 :: self.reduction_factor
            ]
        if self.reduction_factor > 1:
            if self.conv_downsample is not None:
                encoder_outs = self.conv_downsample(
                    encoder_outs.transpose(1, 2)
                ).transpose(1, 2)
            else:
                encoder_outs = encoder_outs[
                    :, self.reduction_factor - 1 :: self.reduction_factor
                ]

        h_list, c_list = [], []
        for _ in range(len(self.lstm)):
            h_list.append(self._zero_state(encoder_outs))
            c_list.append(self._zero_state(encoder_outs))

        go_frame = (
            encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim)
            + self.initial_value
        )
        prev_out = go_frame

        if not is_inference and self.prenet is not None:
            prenet_outs = self.prenet(decoder_targets)

        outs = []
        for t in range(encoder_outs.shape[1]):
            # Pre-Net
            if self.prenet is not None:
                if is_inference:
                    prenet_out = self.prenet(prev_out)
                else:
                    prenet_out = prenet_outs[:, t, :]
            elif self.prenet_noise_std > 0:
                prenet_out = (
                    prev_out + torch.randn_like(prev_out) * self.prenet_noise_std
                )
            else:
                prenet_out = F.dropout(prev_out, self.prenet_dropout, training=True)

            # LSTM
            xs = torch.cat([encoder_outs[:, t], prenet_out], dim=1)
            h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0]))
            for i in range(1, len(self.lstm)):
                h_list[i], c_list[i] = self.lstm[i](
                    h_list[i - 1], (h_list[i], c_list[i])
                )
            # Output
            hcs = torch.cat([h_list[-1], encoder_outs[:, t]], dim=1)
            outs.append(self.feat_out(hcs).view(encoder_outs.size(0), self.out_dim, -1))

            # Update decoder input for the next time step
            if is_inference:
                prev_out = outs[-1][:, :, -1]  # (1, out_dim)
            else:
                # Teacher forcing
                prev_out = decoder_targets[:, t, :]

        outs = torch.cat(outs, dim=2)  # (B, out_dim, Lmax)

        if self.reduction_factor > 1:
            outs = outs.view(outs.size(0), self.out_dim, -1)  # (B, out_dim, Lmax)

        # (B, C, T) -> (B, T, C)
        return outs.transpose(1, 2)


[docs] class MDNNonAttentiveDecoder(BaseModel): """Non-atteive decoder with MDN Each decoder step outputs the parameters of MDN. Args: in_dim (int): input dimension out_dim (int): output dimension layers (int): number of LSTM layers hidden_dim (int): hidden dimension prenet_layers (int): number of prenet layers prenet_hidden_dim (int): prenet hidden dimension prenet_dropout (float): prenet dropout rate zoneout (float): zoneout rate reduction_factor (int): reduction factor downsample_by_conv (bool): if True, use conv1d to downsample the input num_gaussians (int): number of Gaussians sampling_mode (str): sampling mode init_type (str): initialization type eval_dropout (bool): if True, use dropout in evaluation initial_value (float) : initial value for the autoregressive decoder. """ def __init__( self, in_dim=512, out_dim=80, layers=2, hidden_dim=1024, prenet_layers=2, prenet_hidden_dim=256, prenet_dropout=0.5, zoneout=0.1, reduction_factor=1, downsample_by_conv=False, num_gaussians=8, sampling_mode="mean", init_type="none", eval_dropout=True, prenet_noise_std=0.0, initial_value=0.0, ): super().__init__() self.out_dim = out_dim self.reduction_factor = reduction_factor self.prenet_dropout = prenet_dropout self.prenet_noise_std = prenet_noise_std self.num_gaussians = num_gaussians self.sampling_mode = sampling_mode assert sampling_mode in ["mean", "random"] self.initial_value = initial_value if prenet_layers > 0: self.prenet = Prenet( out_dim, prenet_layers, prenet_hidden_dim, prenet_dropout, eval_dropout=eval_dropout, ) lstm_in_dim = in_dim + prenet_hidden_dim else: self.prenet = None prenet_hidden_dim = 0 lstm_in_dim = in_dim + out_dim self.lstm = nn.ModuleList() for layer in range(layers): lstm = nn.LSTMCell( lstm_in_dim if layer == 0 else hidden_dim, hidden_dim, ) self.lstm += [ZoneOutCell(lstm, zoneout)] proj_in_dim = in_dim + hidden_dim self.feat_out = MDNLayer( proj_in_dim, out_dim * reduction_factor, num_gaussians=num_gaussians, dim_wise=True, ) if reduction_factor > 1 and downsample_by_conv: self.conv_downsample = nn.Conv1d( in_dim, in_dim, kernel_size=reduction_factor, stride=reduction_factor, groups=in_dim, ) else: self.conv_downsample = None init_weights(self, init_type) def _zero_state(self, hs): init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) return init_hs def is_autoregressive(self): return True def prediction_type(self): return PredictionType.PROBABILISTIC def forward(self, encoder_outs, in_lens, decoder_targets=None): is_inference = decoder_targets is None if not is_inference: assert encoder_outs.shape[1] == decoder_targets.shape[1] # Adjust number of frames according to the reduction factor # (B, Lmax, out_dim) -> (B, Lmax/r, out_dim) if self.reduction_factor > 1 and not is_inference: decoder_targets = decoder_targets[ :, self.reduction_factor - 1 :: self.reduction_factor ] if self.reduction_factor > 1: if self.conv_downsample is not None: encoder_outs = self.conv_downsample( encoder_outs.transpose(1, 2) ).transpose(1, 2) else: encoder_outs = encoder_outs[ :, self.reduction_factor - 1 :: self.reduction_factor ] h_list, c_list = [], [] for _ in range(len(self.lstm)): h_list.append(self._zero_state(encoder_outs)) c_list.append(self._zero_state(encoder_outs)) go_frame = ( encoder_outs.new_zeros(encoder_outs.size(0), self.out_dim) + self.initial_value ) prev_out = go_frame if not is_inference and self.prenet is not None: prenet_outs = self.prenet(decoder_targets) mus = [] log_pis = [] log_sigmas = [] # NOTE: only used for inference mus_inf = [] for t in range(encoder_outs.shape[1]): # Pre-Net if self.prenet is not None: if is_inference: prenet_out = self.prenet(prev_out) else: prenet_out = prenet_outs[:, t, :] elif self.prenet_noise_std > 0: prenet_out = ( prev_out + torch.randn_like(prev_out) * self.prenet_noise_std ) else: prenet_out = F.dropout(prev_out, self.prenet_dropout, training=True) # LSTM xs = torch.cat([encoder_outs[:, t], prenet_out], dim=1) h_list[0], c_list[0] = self.lstm[0](xs, (h_list[0], c_list[0])) for i in range(1, len(self.lstm)): h_list[i], c_list[i] = self.lstm[i]( h_list[i - 1], (h_list[i], c_list[i]) ) # Output hcs = torch.cat([h_list[-1], encoder_outs[:, t]], dim=1) log_pi, log_sigma, mu = self.feat_out(hcs.unsqueeze(1)) # (B, 1, num_gaussians, out_dim*reduction_factor) # -> (B, reduction_factor, num_gaussians, out_dim) log_pi = ( log_pi.transpose(1, 2) .view(encoder_outs.size(0), self.num_gaussians, -1, self.out_dim) .transpose(1, 2) ) log_sigma = ( log_sigma.transpose(1, 2) .view(encoder_outs.size(0), self.num_gaussians, -1, self.out_dim) .transpose(1, 2) ) mu = ( mu.transpose(1, 2) .view(encoder_outs.size(0), self.num_gaussians, -1, self.out_dim) .transpose(1, 2) ) mus.append(mu) log_pis.append(log_pi) log_sigmas.append(log_sigma) # Update decoder input for the next time step if is_inference: # (B, reduction_factor, out_dim) if self.sampling_mode == "mean": _, mu = mdn_get_most_probable_sigma_and_mu(log_pi, log_sigma, mu) elif self.sampling_mode == "random": mu = mdn_get_sample(log_pi, log_sigma, mu) # Feed last sample for the feedback loop prev_out = mu[:, -1] mus_inf.append(mu) else: # Teacher forcing prev_out = decoder_targets[:, t, :] mus = torch.cat(mus, dim=1) # (B, out_dim, Lmax) log_pis = torch.cat(log_pis, dim=1) # (B, out_dim, Lmax) log_sigmas = torch.cat(log_sigmas, dim=1) # (B, out_dim, Lmax) if is_inference: mu = torch.cat(mus_inf, dim=1) # TODO: may need to track sigma. For now we only use mu return mu, mu else: return log_pis, log_sigmas, mus