import os
import random
import shutil
import sys
import types
from glob import glob
from multiprocessing import Manager
from os.path import join
from pathlib import Path
import hydra
import joblib
import librosa
import librosa.display
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pysptk
import pyworld
import torch
import torch.distributed as dist
from hydra.utils import get_original_cwd, to_absolute_path
from nnmnkwii import metrics
from nnsvs.base import PredictionType
from nnsvs.gen import gen_world_params
from nnsvs.logger import getLogger
from nnsvs.mdn import mdn_get_most_probable_sigma_and_mu
from nnsvs.multistream import (
get_static_features,
get_static_stream_sizes,
get_windows,
multi_stream_mlpg,
select_streams,
split_streams,
)
from nnsvs.pitch import lowpass_filter, note_segments
from nnsvs.util import MinMaxScaler, StandardScaler, init_seed, pad_2d
from omegaconf import DictConfig, ListConfig, OmegaConf
from sklearn.preprocessing import MinMaxScaler as SKMinMaxScaler
from torch import nn, optim
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils import data as data_utils
from torch.utils.data.sampler import BatchSampler
from torch.utils.tensorboard import SummaryWriter
plt.style.use("seaborn-whitegrid")
class ShuffleBatchSampler(BatchSampler):
def __init__(
self,
batches,
drop_last=False,
shuffle=True,
):
self.shuffle = shuffle
self.batches = batches
self.drop_last = drop_last
def __iter__(self):
batches = self.batches
if self.shuffle:
random.shuffle(batches)
return iter(batches)
def __len__(self):
return len(self.batches)
def log_params_from_omegaconf_dict(params):
for param_name, element in params.items():
_explore_recursive(param_name, element)
def _explore_recursive(parent_name, element):
if isinstance(element, DictConfig):
for k, v in element.items():
if isinstance(v, DictConfig) or isinstance(v, ListConfig):
_explore_recursive(f"{parent_name}.{k}", v)
else:
mlflow.log_param(f"{parent_name}.{k}", v)
elif isinstance(element, ListConfig):
for i, v in enumerate(element):
mlflow.log_param(f"{parent_name}.{i}", v)
[docs]
def num_trainable_params(model):
"""Count the number of trainable parameters in the model.
Args:
model (torch.nn.Module): Model to count the number of trainable parameters.
Returns:
int: Number of trainable parameters.
"""
parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in parameters])
def get_filtered_files(
data_root,
logger,
filter_long_segments=False,
filter_num_frames=6000,
filter_min_num_frames=0,
):
files = sorted(glob(join(data_root, "*-feats.npy")))
if filter_long_segments:
valid_files = []
num_filtered = 0
for path in files:
length = len(np.load(path))
if length < filter_num_frames and length > filter_min_num_frames:
valid_files.append(path)
else:
if logger is not None:
logger.info(f"Filtered: {path} is too long or short: {length}")
num_filtered += 1
if num_filtered > 0 and logger is not None:
logger.info(f"Filtered {num_filtered} files")
# Print stats of lengths
if logger is not None:
lengths = [len(np.load(f)) for f in files]
logger.debug(f"[before] Size of dataset: {len(files)}")
logger.debug(f"[before] maximum length: {max(lengths)}")
logger.debug(f"[before] minimum length: {min(lengths)}")
logger.debug(f"[before] mean length: {np.mean(lengths)}")
logger.debug(f"[before] std length: {np.std(lengths)}")
logger.debug(f"[before] median length: {np.median(lengths)}")
files = valid_files
lengths = [len(np.load(f)) for f in files]
if logger is not None:
logger.debug(f"[after] Size of dataset: {len(files)}")
logger.debug(f"[after] maximum length: {max(lengths)}")
logger.debug(f"[after] minimum length: {min(lengths)}")
logger.debug(f"[after] mean length: {np.mean(lengths)}")
logger.debug(f"[after] std length: {np.std(lengths)}")
logger.debug(f"[after] median length: {np.median(lengths)}")
else:
lengths = [len(np.load(f)) for f in files]
return files, lengths
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
if len(batch) == 0:
return 0
if len(batch) == max_sentences:
return 1
if num_tokens > max_tokens:
return 1
return 0
[docs]
def batch_by_size(
indices,
num_tokens_fn,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
sample_len = 0
sample_lens = []
batch = []
batches = []
for i in range(len(indices)):
idx = indices[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
assert (
sample_len <= max_tokens
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
idx, sample_len, max_tokens
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
class Dataset(data_utils.Dataset): # type: ignore
"""Dataset for numpy files
Args:
in_paths (list): List of paths to input files
out_paths (list): List of paths to output files
"""
def __init__(self, in_paths, out_paths, lengths, shuffle=False, allow_cache=True):
self.in_paths = in_paths
self.out_paths = out_paths
self.lengths = lengths
self.sort_by_len = True
self.shuffle = shuffle
self.allow_cache = allow_cache
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(in_paths))]
def __getitem__(self, idx):
"""Get a pair of input and target
Args:
idx (int): index of the pair
Returns:
tuple: input and target in numpy format
"""
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
x, y = np.load(self.in_paths[idx]), np.load(self.out_paths[idx])
if self.allow_cache:
self.caches[idx] = (x, y)
return x, y
def num_tokens(self, index):
return self.lengths[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order.
"""
if self.shuffle:
indices = np.random.permutation(len(self))
if self.sort_by_len:
indices = indices[
np.argsort(np.array(self.lengths)[indices], kind="mergesort")
]
else:
indices = np.arange(len(self))
return indices
def __len__(self):
"""Returns the size of the dataset
Returns:
int: size of the dataset
"""
return len(self.in_paths)
[docs]
def ensure_divisible_by(feats, N):
"""Ensure that the number of frames is divisible by N.
Args:
feats (np.ndarray): Input features.
N (int): Target number of frames.
Returns:
np.ndarray: Input features with number of frames divisible by N.
"""
if N == 1:
return feats
mod = len(feats) % N
if mod != 0:
feats = feats[: len(feats) - mod]
return feats
[docs]
def collate_fn_default(batch, reduction_factor=1, stream_sizes=None, streams=None):
"""Create batch
Args:
batch(tuple): List of tuples
- x[0] (ndarray,int) : list of (T, D_in)
- x[1] (ndarray,int) : list of (T, D_out)
reduction_factor (int): Reduction factor.
Returns:
tuple: Tuple of batch
- x (FloatTensor) : Network inputs (B, max(T), D_in)
- y (FloatTensor) : Network targets (B, max(T), D_out)
- lengths (LongTensor): Input lengths
"""
lengths = [len(ensure_divisible_by(x[0], reduction_factor)) for x in batch]
max_len = max(lengths)
x_batch = torch.stack(
[
torch.from_numpy(
pad_2d(ensure_divisible_by(x[0], reduction_factor), max_len)
)
for x in batch
]
)
if stream_sizes is not None:
assert streams is not None
y_batch = torch.stack(
[
torch.from_numpy(
pad_2d(
ensure_divisible_by(
select_streams(x[1], stream_sizes, streams),
reduction_factor,
),
max_len,
)
)
for x in batch
]
)
else:
y_batch = torch.stack(
[
torch.from_numpy(
pad_2d(ensure_divisible_by(x[1], reduction_factor), max_len)
)
for x in batch
]
)
l_batch = torch.tensor(lengths, dtype=torch.long)
return x_batch, y_batch, l_batch
[docs]
def collate_fn_random_segments(batch, max_time_frames=256):
"""Collate function with random segments
Use segmented frames instead of padded entire frames. No padding is performed.
.. warning::
max_time_frames must be larger than the shortest sequence in the training data.
Args:
batch (tuple): tupls of lit
- x[0] (ndarray,int) : list of (T, D_in)
- x[1] (ndarray,int) : list of (T, D_out)
max_time_frames (int, optional): Number of time frames. Defaults to 256.
Returns:
tuple: Tuple of batch
- x (FloatTensor) : Network inputs (B, max(T), D_in)
- y (FloatTensor) : Network targets (B, max(T), D_out)
- lengths (LongTensor): Input lengths
"""
xs, ys = [b[0] for b in batch], [b[1] for b in batch]
lengths = [len(x[0]) for x in batch]
start_frames = np.array(
[np.random.randint(0, xl - max_time_frames) for xl in lengths]
)
starts = start_frames
ends = starts + max_time_frames
x_cut = [torch.from_numpy(x[s:e]) for x, s, e in zip(xs, starts, ends)]
y_cut = [torch.from_numpy(y[s:e]) for y, s, e in zip(ys, starts, ends)]
x_batch = torch.stack(x_cut).float()
y_batch = torch.stack(y_cut).float()
# NOTE: we don't actually need lengths since we don't perform padding
# but just for consistency with collate_fn_default
l_batch = torch.tensor([max_time_frames] * len(lengths), dtype=torch.long)
return x_batch, y_batch, l_batch
[docs]
def get_data_loaders(data_config, collate_fn, logger):
"""Get data loaders for training and validation.
Args:
data_config (dict): Data configuration.
collate_fn (callable): Collate function.
logger (logging.Logger): Logger.
Returns:
dict: Data loaders.
"""
if "filter_long_segments" not in data_config:
logger.warning(
"filter_long_segments is not found in the data config. Consider set it explicitly."
)
logger.info("Disable filtering for long segments.")
filter_long_segments = False
else:
filter_long_segments = data_config.filter_long_segments
if "filter_num_frames" not in data_config:
logger.warning(
"filter_num_frames is not found in the data config. Consider set it explicitly."
)
filter_num_frames = 6000
filter_min_num_frames = 0
else:
filter_num_frames = data_config.filter_num_frames
filter_min_num_frames = data_config.filter_min_num_frames
data_loaders = {}
samplers = {}
for phase in ["train_no_dev", "dev"]:
in_dir = to_absolute_path(data_config[phase].in_dir)
out_dir = to_absolute_path(data_config[phase].out_dir)
train = phase.startswith("train")
in_files, lengths = get_filtered_files(
in_dir,
logger,
filter_long_segments=filter_long_segments,
filter_num_frames=filter_num_frames,
filter_min_num_frames=filter_min_num_frames,
)
out_files, _ = get_filtered_files(
out_dir,
None,
filter_long_segments=filter_long_segments,
filter_num_frames=filter_num_frames,
filter_min_num_frames=filter_min_num_frames,
)
# Dynamic batch size
if data_config.batch_max_frames > 0:
logger.debug(
f"Dynamic batch size with batch_max_frames={data_config.batch_max_frames}"
)
dataset = Dataset(
in_files,
out_files,
lengths,
shuffle=train,
allow_cache=data_config.get("allow_cache", False),
)
if dist.is_initialized():
required_batch_size_multiple = dist.get_world_size()
else:
required_batch_size_multiple = 1
indices = dataset.ordered_indices()
batches = batch_by_size(
indices,
dataset.num_tokens,
max_tokens=data_config.batch_max_frames,
required_batch_size_multiple=required_batch_size_multiple,
)
# Split mini-batches for each rank manually
if dist.is_initialized():
num_replicas = dist.get_world_size()
rank = dist.get_rank()
logger.debug(f"Splitting mini-batches for rank {rank}")
batches = [
x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0
]
logger.info(f"Num mini-batches: {len(batches)}")
for batch in batches:
sizes = [dataset.num_tokens(i) for i in batch]
logger.debug(f"Batch-size: {len(batch)}, Lens: {sizes}")
logger.info(f"Average batch size: {np.mean([len(b) for b in batches])}")
data_loader_extra_kwargs = {
"batch_sampler": ShuffleBatchSampler(batches) if train else batches,
}
sampler = None
else:
logger.debug(f"Fixed batch size: {data_config.batch_size}")
dataset = Dataset(in_files, out_files, lengths)
if dist.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=train
)
shuffle = False
else:
sampler = None
shuffle = train
data_loader_extra_kwargs = {
"batch_size": data_config.batch_size,
"sampler": sampler,
"shuffle": shuffle,
}
data_loaders[phase] = data_utils.DataLoader(
dataset,
collate_fn=collate_fn,
pin_memory=data_config.pin_memory,
num_workers=data_config.num_workers,
**data_loader_extra_kwargs,
)
samplers[phase] = sampler
return data_loaders, samplers
def set_epochs_based_on_max_steps_(train_config, steps_per_epoch, logger):
"""Set epochs based on max steps.
Args:
train_config (TrainConfig): Train config.
steps_per_epoch (int): Number of steps per epoch.
logger (logging.Logger): Logger.
"""
if "max_train_steps" not in train_config:
logger.warning("max_train_steps is not found in the train config.")
return
logger.info(f"Number of iterations per epoch: {steps_per_epoch}")
if train_config.max_train_steps < 0:
# Set max_train_steps based on nepochs
max_train_steps = train_config.nepochs * steps_per_epoch
train_config.max_train_steps = max_train_steps
logger.info(
"Number of max_train_steps is set based on nepochs: {}".format(
max_train_steps
)
)
else:
# Set nepochs based on max_train_steps
max_train_steps = train_config.max_train_steps
epochs = int(np.ceil(max_train_steps / steps_per_epoch))
train_config.nepochs = epochs
logger.info(
"Number of epochs is set based on max_train_steps: {}".format(epochs)
)
logger.info(f"Number of epochs: {train_config.nepochs}")
logger.info(f"Number of iterations: {train_config.max_train_steps}")
[docs]
def save_checkpoint(
logger,
out_dir,
model,
optimizer,
lr_scheduler,
epoch,
is_best=False,
postfix="",
):
"""Save a checkpoint.
Args:
logger (logging.Logger): Logger.
out_dir (str): Output directory.
model (nn.Module): Model.
optimizer (Optimizer): Optimizer.
lr_scheduler (LRScheduler): Learning rate scheduler.
epoch (int): Current epoch.
is_best (bool, optional): Whether or not the current model is the best.
Defaults to False.
postfix (str, optional): Postfix. Defaults to "".
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
if isinstance(model, nn.DataParallel) or isinstance(model, DDP):
model = model.module
out_dir.mkdir(parents=True, exist_ok=True)
if is_best:
path = out_dir / f"best_loss{postfix}.pth"
else:
path = out_dir / "epoch{:04d}{}.pth".format(epoch, postfix)
torch.save(
{
"state_dict": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"lr_scheduler_state": lr_scheduler.state_dict(),
},
path,
)
logger.info(f"Saved checkpoint at {path}")
if not is_best:
shutil.copyfile(path, out_dir / f"latest{postfix}.pth")
def get_stream_weight(stream_weights, stream_sizes):
if stream_weights is not None:
assert len(stream_weights) == len(stream_sizes)
return torch.tensor(stream_weights)
S = sum(stream_sizes)
w = torch.tensor(stream_sizes).float() / S
return w
def _instantiate_optim(optim_config, model):
# Optimizer
optimizer_class = getattr(optim, optim_config.optimizer.name)
optimizer = optimizer_class(model.parameters(), **optim_config.optimizer.params)
# Scheduler
lr_scheduler_class = getattr(optim.lr_scheduler, optim_config.lr_scheduler.name)
lr_scheduler = lr_scheduler_class(optimizer, **optim_config.lr_scheduler.params)
return optimizer, lr_scheduler
def _resume(logger, resume_config, model, optimizer, lr_scheduler):
if resume_config.checkpoint is not None and len(resume_config.checkpoint) > 0:
logger.info("Load weights from %s", resume_config.checkpoint)
checkpoint = torch.load(to_absolute_path(resume_config.checkpoint))
state_dict = checkpoint["state_dict"]
model_dict = model.state_dict()
valid_state_dict = {
k: v
for k, v in state_dict.items()
if (k in model_dict) and (v.shape == model_dict[k].shape)
}
non_valid_state_dict = {
k: v for k, v in state_dict.items() if k not in valid_state_dict
}
if len(non_valid_state_dict) > 0:
for k, _ in non_valid_state_dict.items():
logger.warning(f"Skip loading {k} from checkpoint")
model_dict.update(valid_state_dict)
model.load_state_dict(model_dict)
if resume_config.load_optimizer:
logger.info("Load optimizer state")
optimizer.load_state_dict(checkpoint["optimizer_state"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state"])
[docs]
def setup(config, device, collate_fn=collate_fn_default):
"""Setup for training
Args:
config (dict): configuration for training
device (torch.device): device to use for training
collate_fn (callable, optional): collate function. Defaults to collate_fn_default.
Returns:
(tuple): tuple containing model, optimizer, learning rate scheduler,
data loaders, tensorboard writer, logger, and scalers.
"""
if dist.is_initialized():
rank = dist.get_rank()
logger = getLogger(config.verbose) if rank == 0 else getLogger(0)
sys.stdout = open(os.devnull, "w") if rank != 0 else sys.stdout
else:
logger = getLogger(config.verbose)
rank = 0
logger.info(OmegaConf.to_yaml(config))
logger.info(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
from torch.backends import cudnn
cudnn.benchmark = config.train.cudnn.benchmark
cudnn.deterministic = config.train.cudnn.deterministic
logger.info(f"cudnn.deterministic: {cudnn.deterministic}")
logger.info(f"cudnn.benchmark: {cudnn.benchmark}")
if torch.backends.cudnn.version() is not None:
logger.info(f"cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"Random seed: {config.seed}")
init_seed(config.seed)
if config.train.use_detect_anomaly:
torch.autograd.set_detect_anomaly(True)
logger.info("Set to use torch.autograd.detect_anomaly")
if "use_amp" in config.train and config.train.use_amp:
logger.info("Use mixed precision training")
grad_scaler = GradScaler()
else:
grad_scaler = None
# Model
model = hydra.utils.instantiate(config.model.netG).to(device)
logger.info(model)
logger.info(
"Number of trainable params: {:.3f} million".format(
num_trainable_params(model) / 1000000.0
)
)
# Distributed training
if dist.is_initialized():
device_id = rank % torch.cuda.device_count()
model = DDP(model, device_ids=[device_id])
# Optimizer
optimizer_class = getattr(optim, config.train.optim.optimizer.name)
optimizer = optimizer_class(
model.parameters(), **config.train.optim.optimizer.params
)
# Scheduler
lr_scheduler_class = getattr(
optim.lr_scheduler, config.train.optim.lr_scheduler.name
)
lr_scheduler = lr_scheduler_class(
optimizer, **config.train.optim.lr_scheduler.params
)
# DataLoader
data_loaders, samplers = get_data_loaders(config.data, collate_fn, logger)
set_epochs_based_on_max_steps_(
config.train, len(data_loaders["train_no_dev"]), logger
)
# Resume
_resume(logger, config.train.resume, model, optimizer, lr_scheduler)
if config.data_parallel:
model = nn.DataParallel(model)
# Mlflow
if config.mlflow.enabled:
mlflow.set_tracking_uri("file://" + get_original_cwd() + "/mlruns")
mlflow.set_experiment(config.mlflow.experiment)
# NOTE: disable tensorboard if mlflow is enabled
writer = None
logger.info("Using mlflow instead of tensorboard")
else:
# Tensorboard
if rank == 0:
writer = SummaryWriter(to_absolute_path(config.train.log_dir))
else:
writer = None
# Scalers
if "in_scaler_path" in config.data and config.data.in_scaler_path is not None:
in_scaler = joblib.load(to_absolute_path(config.data.in_scaler_path))
in_scaler = MinMaxScaler(
in_scaler.min_, in_scaler.scale_, in_scaler.data_min_, in_scaler.data_max_
)
else:
in_scaler = None
if "out_scaler_path" in config.data and config.data.out_scaler_path is not None:
out_scaler = joblib.load(to_absolute_path(config.data.out_scaler_path))
out_scaler = StandardScaler(
out_scaler.mean_, out_scaler.var_, out_scaler.scale_
)
else:
out_scaler = None
return (
model,
optimizer,
lr_scheduler,
grad_scaler,
data_loaders,
samplers,
writer,
logger,
in_scaler,
out_scaler,
)
def setup_gan(config, device, collate_fn=collate_fn_default):
"""Setup for training GAN
Args:
config (dict): configuration for training
device (torch.device): device to use for training
collate_fn (callable, optional): collate function. Defaults to collate_fn_default.
Returns:
(tuple): tuple containing model, optimizer, learning rate scheduler,
data loaders, tensorboard writer, logger, and scalers.
"""
if dist.is_initialized():
rank = dist.get_rank()
logger = getLogger(config.verbose) if rank == 0 else getLogger(0)
sys.stdout = open(os.devnull, "w") if rank != 0 else sys.stdout
else:
logger = getLogger(config.verbose)
rank = 0
logger.info(OmegaConf.to_yaml(config))
logger.info(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
from torch.backends import cudnn
cudnn.benchmark = config.train.cudnn.benchmark
cudnn.deterministic = config.train.cudnn.deterministic
logger.info(f"cudnn.deterministic: {cudnn.deterministic}")
logger.info(f"cudnn.benchmark: {cudnn.benchmark}")
if torch.backends.cudnn.version() is not None:
logger.info(f"cuDNN version: {torch.backends.cudnn.version()}")
logger.info(f"Random seed: {config.seed}")
init_seed(config.seed)
if config.train.use_detect_anomaly:
torch.autograd.set_detect_anomaly(True)
logger.info("Set to use torch.autograd.detect_anomaly")
if "use_amp" in config.train and config.train.use_amp:
logger.info("Use mixed precision training")
grad_scaler = GradScaler()
else:
grad_scaler = None
# Model G
netG = hydra.utils.instantiate(config.model.netG).to(device)
logger.info(netG)
logger.info(
"[Generator] Number of trainable params: {:.3f} million".format(
num_trainable_params(netG) / 1000000.0
)
)
if dist.is_initialized():
device_id = rank % torch.cuda.device_count()
netG = DDP(netG, device_ids=[device_id])
# Optimizer and LR scheduler for G
optG, schedulerG = _instantiate_optim(config.train.optim.netG, netG)
# Model D
netD = hydra.utils.instantiate(config.model.netD).to(device)
logger.info(netD)
logger.info(
"[Discriminator] Number of trainable params: {:.3f} million".format(
num_trainable_params(netD) / 1000000.0
)
)
if dist.is_initialized():
device_id = rank % torch.cuda.device_count()
netD = DDP(netD, device_ids=[device_id])
# Optimizer and LR scheduler for D
optD, schedulerD = _instantiate_optim(config.train.optim.netD, netD)
# DataLoader
data_loaders, samplers = get_data_loaders(config.data, collate_fn, logger)
set_epochs_based_on_max_steps_(
config.train, len(data_loaders["train_no_dev"]), logger
)
# Resume
_resume(logger, config.train.resume.netG, netG, optG, schedulerG)
_resume(logger, config.train.resume.netD, netD, optD, schedulerD)
if config.data_parallel:
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)
# Mlflow
if config.mlflow.enabled:
mlflow.set_tracking_uri("file://" + get_original_cwd() + "/mlruns")
mlflow.set_experiment(config.mlflow.experiment)
# NOTE: disable tensorboard if mlflow is enabled
writer = None
logger.info("Using mlflow instead of tensorboard")
else:
# Tensorboard
writer = SummaryWriter(to_absolute_path(config.train.log_dir))
# Scalers
if "in_scaler_path" in config.data and config.data.in_scaler_path is not None:
in_scaler = joblib.load(to_absolute_path(config.data.in_scaler_path))
if isinstance(in_scaler, SKMinMaxScaler):
in_scaler = MinMaxScaler(
in_scaler.min_,
in_scaler.scale_,
in_scaler.data_min_,
in_scaler.data_max_,
)
else:
in_scaler = None
if "out_scaler_path" in config.data and config.data.out_scaler_path is not None:
out_scaler = joblib.load(to_absolute_path(config.data.out_scaler_path))
out_scaler = StandardScaler(
out_scaler.mean_, out_scaler.var_, out_scaler.scale_
)
else:
out_scaler = None
return (
(netG, optG, schedulerG),
(netD, optD, schedulerD),
grad_scaler,
data_loaders,
samplers,
writer,
logger,
in_scaler,
out_scaler,
)
def save_configs(config):
out_dir = Path(to_absolute_path(config.train.out_dir))
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "model.yaml", "w") as f:
OmegaConf.save(config.model, f)
with open(out_dir / "config.yaml", "w") as f:
OmegaConf.save(config, f)
def check_resf0_config(logger, model, config, in_scaler, out_scaler):
logger.info("Checking model configs for residual F0 prediction")
if in_scaler is None or out_scaler is None:
raise ValueError("in_scaler and out_scaler must be specified")
if isinstance(model, nn.DataParallel) or isinstance(model, DDP):
model = model.module
in_lf0_idx = config.data.in_lf0_idx
in_rest_idx = config.data.in_rest_idx
out_lf0_idx = config.data.out_lf0_idx
if in_lf0_idx is None or in_rest_idx is None or out_lf0_idx is None:
raise ValueError("in_lf0_idx, in_rest_idx and out_lf0_idx must be specified")
logger.info("in_lf0_idx: %s", in_lf0_idx)
logger.info("in_rest_idx: %s", in_rest_idx)
logger.info("out_lf0_idx: %s", out_lf0_idx)
ok = True
if hasattr(model, "in_lf0_idx"):
if model.in_lf0_idx != in_lf0_idx:
logger.warning(
"in_lf0_idx in model and data config must be same",
model.in_lf0_idx,
in_lf0_idx,
)
ok = False
if hasattr(model, "out_lf0_idx"):
if model.out_lf0_idx != out_lf0_idx:
logger.warning(
"out_lf0_idx in model and data config must be same",
model.out_lf0_idx,
out_lf0_idx,
)
ok = False
if hasattr(model, "in_lf0_min") and hasattr(model, "in_lf0_max"):
# Inject values from the input scaler
if model.in_lf0_min is None or model.in_lf0_max is None:
model.in_lf0_min = in_scaler.data_min_[in_lf0_idx]
model.in_lf0_max = in_scaler.data_max_[in_lf0_idx]
logger.info("in_lf0_min: %s", model.in_lf0_min)
logger.info("in_lf0_max: %s", model.in_lf0_max)
if not np.allclose(model.in_lf0_min, in_scaler.data_min_[model.in_lf0_idx]):
logger.warning(
f"in_lf0_min is set to {model.in_lf0_min}, "
f"but should be {in_scaler.data_min_[model.in_lf0_idx]}"
)
ok = False
if not np.allclose(model.in_lf0_max, in_scaler.data_max_[model.in_lf0_idx]):
logger.warning(
f"in_lf0_max is set to {model.in_lf0_max}, "
f"but should be {in_scaler.data_max_[model.in_lf0_idx]}"
)
ok = False
if hasattr(model, "out_lf0_mean") and hasattr(model, "out_lf0_scale"):
# Inject values from the output scaler
if model.out_lf0_mean is None or model.out_lf0_scale is None:
model.out_lf0_mean = float(out_scaler.mean_[out_lf0_idx])
model.out_lf0_scale = float(out_scaler.scale_[out_lf0_idx])
logger.info("model.out_lf0_mean: %s", model.out_lf0_mean)
logger.info("model.out_lf0_scale: %s", model.out_lf0_scale)
if not np.allclose(model.out_lf0_mean, out_scaler.mean_[model.out_lf0_idx]):
logger.warning(
f"out_lf0_mean is set to {model.out_lf0_mean}, "
f"but should be {out_scaler.mean_[model.out_lf0_idx]}"
)
ok = False
if not np.allclose(model.out_lf0_scale, out_scaler.scale_[model.out_lf0_idx]):
logger.warning(
f"out_lf0_scale is set to {model.out_lf0_scale}, "
f"but should be {out_scaler.scale_[model.out_lf0_idx]}"
)
ok = False
if not ok:
if (
model.in_lf0_idx == in_lf0_idx
and hasattr(model, "in_lf0_min")
and hasattr(model, "out_lf0_mean")
):
logger.info(
f"""
If you are 100% sure that you set model.in_lf0_idx and model.out_lf0_idx correctly,
Please consider the following parameters in your model config:
in_lf0_idx: {model.in_lf0_idx}
out_lf0_idx: {model.out_lf0_idx}
in_lf0_min: {in_scaler.data_min_[model.in_lf0_idx]}
in_lf0_max: {in_scaler.data_max_[model.in_lf0_idx]}
out_lf0_mean: {out_scaler.mean_[model.out_lf0_idx]}
out_lf0_scale: {out_scaler.scale_[model.out_lf0_idx]}
"""
)
raise ValueError("The model config has wrong configurations.")
# Overwrite the parameters to the config
for key in ["in_lf0_min", "in_lf0_max", "out_lf0_mean", "out_lf0_scale"]:
if hasattr(model, key):
config.model.netG[key] = float(getattr(model, key))
[docs]
def compute_pitch_regularization_weight(segments, N, decay_size=25, max_w=0.5):
"""Compute pitch regularization weight given note segments
Args:
segments (list): list of note (start, end) indices
N (int): number of frames
decay_size (int): size of the decay window
max_w (float): maximum weight
Returns:
Tensor: weights of shape (N,)
"""
w = torch.zeros(N)
for s, e in segments:
L = e - s
w[s:e] = max_w
if L > decay_size * 2:
w[s : s + decay_size] *= torch.arange(decay_size) / decay_size
w[e - decay_size : e] *= torch.arange(decay_size - 1, -1, -1) / decay_size
else:
# For shote notes (less than decay_size*0.01 sec) we don't use pitch regularization
w[s:e] = 0.0
return w
[docs]
def compute_batch_pitch_regularization_weight(lf0_score_denorm, decay_size):
"""Batch version of computing pitch regularization weight
Args:
lf0_score_denorm (Tensor): (B, T)
Returns:
Tensor: weights of shape (B, N, 1)
"""
B, T = lf0_score_denorm.shape
w = torch.zeros_like(lf0_score_denorm)
for idx in range(len(lf0_score_denorm)):
segments = note_segments(lf0_score_denorm[idx])
w[idx, :] = compute_pitch_regularization_weight(
segments, N=T, decay_size=decay_size
).to(w.device)
return w.unsqueeze(-1)
[docs]
@torch.no_grad()
def compute_distortions(pred_out_feats, out_feats, lengths, out_scaler, model_config):
"""Compute distortion measures between predicted and ground-truth acoustic features
Args:
pred_out_feats (nn.Tensor): predicted acoustic features
out_feats (nn.Tensor): ground-truth acoustic features
lengths (nn.Tensor): lengths of the sequences
out_scaler (nn.Module): scaler to denormalize features
model_config (dict): model configuration
Returns:
dict: a dict that includes MCD for mgc/bap, V/UV error and F0 RMSE
"""
out_feats = out_scaler.inverse_transform(out_feats)
pred_out_feats = out_scaler.inverse_transform(pred_out_feats)
out_streams = get_static_features(
out_feats,
model_config.num_windows,
model_config.stream_sizes,
model_config.has_dynamic_features,
)
pred_out_streams = get_static_features(
pred_out_feats,
model_config.num_windows,
model_config.stream_sizes,
model_config.has_dynamic_features,
)
if len(out_streams) >= 4:
mgc, lf0, vuv, bap = (
out_streams[0],
out_streams[1],
out_streams[2],
out_streams[3],
)
pred_mgc, pred_lf0, pred_vuv, pred_bap = (
pred_out_streams[0],
pred_out_streams[1],
pred_out_streams[2],
pred_out_streams[3],
)
elif len(out_streams) == 3:
mgc, lf0, vuv = out_streams[0], out_streams[1], out_streams[2]
pred_mgc, pred_lf0, pred_vuv = (
pred_out_streams[0],
pred_out_streams[1],
pred_out_streams[2],
)
bap = None
pred_bap = None
# binarize vuv
vuv, pred_vuv = (vuv > 0.5).float(), (pred_vuv > 0.5).float()
dist = {
"ObjEval_MGC_MCD": metrics.melcd(
mgc[:, :, 1:], pred_mgc[:, :, 1:], lengths=lengths
),
"ObjEval_VUV_ERR": metrics.vuv_error(vuv, pred_vuv, lengths=lengths),
}
if bap is not None:
dist["ObjEval_BAP_MCD"] = metrics.melcd(bap, pred_bap, lengths=lengths) / 10.0
try:
f0_mse = metrics.lf0_mean_squared_error(
lf0, vuv, pred_lf0, pred_vuv, lengths=lengths, linear_domain=True
)
dist["ObjEval_F0_RMSE"] = np.sqrt(f0_mse)
except ZeroDivisionError:
pass
return dist
@torch.no_grad()
def eval_pitch_model(
phase,
step,
netG,
in_feats,
out_feats,
lengths,
model_config,
out_scaler,
writer,
sr,
lf0_score_denorm,
max_num_eval_utts=10,
):
if dist.is_initialized() and dist.get_rank() != 0:
return
if writer is None:
return
# make sure to be in eval mode
netG.eval()
prediction_type = (
netG.module.prediction_type()
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP)
else netG.prediction_type()
)
utt_indices = list(range(max_num_eval_utts))
utt_indices = utt_indices[: min(len(utt_indices), len(in_feats))]
assert not np.any(model_config.has_dynamic_features)
for utt_idx in utt_indices:
out_feats_denorm_ = out_scaler.inverse_transform(
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0)
)
lf0 = out_feats_denorm_.squeeze(0).cpu().numpy().reshape(-1)
lf0_score_denorm_ = (
lf0_score_denorm[utt_idx, : lengths[utt_idx]].cpu().numpy().reshape(-1)
)
# Run forward
outs = netG(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
)
# ResF0 case
if netG.has_residual_lf0_prediction():
outs, _ = outs
if prediction_type == PredictionType.PROBABILISTIC:
pi, sigma, mu = outs
pred_out_feats = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1]
else:
pred_out_feats = outs
# NOTE: multiple outputs
if isinstance(pred_out_feats, list):
pred_out_feats = pred_out_feats[-1]
if isinstance(pred_out_feats, tuple):
pred_out_feats = pred_out_feats[0]
if not isinstance(pred_out_feats, list):
pred_out_feats = [pred_out_feats]
# Run inference
if prediction_type == PredictionType.PROBABILISTIC:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats, _ = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats, _ = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
pred_out_feats.append(inference_out_feats)
assert len(pred_out_feats) == 2
for idx, pred_out_feats_ in enumerate(pred_out_feats):
pred_out_feats_ = pred_out_feats_.squeeze(0).cpu().numpy()
pred_lf0 = (
out_scaler.inverse_transform(
torch.from_numpy(pred_out_feats_).to(in_feats.device)
)
.cpu()
.numpy()
).reshape(-1)
if idx == 1:
group = f"{phase}_utt{np.abs(utt_idx)}_inference"
else:
group = f"{phase}_utt{np.abs(utt_idx)}_forward"
# Continuous log-F0
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(
timeaxis, lf0, linewidth=1.5, color="tab:blue", label="Target log-F0"
)
ax.plot(
timeaxis,
pred_lf0,
linewidth=1.5,
color="tab:orange",
label="Predicted log-F0",
)
ax.plot(
timeaxis,
lf0_score_denorm_,
"--",
color="gray",
linewidth=1.3,
label="Note log-F0",
)
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Log-frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(
min(lf0_score_denorm_[lf0_score_denorm_ > 0]),
min(lf0),
min(pred_lf0),
)
- 0.1,
max(max(lf0_score_denorm_), max(lf0), max(pred_lf0)) + 0.1,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/ContinuousLogF0", fig, step)
plt.close()
# F0
lf0_score = lf0_score_denorm_.copy()
note_indices = lf0_score > 0
lf0_score[note_indices] = np.exp(lf0_score[note_indices])
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(
timeaxis,
np.exp(lf0) * np.sign(lf0_score),
linewidth=1.5,
color="tab:blue",
label="Target F0",
)
ax.plot(
timeaxis,
np.exp(pred_lf0) * np.sign(lf0_score),
linewidth=1.5,
color="tab:orange",
label="Predicted F0",
)
ax.plot(
timeaxis, lf0_score, "--", linewidth=1.3, color="gray", label="Note F0"
)
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(
min(lf0_score[lf0_score > 0]),
min(np.exp(lf0)),
min(np.exp(pred_lf0)),
)
- 10,
max(max(lf0_score), max(np.exp(lf0)), max(np.exp(pred_lf0))) + 10,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/F0", fig, step)
plt.close()
def synthesize(
device,
mgc,
lf0,
vuv,
bap,
sr,
use_world_codec=False,
vuv_threshold=0.3,
vocoder=None,
vocoder_in_scaler=None,
vocoder_config=None,
):
if vocoder is not None:
is_usfgan = "generator" in vocoder_config and "discriminator" in vocoder_config
assert vocoder_in_scaler is not None
if not is_usfgan:
# NOTE: So far vocoder models are trained on binary V/UV features
vuv = (vuv > vuv_threshold).astype(np.float32)
voc_inp = (
torch.from_numpy(
vocoder_in_scaler.transform(
np.concatenate([mgc, lf0, vuv, bap], axis=-1)
)
)
.float()
.to(device)
)
wav = vocoder.inference(voc_inp).view(-1).to("cpu").numpy()
else:
fftlen = pyworld.get_cheaptrick_fft_size(sr)
use_mcep_aperiodicity = bap.shape[-1] > 5
if use_mcep_aperiodicity:
mcep_aperiodicity_order = bap.shape[-1] - 1
alpha = pysptk.util.mcepalpha(sr)
aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(bap).astype(np.float64),
fftlen=fftlen,
alpha=alpha,
)
else:
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64), sr, fftlen
)
# fill aperiodicity with ones for unvoiced regions
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0
# WORLD fails catastrophically for out of range aperiodicity
aperiodicity = np.clip(aperiodicity, 0.0, 1.0)
# back to bap
if use_mcep_aperiodicity:
bap = pysptk.sp2mc(
aperiodicity,
order=mcep_aperiodicity_order,
alpha=alpha,
)
else:
bap = pyworld.code_aperiodicity(aperiodicity, sr).astype(np.float32)
aux_feats = (
torch.from_numpy(
vocoder_in_scaler.transform(np.concatenate([mgc, bap], axis=-1))
)
.float()
.to(device)
)
contf0 = np.exp(lf0)
if vocoder_config.data.sine_f0_type in ["contf0", "cf0"]:
f0_inp = contf0
elif vocoder_config.data.sine_f0_type == "f0":
f0_inp = contf0
f0_inp[vuv < vuv_threshold] = 0
wav = vocoder.inference(f0_inp, aux_feats).view(-1).to("cpu").numpy()
else:
# Fallback to WORLD
f0, spectrogram, aperiodicity = gen_world_params(
mgc,
lf0,
vuv,
bap,
sr,
use_world_codec=use_world_codec,
)
wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5)
return wav
def synthesize_from_mel(
device,
logmel,
lf0,
vuv,
sr,
vuv_threshold=0.3,
vocoder=None,
vocoder_in_scaler=None,
vocoder_config=None,
):
if vocoder is not None:
is_usfgan = "generator" in vocoder_config and "discriminator" in vocoder_config
assert vocoder_in_scaler is not None
if not is_usfgan:
# NOTE: So far vocoder models are trained on binary V/UV features
vuv = (vuv > vuv_threshold).astype(np.float32)
voc_inp = (
torch.from_numpy(
vocoder_in_scaler.transform(
np.concatenate([logmel, lf0, vuv], axis=-1)
)
)
.float()
.to(device)
)
wav = vocoder.inference(voc_inp).view(-1).to("cpu").numpy()
else:
# NOTE: So far vocoder models are trained on binary V/UV features
vuv = (vuv > vuv_threshold).astype(np.float32)
aux_feats = (
torch.from_numpy(vocoder_in_scaler.transform(logmel)).float().to(device)
)
contf0 = np.exp(lf0)
if vocoder_config.data.sine_f0_type in ["contf0", "cf0"]:
f0_inp = contf0
elif vocoder_config.data.sine_f0_type == "f0":
f0_inp = contf0
f0_inp[vuv < vuv_threshold] = 0
wav = vocoder.inference(f0_inp, aux_feats).view(-1).to("cpu").numpy()
else:
raise RuntimeError("Not supported")
return wav
@torch.no_grad()
def eval_model(
phase,
step,
netG,
in_feats,
out_feats,
lengths,
model_config,
out_scaler,
writer,
sr,
lf0_score_denorm=None,
trajectory_smoothing=True,
trajectory_smoothing_cutoff=50,
trajectory_smoothing_cutoff_f0=20,
use_world_codec=False,
vocoder=None,
vocoder_in_scaler=None,
vocoder_config=None,
vuv_threshold=0.3,
max_num_eval_utts=10,
):
if len(model_config.stream_sizes) >= 4:
return eval_spss_model(
phase=phase,
step=step,
netG=netG,
in_feats=in_feats,
out_feats=out_feats,
lengths=lengths,
model_config=model_config,
out_scaler=out_scaler,
writer=writer,
sr=sr,
lf0_score_denorm=lf0_score_denorm,
trajectory_smoothing=trajectory_smoothing,
trajectory_smoothing_cutoff=trajectory_smoothing_cutoff,
trajectory_smoothing_cutoff_f0=trajectory_smoothing_cutoff_f0,
use_world_codec=use_world_codec,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
vuv_threshold=vuv_threshold,
max_num_eval_utts=max_num_eval_utts,
)
else:
return eval_mel_model(
phase=phase,
step=step,
netG=netG,
in_feats=in_feats,
out_feats=out_feats,
lengths=lengths,
model_config=model_config,
out_scaler=out_scaler,
writer=writer,
sr=sr,
lf0_score_denorm=lf0_score_denorm,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
vuv_threshold=vuv_threshold,
max_num_eval_utts=max_num_eval_utts,
)
@torch.no_grad()
def eval_spss_model(
phase,
step,
netG,
in_feats,
out_feats,
lengths,
model_config,
out_scaler,
writer,
sr,
lf0_score_denorm=None,
trajectory_smoothing=True,
trajectory_smoothing_cutoff=50,
trajectory_smoothing_cutoff_f0=20,
use_world_codec=False,
vocoder=None,
vocoder_in_scaler=None,
vocoder_config=None,
vuv_threshold=0.3,
max_num_eval_utts=10,
):
if dist.is_initialized() and dist.get_rank() != 0:
return
if writer is None:
return
# make sure to be in eval mode
netG.eval()
prediction_type = (
netG.module.prediction_type()
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP)
else netG.prediction_type()
)
utt_indices = list(range(max_num_eval_utts))
utt_indices = utt_indices[: min(len(utt_indices), len(in_feats))]
if np.any(model_config.has_dynamic_features):
static_stream_sizes = get_static_stream_sizes(
model_config.stream_sizes,
model_config.has_dynamic_features,
model_config.num_windows,
)
else:
static_stream_sizes = model_config.stream_sizes
rawsp_output = False
for utt_idx in utt_indices:
out_feats_denorm_ = out_scaler.inverse_transform(
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0)
)
mgc, lf0, vuv, bap = get_static_features(
out_feats_denorm_,
model_config.num_windows,
model_config.stream_sizes,
model_config.has_dynamic_features,
)[:4]
mgc = mgc.squeeze(0).cpu().numpy()
lf0 = lf0.squeeze(0).cpu().numpy()
vuv = vuv.squeeze(0).cpu().numpy()
bap = bap.squeeze(0).cpu().numpy()
if lf0_score_denorm is not None:
lf0_score_denorm_ = (
lf0_score_denorm[utt_idx, : lengths[utt_idx]].cpu().numpy().reshape(-1)
)
else:
lf0_score_denorm_ = None
# log spectrogram case
rawsp_output = mgc.shape[1] >= 128
if rawsp_output:
sp = np.exp(mgc)
# NOTE: 60-dim mgc is asummed
mgc = pyworld.code_spectral_envelope(sp, sr, 60)
assert use_world_codec
else:
sp = None
wav = synthesize(
device=in_feats.device,
mgc=mgc,
lf0=lf0,
vuv=vuv,
bap=bap,
sr=sr,
use_world_codec=use_world_codec,
vuv_threshold=vuv_threshold,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
)
group = f"{phase}_utt{np.abs(utt_idx)}_reference"
wav = wav / np.abs(wav).max() if np.abs(wav).max() > 1.0 else wav
writer.add_audio(group, wav, step, sr)
# Run forward
outs = netG(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
)
# ResF0 case
if netG.has_residual_lf0_prediction():
outs, _ = outs
# Hybrid
if prediction_type == PredictionType.MULTISTREAM_HYBRID:
pred_mgc, pred_lf0, pred_vuv, pred_bap = outs
if isinstance(pred_lf0, tuple) and len(pred_lf0) == 3:
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
elif isinstance(pred_lf0, tuple) and len(pred_lf0) == 2:
pred_lf0 = pred_lf0[1]
if isinstance(pred_mgc, tuple) and len(pred_mgc) == 3:
pred_mgc = mdn_get_most_probable_sigma_and_mu(*pred_mgc)[1]
elif isinstance(pred_mgc, tuple) and len(pred_mgc) == 2:
pred_mgc = pred_mgc[1]
if isinstance(pred_bap, tuple) and len(pred_bap) == 3:
pred_bap = mdn_get_most_probable_sigma_and_mu(*pred_bap)[1]
elif isinstance(pred_bap, tuple) and len(pred_bap) == 2:
pred_bap = pred_bap[1]
pred_out_feats = torch.cat([pred_mgc, pred_lf0, pred_vuv, pred_bap], dim=-1)
elif prediction_type == PredictionType.PROBABILISTIC:
pi, sigma, mu = outs
pred_out_feats = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1]
else:
pred_out_feats = outs
# NOTE: multiple outputs
if isinstance(pred_out_feats, list):
pred_out_feats = pred_out_feats[-1]
if isinstance(pred_out_feats, tuple):
pred_out_feats = pred_out_feats[0]
if not isinstance(pred_out_feats, list):
pred_out_feats = [pred_out_feats]
# Run inference
if prediction_type in [
PredictionType.PROBABILISTIC,
PredictionType.MULTISTREAM_HYBRID,
]:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats, _ = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats, _ = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
pred_out_feats.append(inference_out_feats)
assert len(pred_out_feats) == 2
for idx, pred_out_feats_ in enumerate(pred_out_feats):
pred_out_feats_ = pred_out_feats_.squeeze(0).cpu().numpy()
pred_out_feats_denorm = (
out_scaler.inverse_transform(
torch.from_numpy(pred_out_feats_).to(in_feats.device)
)
.cpu()
.numpy()
)
if np.any(model_config.has_dynamic_features):
# (T, D_out) -> (T, static_dim)
pred_out_feats_denorm = multi_stream_mlpg(
pred_out_feats_denorm,
(out_scaler.scale_ ** 2).cpu().numpy(),
get_windows(model_config.num_windows),
model_config.stream_sizes,
model_config.has_dynamic_features,
)
pred_mgc, pred_lf0, pred_vuv, pred_bap = split_streams(
pred_out_feats_denorm, static_stream_sizes
)[:4]
# log spectrogram case
if rawsp_output:
pred_sp = np.exp(pred_mgc)
# NOTE: 60-dim mgc is asummed
pred_mgc = pyworld.code_spectral_envelope(pred_sp, sr, 60)
else:
pred_sp = None
# Remove high-frequency components of lf0/mgc/bap
# NOTE: Useful to reduce high-frequency artifacts
if trajectory_smoothing:
modfs = int(1 / 0.005)
pred_lf0[:, 0] = lowpass_filter(
pred_lf0[:, 0], modfs, cutoff=trajectory_smoothing_cutoff_f0
)
for d in range(pred_mgc.shape[1]):
pred_mgc[:, d] = lowpass_filter(
pred_mgc[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)
for d in range(pred_bap.shape[1]):
pred_bap[:, d] = lowpass_filter(
pred_bap[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)
# Generated sample
wav = synthesize(
device=in_feats.device,
mgc=pred_mgc,
lf0=pred_lf0,
vuv=pred_vuv,
bap=pred_bap,
sr=sr,
use_world_codec=use_world_codec,
vuv_threshold=vuv_threshold,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
)
wav = wav / np.abs(wav).max() if np.abs(wav).max() > 1.0 else wav
if idx == 1:
group = f"{phase}_utt{np.abs(utt_idx)}_inference"
else:
group = f"{phase}_utt{np.abs(utt_idx)}_forward"
writer.add_audio(group, wav, step, sr)
try:
plot_spsvs_params(
step,
writer,
mgc,
lf0,
vuv,
bap,
pred_mgc,
pred_lf0,
pred_vuv,
pred_bap,
lf0_score=lf0_score_denorm_,
group=group,
sr=sr,
use_world_codec=use_world_codec,
sp=sp,
pred_sp=pred_sp,
)
except IndexError as e:
# In _quantile_ureduce_func:
# IndexError: index -1 is out of bounds for axis 0 with size 0
print(str(e))
@torch.no_grad()
def eval_mel_model(
phase,
step,
netG,
in_feats,
out_feats,
lengths,
model_config,
out_scaler,
writer,
sr,
lf0_score_denorm=None,
vocoder=None,
vocoder_in_scaler=None,
vocoder_config=None,
vuv_threshold=0.3,
max_num_eval_utts=10,
):
if dist.is_initialized() and dist.get_rank() != 0:
return
if writer is None:
return
# make sure to be in eval mode
netG.eval()
prediction_type = (
netG.module.prediction_type()
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP)
else netG.prediction_type()
)
utt_indices = list(range(max_num_eval_utts))
utt_indices = utt_indices[: min(len(utt_indices), len(in_feats))]
assert not np.any(model_config.has_dynamic_features)
static_stream_sizes = model_config.stream_sizes
for utt_idx in utt_indices:
out_feats_denorm_ = out_scaler.inverse_transform(
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0)
)
logmel, lf0, vuv = get_static_features(
out_feats_denorm_,
model_config.num_windows,
model_config.stream_sizes,
model_config.has_dynamic_features,
)
logmel = logmel.squeeze(0).cpu().numpy()
lf0 = lf0.squeeze(0).cpu().numpy()
vuv = vuv.squeeze(0).cpu().numpy()
if lf0_score_denorm is not None:
lf0_score_denorm_ = (
lf0_score_denorm[utt_idx, : lengths[utt_idx]].cpu().numpy().reshape(-1)
)
else:
lf0_score_denorm_ = None
group = f"{phase}_utt{np.abs(utt_idx)}_reference"
if vocoder is not None:
wav = synthesize_from_mel(
device=in_feats.device,
logmel=logmel,
lf0=lf0,
vuv=vuv,
sr=sr,
vuv_threshold=vuv_threshold,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
)
wav = wav / np.abs(wav).max() if np.abs(wav).max() > 1.0 else wav
writer.add_audio(group, wav, step, sr)
# Run forward
outs = netG(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
out_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
)
# ResF0 case
if netG.has_residual_lf0_prediction():
outs, _ = outs
# Hybrid
if prediction_type == PredictionType.MULTISTREAM_HYBRID:
pred_logmel, pred_lf0, pred_vuv = outs
if isinstance(pred_lf0, tuple) and len(pred_lf0) == 3:
pred_lf0 = mdn_get_most_probable_sigma_and_mu(*pred_lf0)[1]
elif isinstance(pred_lf0, tuple) and len(pred_lf0) == 2:
pred_lf0 = pred_lf0[1]
if isinstance(pred_logmel, tuple) and len(pred_logmel) == 3:
pred_logmel = mdn_get_most_probable_sigma_and_mu(*pred_logmel)[1]
elif isinstance(pred_logmel, tuple) and len(pred_logmel) == 2:
pred_logmel = pred_logmel[1]
pred_out_feats = torch.cat([pred_logmel, pred_lf0, pred_vuv], dim=-1)
elif prediction_type == PredictionType.PROBABILISTIC:
pi, sigma, mu = outs
pred_out_feats = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1]
else:
pred_out_feats = outs
# NOTE: multiple outputs
if isinstance(pred_out_feats, list):
pred_out_feats = pred_out_feats[-1]
if isinstance(pred_out_feats, tuple):
pred_out_feats = pred_out_feats[0]
if not isinstance(pred_out_feats, list):
pred_out_feats = [pred_out_feats]
# Run inference
if prediction_type in [
PredictionType.PROBABILISTIC,
PredictionType.MULTISTREAM_HYBRID,
]:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats, _ = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats, _ = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
if isinstance(netG, nn.DataParallel) or isinstance(netG, DDP):
inference_out_feats = netG.module.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
else:
inference_out_feats = netG.inference(
in_feats[utt_idx, : lengths[utt_idx]].unsqueeze(0),
[lengths[utt_idx]],
)
pred_out_feats.append(inference_out_feats)
assert len(pred_out_feats) == 2
for idx, pred_out_feats_ in enumerate(pred_out_feats):
pred_out_feats_ = pred_out_feats_.squeeze(0).cpu().numpy()
pred_out_feats_denorm = (
out_scaler.inverse_transform(
torch.from_numpy(pred_out_feats_).to(in_feats.device)
)
.cpu()
.numpy()
)
pred_logmel, pred_lf0, pred_vuv = split_streams(
pred_out_feats_denorm, static_stream_sizes
)
if idx == 1:
group = f"{phase}_utt{np.abs(utt_idx)}_inference"
else:
group = f"{phase}_utt{np.abs(utt_idx)}_forward"
# Generated sample
if vocoder is not None:
wav = synthesize_from_mel(
device=in_feats.device,
logmel=pred_logmel,
lf0=pred_lf0,
vuv=pred_vuv,
sr=sr,
vuv_threshold=vuv_threshold,
vocoder=vocoder,
vocoder_in_scaler=vocoder_in_scaler,
vocoder_config=vocoder_config,
)
wav = wav / np.abs(wav).max() if np.abs(wav).max() > 1.0 else wav
writer.add_audio(group, wav, step, sr)
try:
plot_mel_params(
step,
writer,
logmel,
lf0,
vuv,
pred_logmel,
pred_lf0,
pred_vuv,
lf0_score=lf0_score_denorm_,
group=group,
sr=sr,
)
except IndexError as e:
# In _quantile_ureduce_func:
# IndexError: index -1 is out of bounds for axis 0 with size 0
print(str(e))
def _colorbar_wrap(fig, mesh, ax, format="%+2.f dB"):
try:
fig.colorbar(mesh, ax=ax, format=format)
except IndexError as e:
# In _quantile_ureduce_func:
# IndexError: index -1 is out of bounds for axis 0 with size 0
print(str(e))
[docs]
@torch.no_grad()
def plot_spsvs_params(
step,
writer,
mgc,
lf0,
vuv,
bap,
pred_mgc,
pred_lf0,
pred_vuv,
pred_bap,
lf0_score,
group,
sr,
use_world_codec=False,
sp=None,
pred_sp=None,
):
"""Plot acoustic parameters of parametric SVS
Args:
step (int): step of the current iteration
writer (tensorboard.SummaryWriter): tensorboard writer
mgc (np.ndarray): mgc
lf0 (np.ndarray): lf0
vuv (np.ndarray): vuv
bap (np.ndarray): bap
pred_mgc (np.ndarray): predicted mgc
pred_lf0 (np.ndarray): predicted lf0
pred_vuv (np.ndarray): predicted vuv
pred_bap (np.ndarray): predicted bap
f0_score (np.ndarray): lf0 score
group (str): group name
sr (int): sampling rate
use_world_codec (bool): use world codec for spectral envelope or not
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
assert writer is not None
fftlen = pyworld.get_cheaptrick_fft_size(sr)
alpha = pysptk.util.mcepalpha(sr)
hop_length = int(sr * 0.005)
use_mcep_aperiodicity = bap.shape[-1] > 5
# Log-F0
if lf0_score is not None:
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(timeaxis, lf0, linewidth=1.5, color="tab:blue", label="Target log-F0")
ax.plot(
timeaxis,
pred_lf0,
linewidth=1.5,
color="tab:orange",
label="Predicted log-F0",
)
ax.plot(
timeaxis,
lf0_score,
"--",
color="gray",
linewidth=1.3,
label="Note log-F0",
)
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Log-frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(min(lf0_score[lf0_score > 0]), min(lf0), min(pred_lf0)) - 0.1,
max(max(lf0_score), max(lf0), max(pred_lf0)) + 0.1,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/ContinuousLogF0", fig, step)
plt.close()
f0_score = lf0_score.copy()
note_indices = f0_score > 0
f0_score[note_indices] = np.exp(lf0_score[note_indices])
# F0
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
f0 = np.exp(lf0)
f0[vuv < 0.5] = 0
pred_f0 = np.exp(pred_lf0)
pred_f0[pred_vuv < 0.5] = 0
ax.plot(
timeaxis,
f0,
linewidth=1.5,
color="tab:blue",
label="Target F0",
)
ax.plot(
timeaxis,
pred_f0,
linewidth=1.5,
color="tab:orange",
label="Predicted F0",
)
ax.plot(timeaxis, f0_score, "--", linewidth=1.3, color="gray", label="Note F0")
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(min(f0_score[f0_score > 0]), min(np.exp(lf0)), min(np.exp(pred_lf0)))
- 10,
max(max(f0_score), max(np.exp(lf0)), max(np.exp(pred_lf0))) + 10,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/F0", fig, step)
plt.close()
# V/UV
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(timeaxis, vuv, linewidth=2, label="Target V/UV")
ax.plot(timeaxis, pred_vuv, "--", linewidth=2, label="Predicted V/UV")
ax.set_xlabel("Time [sec]")
ax.set_ylabel("V/UV")
ax.set_xlim(timeaxis[0], timeaxis[-1])
plt.legend(loc="upper right", borderaxespad=0, ncol=2)
plt.tight_layout()
writer.add_figure(f"{group}/VUV", fig, step)
plt.close()
# Spectrogram
fig, ax = plt.subplots(2, 1, figsize=(8, 6))
ax[0].set_title("Reference spectrogram")
ax[1].set_title("Predicted spectrogram")
if use_world_codec:
if sp is not None:
spectrogram = sp.T
else:
spectrogram = pyworld.decode_spectral_envelope(
np.ascontiguousarray(mgc), sr, fftlen
).T
else:
spectrogram = pysptk.mc2sp(mgc, fftlen=fftlen, alpha=alpha).T
mesh = librosa.display.specshow(
librosa.power_to_db(np.abs(spectrogram), ref=np.max),
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="hz",
cmap="viridis",
ax=ax[0],
)
_colorbar_wrap(fig, mesh, ax[0])
if use_world_codec:
if pred_sp is not None:
pred_spectrogram = pred_sp.T
else:
pred_spectrogram = pyworld.decode_spectral_envelope(
np.ascontiguousarray(pred_mgc), sr, fftlen
).T
else:
pred_spectrogram = pysptk.mc2sp(
np.ascontiguousarray(pred_mgc), fftlen=fftlen, alpha=alpha
).T
mesh = librosa.display.specshow(
librosa.power_to_db(np.abs(pred_spectrogram), ref=np.max),
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="hz",
cmap="viridis",
ax=ax[1],
)
_colorbar_wrap(fig, mesh, ax[1])
for a in ax:
a.set_ylim(0, sr // 2)
plt.tight_layout()
writer.add_figure(f"{group}/Spectrogram", fig, step)
plt.close()
# Aperiodicity
fig, ax = plt.subplots(2, 1, figsize=(8, 6))
ax[0].set_title("Reference aperiodicity")
ax[1].set_title("Predicted aperiodicity")
if use_mcep_aperiodicity:
aperiodicity = pysptk.mc2sp(bap, fftlen=fftlen, alpha=alpha).T
else:
aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sr, fftlen).T
mesh = librosa.display.specshow(
20 * np.log10(aperiodicity),
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="linear",
cmap="viridis",
ax=ax[0],
)
_colorbar_wrap(fig, mesh, ax[0])
if use_mcep_aperiodicity:
pred_aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(pred_bap), fftlen=fftlen, alpha=alpha
).T
else:
pred_aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(pred_bap).astype(np.float64), sr, fftlen
).T
mesh = librosa.display.specshow(
20 * np.log10(pred_aperiodicity),
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="linear",
cmap="viridis",
ax=ax[1],
)
_colorbar_wrap(fig, mesh, ax[1])
for a in ax:
a.set_ylim(0, sr // 2)
plt.tight_layout()
writer.add_figure(f"{group}/Aperiodicity", fig, step)
plt.close()
# GV for mgc
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
ax.plot(np.var(mgc, axis=0), "--", linewidth=2, label="Natural: global variances")
ax.plot(np.var(pred_mgc, axis=0), linewidth=2, label="Generated: global variances")
ax.legend()
ax.set_yscale("log")
ax.set_xlabel("Dimension of mgc")
min_ = min(np.var(mgc, axis=0).min(), np.var(pred_mgc, axis=0).min(), 1e-4)
ax.set_ylim(min_)
plt.tight_layout()
writer.add_figure(f"{group}/GV_mgc", fig, step)
plt.close()
# GV for bap
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
ax.plot(np.var(bap, axis=0), "--", linewidth=2, label="Natural: global variances")
ax.plot(np.var(pred_bap, axis=0), linewidth=2, label="Generated: global variances")
ax.legend()
ax.set_yscale("log")
ax.set_xlabel("Dimension of bap")
min_ = min(np.var(bap, axis=0).min(), np.var(pred_bap, axis=0).min(), 10)
ax.set_ylim(min_)
plt.tight_layout()
writer.add_figure(f"{group}/GV_bap", fig, step)
plt.close()
@torch.no_grad()
def plot_mel_params(
step,
writer,
logmel,
lf0,
vuv,
pred_logmel,
pred_lf0,
pred_vuv,
lf0_score,
group,
sr,
):
if dist.is_initialized() and dist.get_rank() != 0:
return
assert writer is not None
hop_length = int(sr * 0.005)
# Log-F0
if lf0_score is not None:
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(timeaxis, lf0, linewidth=1.5, color="tab:blue", label="Target log-F0")
ax.plot(
timeaxis,
pred_lf0,
linewidth=1.5,
color="tab:orange",
label="Predicted log-F0",
)
ax.plot(
timeaxis,
lf0_score,
"--",
color="gray",
linewidth=1.3,
label="Note log-F0",
)
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Log-frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(min(lf0_score[lf0_score > 0]), min(lf0), min(pred_lf0)) - 0.1,
max(max(lf0_score), max(lf0), max(pred_lf0)) + 0.1,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/ContinuousLogF0", fig, step)
plt.close()
f0_score = lf0_score.copy()
note_indices = f0_score > 0
f0_score[note_indices] = np.exp(lf0_score[note_indices])
# F0
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
f0 = np.exp(lf0)
f0[vuv < 0.5] = 0
pred_f0 = np.exp(pred_lf0)
pred_f0[pred_vuv < 0.5] = 0
ax.plot(
timeaxis,
f0,
linewidth=1.5,
color="tab:blue",
label="Target F0",
)
ax.plot(
timeaxis,
pred_f0,
linewidth=1.5,
color="tab:orange",
label="Predicted F0",
)
ax.plot(timeaxis, f0_score, "--", linewidth=1.3, color="gray", label="Note F0")
ax.set_xlabel("Time [sec]")
ax.set_ylabel("Frequency [Hz]")
ax.set_xlim(timeaxis[0], timeaxis[-1])
ax.set_ylim(
min(min(f0_score[f0_score > 0]), min(np.exp(lf0)), min(np.exp(pred_lf0)))
- 10,
max(max(f0_score), max(np.exp(lf0)), max(np.exp(pred_lf0))) + 10,
)
plt.legend(loc="upper right", borderaxespad=0, ncol=3)
plt.tight_layout()
writer.add_figure(f"{group}/F0", fig, step)
plt.close()
# V/UV
fig, ax = plt.subplots(1, 1, figsize=(8, 3))
timeaxis = np.arange(len(lf0)) * 0.005
ax.plot(timeaxis, vuv, linewidth=2, label="Target V/UV")
ax.plot(timeaxis, pred_vuv, "--", linewidth=2, label="Predicted V/UV")
ax.set_xlabel("Time [sec]")
ax.set_ylabel("V/UV")
ax.set_xlim(timeaxis[0], timeaxis[-1])
plt.legend(loc="upper right", borderaxespad=0, ncol=2)
plt.tight_layout()
writer.add_figure(f"{group}/VUV", fig, step)
plt.close()
# Mel-spectrogram
fig, ax = plt.subplots(2, 1, figsize=(8, 6))
ax[0].set_title("Reference spectrogram")
ax[1].set_title("Predicted spectrogram")
mesh = librosa.display.specshow(
logmel.T,
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="hz",
cmap="viridis",
ax=ax[0],
)
_colorbar_wrap(fig, mesh, ax[0])
mesh = librosa.display.specshow(
pred_logmel.T,
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="hz",
cmap="viridis",
ax=ax[1],
)
_colorbar_wrap(fig, mesh, ax[1])
for a in ax:
a.set_ylim(0, sr // 2)
plt.tight_layout()
writer.add_figure(f"{group}/Spectrogram", fig, step)
plt.close()