Source code for nnsvs.pretrained

import os
import shutil
import tarfile
from os.path import join
from pathlib import Path
from urllib.request import urlretrieve

from nnsvs.util import dynamic_import
from tqdm.auto import tqdm

DEFAULT_CACHE_DIR = join(os.path.expanduser("~"), ".cache", "nnsvs")
CACHE_DIR = os.environ.get("NNSVS_CACHE_DIR", DEFAULT_CACHE_DIR)

model_registry = {
    "r9y9/yoko_latest": {
        "url": "https://www.dropbox.com/s/k8mya65yt52m0ps/yoko_latest.tar.gz?dl=1",
        "_target_": "nnsvs.svs:SPSVS",
    },
    "r9y9/20220322_yoko_timelag_mdn_duration_mdn_acoustic_resf0conv": {
        "url": "https://www.dropbox.com/s/olsfyqol9ryk5kx/"
        "20220322_yoko_timelag_mdn_duration_mdn_acoustic_resf0conv.tar.gz?dl=1",
        "_target_": "nnsvs.svs:SPSVS",
    },
}


[docs] def create_svs_engine(name, *args, **kwargs): """Create SVS engine from pretrained models. Args: name (str): Pre-trained model name args (list): Additional args for instantiation kwargs (dict): Additional kwargs for instantiation Returns: object: instance of SVS engine """ if name not in model_registry: s = "" for model_id in get_available_model_ids(): s += f"'{model_id}'\n" raise ValueError( f""" Pretrained model '{name}' does not exist! Available models: {s[:-1]}""" ) # download if not exists model_dir = retrieve_pretrained_model(name) # create an instance return dynamic_import(model_registry[name]["_target_"])(model_dir, *args, **kwargs)
[docs] def get_available_model_ids(): """Get available pretrained model names. Returns: list: List of available pretrained model names. """ return list(model_registry.keys())
# https://github.com/tqdm/tqdm#hooks-and-callbacks class _TqdmUpTo(tqdm): # type: ignore def update_to(self, b=1, bsize=1, tsize=None): if tsize is not None: self.total = tsize return self.update(b * bsize - self.n) def is_pretrained_model_ready(name): out_dir = Path(CACHE_DIR) / name if out_dir.exists() and len(list(out_dir.glob("*.pth"))) == 0: return False return out_dir.exists()
[docs] def retrieve_pretrained_model(name): """Retrieve pretrained model from local cache or download from GitHub. Args: name (str): Name of pretrained model. Returns: str: Path to the pretrained model. Raises: ValueError: If the pretrained model is not found. """ global model_registry if name not in model_registry: s = "" for model_id in get_available_model_ids(): s += f"'{model_id}'\n" raise ValueError( f""" Pretrained model '{name}' does not exist! Available models: {s[:-1]}""" ) url = model_registry[name]["url"] # NOTE: assuming that filename and extracted is the same out_dir = Path(CACHE_DIR) / name model_dir = out_dir.parent out_dir.mkdir(parents=True, exist_ok=True) filename = Path(CACHE_DIR) / f"{name}.tar.gz" # re-download models if out_dir.exists() and len(list(out_dir.glob("*.pth"))) == 0: shutil.rmtree(out_dir) if not out_dir.exists(): print('Downloading: "{}"'.format(url)) with _TqdmUpTo( unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"{name}.tar.gz", ) as t: # all optional kwargs urlretrieve(url, filename, reporthook=t.update_to) t.total = t.n with tarfile.open(filename, mode="r|gz") as f: f.extractall(path=model_dir) os.remove(filename) return out_dir