diff --git a/lmdeploy/lite/quantization/calibration.py b/lmdeploy/lite/quantization/calibration.py index 4ae21e3f45..77ff74e234 100644 --- a/lmdeploy/lite/quantization/calibration.py +++ b/lmdeploy/lite/quantization/calibration.py @@ -253,9 +253,11 @@ def export(self, out_dir): inp_stats = self.collect_inputs_stats() torch.save(inp_stats, out_dir / 'inputs_stats.pth') + torch.cuda.empty_cache() out_stats = self.collect_outputs_stats() torch.save(out_stats, out_dir / 'outputs_stats.pth') + torch.cuda.empty_cache() def calibrate(self, data): """Forward pass through the model in inference mode with given data.""" @@ -267,6 +269,7 @@ def calibrate(self, data): model = self.model.model with torch.inference_mode(): _ = model(data.to(self.device)) + torch.cuda.empty_cache() def __enter__(self): """Prepares the Calibration object for a 'with' statement by @@ -440,6 +443,7 @@ def export(self, out_dir): inputs_stats['absmean'][name] = obs.absmean_val inputs_stats['ratios'][name] = obs.ratio torch.save(inputs_stats, out_dir / 'inputs_stats.pth') + torch.cuda.empty_cache() def _wrap_decoder_layers_for_search(self): """Method to wrap the decoder layers' forward functions for observing