diff --git a/podcastfy/content_generator.py b/podcastfy/content_generator.py index f3cfd91..aa1454d 100644 --- a/podcastfy/content_generator.py +++ b/podcastfy/content_generator.py @@ -14,7 +14,7 @@ from langchain_community.chat_models import ChatLiteLLM from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.llms.llamafile import Llamafile -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain import hub from podcastfy.utils.config_conversation import load_conversation_config @@ -503,6 +503,7 @@ def clean(self, # Then apply additional long-form specific cleaning return self._clean_transcript_response(standard_clean, config) + def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> str: """ Clean transcript using a two-step process with LLM-based cleaning. @@ -522,7 +523,40 @@ def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> """ logger.debug("Starting transcript cleaning process") - final_transcript = self._fix_alternating_tags(transcript) + # Run rewriting chain + llm = self.llm + + analysis_prompt = PromptTemplate( + input_variables=["transcript"], + template=config.get("analysis_prompt_template", "You are a podcast editor. Analyze this podcast transcript and identify duplicated/repeated lines and recommendations to improve flow. Do not remove too many facts or add any new facts: \n\n{transcript} \n\nAnalysis (bullet-points, with line numbers referring to problematic lines.):") + ) + analysis_chain = analysis_prompt | llm | StrOutputParser() + + rewrite_prompt = PromptTemplate( + input_variables=["transcript", "analysis"], + template=config.get("rewrite_prompt_template", "Rewrite the podcast transcript by applying only the following recommendations. Refrain from shortening the transcript too much.\n\nRecommendations: \n\n{analysis}\n\nOriginal Transcript: \n\n{transcript}\n\nRewritten Transcript:") + ) + rewrite_chain = rewrite_prompt | llm | StrOutputParser() + + try: + logger.debug("Executing analysis chain") + analysis = analysis_chain.invoke({"transcript": transcript}) + logger.debug(f"Successfully analyzed transcript: \n\n{analysis}") + + logger.debug("Executing rewriting chain") + rewritten_response = rewrite_chain.invoke({"analysis": analysis, "transcript": transcript}) + if not rewritten_response: + logger.warning("Rewriting chain returned empty response") + # Fall back to original + rewritten_response = transcript + logger.debug("Successfully rewrote transcript") + logger.debug(f"Successfully rewrote transcript, BEFORE = \n\n{transcript}") + logger.debug(f"Successfully rewrote transcript, AFTER = \n\n{rewritten_response}") + except Exception as e: + logger.error(f"Error in rewriting chain: {str(e)}") + rewritten_response = transcript # Fall back to original + + final_transcript = self._fix_alternating_tags(rewritten_response) logger.debug("Completed transcript cleaning process")