From 3afba34fc773c0c0a8d413fd71bd81c22ac3e1df Mon Sep 17 00:00:00 2001 From: Ivan Cheung Date: Mon, 2 Dec 2024 14:41:59 -0500 Subject: [PATCH] Added extra analysis step before rewriting --- podcastfy/content_generator.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/podcastfy/content_generator.py b/podcastfy/content_generator.py index d688151..69ae875 100644 --- a/podcastfy/content_generator.py +++ b/podcastfy/content_generator.py @@ -525,21 +525,33 @@ def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> # Run rewriting chain llm = self.llm - rewrite_prompt = PromptTemplate( + + analysis_prompt = PromptTemplate( input_variables=["transcript"], - template=config.get("rewrite_prompt_template", "Clean and improve this podcast transcript by deduping any repeated sections and improving conversational flow. Just output the improved conversation in the same format and nothing else. Do not add or omit any information.: \n\n{transcript}") + template=config.get("analysis_prompt_template", "You are a podcast editor. Analyze this podcast transcript and provide suggestions on how to remove duplicated/repeated facts, clean up flow and improve coherency. The structure of the transcript should be the same, without removal of too much data or addition of new facts: \n\n{transcript} \n\nAnalysis (bullet-points, with line numbers referring to problematic lines.):") + ) + analysis_chain = analysis_prompt | llm | StrOutputParser() + + rewrite_prompt = PromptTemplate( + input_variables=["transcript", "analysis"], + template=config.get("rewrite_prompt_template", "Rewrite the podcast transcript based on the following recommendations: \n\n{analysis}\n\nOriginal Transcript: \n\n{transcript}\n\nRewritten Transcript:") ) - logger.debug("Executing rewriting chain") rewrite_chain = rewrite_prompt | llm | StrOutputParser() try: - rewritten_response = rewrite_chain.invoke({"transcript": transcript}) + logger.debug("Executing analysis chain") + analysis = analysis_chain.invoke({"transcript": transcript}) + logger.debug(f"Successfully analyzed transcript: \n\n{analysis}") + + logger.debug("Executing rewriting chain") + rewritten_response = rewrite_chain.invoke({"analysis": analysis, "transcript": transcript}) if not rewritten_response: logger.warning("Rewriting chain returned empty response") # Fall back to original rewritten_response = transcript logger.debug("Successfully rewrote transcript") - logger.debug("Successfully rewrote transcript, BEFORE = ", transcript, "AFTER = ", rewritten_response) + logger.debug(f"Successfully rewrote transcript, BEFORE = \n\n{transcript}") + logger.debug(f"Successfully rewrote transcript, AFTER = \n\n{rewritten_response}") except Exception as e: logger.error(f"Error in rewriting chain: {str(e)}") rewritten_response = transcript # Fall back to original