diff --git a/dnn/torch/osce/make_default_setup.py b/dnn/torch/osce/make_default_setup.py index 06add8faef2f0d222e1c76dc4092f07b406f0d8a..2b2956625b6c843a9131f2bce134930135198d21 100644 --- a/dnn/torch/osce/make_default_setup.py +++ b/dnn/torch/osce/make_default_setup.py @@ -36,7 +36,7 @@ from utils.templates import setup_dict parser = argparse.ArgumentParser() parser.add_argument('name', type=str, help='name of default setup file') -parser.add_argument('--model', choices=['lace'], help='model name', default='lace') +parser.add_argument('--model', choices=['lace', 'nolace'], help='model name', default='lace') parser.add_argument('--path2dataset', type=str, help='dataset path', default=None) args = parser.parse_args() diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py index c8dfc5d98e2190e65552a1e423f5774ad4224640..49a88ae27ed4b1b04df2f35c7181a362e0063d1d 100644 --- a/dnn/torch/osce/models/__init__.py +++ b/dnn/torch/osce/models/__init__.py @@ -28,9 +28,11 @@ """ from .lace import LACE +from .no_lace import NoLACE model_dict = { - 'lace': LACE + 'lace': LACE, + 'nolace': NoLACE } diff --git a/dnn/torch/osce/models/no_lace.py b/dnn/torch/osce/models/no_lace.py index 8b2de87f03259f554c824b63c6c649ba77da84e9..4524906d3b5f88a0588ea9b0a1ef5f4f36ed0687 100644 --- a/dnn/torch/osce/models/no_lace.py +++ b/dnn/torch/osce/models/no_lace.py @@ -1,33 +1,4 @@ -""" -/* Copyright (c) 2023 Amazon - Written by Jan Buethe */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER - OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ -""" - import torch from torch import nn import torch.nn.functional as F @@ -64,7 +35,8 @@ class NoLACE(NNSBase): hidden_feature_dim=64, partial_lookahead=True, norm_p=2, - avg_pool_k=4): + avg_pool_k=4, + pool_after=False): super().__init__(skip=skip, preemph=preemph) @@ -103,9 +75,9 @@ class NoLACE(NNSBase): self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) # non-linear transforms - self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) - self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) - self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, pool_after=pool_after) # combinators self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) @@ -128,6 +100,7 @@ class NoLACE(NNSBase): feature_net_flops = self.feature_net.flop_count(frame_rate) comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + shape_flops = self.tdshape1.flop_count(rate) + self.tdshape2.flop_count(rate) + self.tdshape3.flop_count(rate) feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) @@ -137,7 +110,7 @@ class NoLACE(NNSBase): print(f"adaptive conv: {af_flops / 1e6} MFLOPS") print(f"feature transforms: {feature_flops / 1e6} MFLOPS") - return feature_net_flops + comb_flops + af_flops + feature_flops + return feature_net_flops + comb_flops + af_flops + feature_flops + shape_flops def feature_transform(self, f, layer): f = f.permute(0, 2, 1) @@ -180,4 +153,4 @@ class NoLACE(NNSBase): y = torch.cat((y1, y2), dim=1) y = self.af4(y, cf, debug=debug) - return y \ No newline at end of file + return y diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 1232710f5908674f8cd7924e6f7b9fb1de10e006..c9648f440b8bf062e117a2dc4492cbb4881ba4b9 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -70,7 +70,6 @@ lace_setup = { 'lr': 5.e-4, 'lr_decay_factor': 2.5e-5, 'epochs': 50, - 'frames_per_sample': 50, 'loss': { 'w_l1': 0, 'w_lm': 0, @@ -86,7 +85,63 @@ lace_setup = { } +nolace_setup = { + 'dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/training', + 'validation_dataset': '/local/datasets/silk_enhancement_v2_full_6to64kbps/validation', + 'model': { + 'name': 'nolace', + 'args': [], + 'kwargs': { + 'avg_pool_k': 4, + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'hidden_feature_dim': 96, + 'kernel_size': 15, + 'num_features': 93, + 'numbits_embedding_dim': 8, + 'numbits_range': [50, 650], + 'partial_lookahead': True, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'skip': 91 + } + }, + 'data': { + 'frames_per_sample': 100, + 'no_pitch_value': 7, + 'preemph': 0.85, + 'skip': 91, + 'pitch_hangover': 8, + 'acorr_radius': 2, + 'num_bands_clean_spec': 64, + 'num_bands_noisy_spec': 18, + 'noisy_spec_scale': 'opus', + 'pitch_hangover': 8, + }, + 'training': { + 'batch_size': 256, + 'lr': 5.e-4, + 'lr_decay_factor': 2.5e-5, + 'epochs': 50, + 'loss': { + 'w_l1': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_wsc': 0, + 'w_xcorr': 0, + 'w_sxcorr': 1, + 'w_l2': 10, + 'w_slm': 2 + } + } +} + setup_dict = { 'lace': lace_setup, + 'nolace': nolace_setup }