diff --git a/configs_backup/single_stem_models/config_apollo.yaml b/configs_backup/single_stem_models/config_apollo.yaml index b041129..aa8378f 100644 --- a/configs_backup/single_stem_models/config_apollo.yaml +++ b/configs_backup/single_stem_models/config_apollo.yaml @@ -29,5 +29,5 @@ augmentations: enable: false # enable or disable all augmentations (to fast disable if needed) inference: - batch_size: 4 - num_overlap: 4 + batch_size: 1 + num_overlap: 2 diff --git a/data_backup/msst_model_map.json b/data_backup/msst_model_map.json index 43b37a6..f922e4f 100644 --- a/data_backup/msst_model_map.json +++ b/data_backup/msst_model_map.json @@ -136,19 +136,19 @@ }, { "name": "Apollo_LQ_MP3_restoration.ckpt", - "config_path": "configs_backup/single_stem_models/config_apollo.yaml", + "config_path": "configs/single_stem_models/config_apollo.yaml", "model_type": "apollo", "link": "https://huggingface.co/Sucial/Music_Source_Sepetration_Models/resolve/main/Apollo_LQ_MP3_restoration.ckpt" }, { "name": "aspiration_mel_band_roformer_sdr_18.9845.ckpt", - "config_path": "configs_backup/single_stem_models/config_aspiration_mel_band_roformer.yaml", + "config_path": "configs/single_stem_models/config_aspiration_mel_band_roformer.yaml", "model_type": "mel_band_roformer", "link": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/aspiration_mel_band_roformer_sdr_18.9845.ckpt" }, { "name": "aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt", - "config_path": "configs_backup/single_stem_models/config_aspiration_mel_band_roformer.yaml", + "config_path": "configs/single_stem_models/config_aspiration_mel_band_roformer.yaml", "model_type": "mel_band_roformer", "link": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt" } diff --git a/models/look2hear/apollo.py b/models/look2hear/apollo.py index ef37b06..ccf1beb 100644 --- a/models/look2hear/apollo.py +++ b/models/look2hear/apollo.py @@ -313,6 +313,7 @@ def forward(self, input): this_RI = self.output[i](feature[:, i]).view(B * nch, 2, self.band_width[i], -1) est_spec.append(torch.complex(this_RI[:, 0], this_RI[:, 1])) est_spec = torch.cat(est_spec, 1) + est_spec = est_spec.to(dtype=torch.complex64) output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1) diff --git a/webUI.py b/webUI.py index 304f747..4f22de8 100644 --- a/webUI.py +++ b/webUI.py @@ -468,11 +468,11 @@ def save_config(selected_model, batch_size, dim_t, num_overlap, normalize): _, config_path, _, _ = get_msst_model(selected_model) config = load_configs(config_path) if config.inference.get('batch_size'): - config.inference['batch_size'] = int(batch_size) if batch_size.isdigit() else None + config.inference['batch_size'] = int(batch_size) if config.inference.get('dim_t'): - config.inference['dim_t'] = int(dim_t) if dim_t.isdigit() else None + config.inference['dim_t'] = int(dim_t) if config.inference.get('num_overlap'): - config.inference['num_overlap'] = int(num_overlap) if num_overlap.isdigit() else None + config.inference['num_overlap'] = int(num_overlap) if config.inference.get('normalize'): config.inference['normalize'] = normalize save_configs(config, config_path)