diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index 068fa77..4c7a888 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -59,6 +59,7 @@ async def gpt_4o_complete( ) + async def gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index f06b3c5..4ee9ed5 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -373,7 +373,7 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): already_processed % len(PROMPTS["process_tickers"]) ] print( - f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", + f"{now_ticks} Processed {already_processed}({already_processed*100//len(ordered_chunks)}%) chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", end="", flush=True, ) diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index 567394d..37b90ee 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -17,24 +17,93 @@ ENCODER = None -def locate_json_string_body_from_string(content: str) -> Union[str, None]: - """Locate the JSON string body from a string""" - maybe_json_str = re.search(r"{.*}", content, re.DOTALL) - if maybe_json_str is not None: - return maybe_json_str.group(0) - else: +def extract_first_complete_json(s: str): + """Extract the first complete JSON object from the string using a stack to track braces.""" + stack = [] + first_json_start = None + + for i, char in enumerate(s): + if char == '{': + stack.append(i) + if first_json_start is None: + first_json_start = i + elif char == '}': + if stack: + start = stack.pop() + if not stack: + first_json_str = s[first_json_start:i+1] + try: + # Attempt to parse the JSON string + return json.loads(first_json_str.replace("\n", "")) + except json.JSONDecodeError as e: + logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...") + return None + finally: + first_json_start = None + logger.warning("No complete JSON object found in the input string.") + return None + +def parse_value(value: str): + """Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'""" + value = value.strip() + + if value == "null": return None + elif value == "true": + return True + elif value == "false": + return False + else: + # Try to convert to int or float + try: + if '.' in value: # If there's a dot, it might be a float + return float(value) + else: + return int(value) + except ValueError: + # If conversion fails, return the value as-is (likely a string) + return value.strip('"') # Remove surrounding quotes if they exist + +def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False): + """Extract key values from a non-standard or malformed JSON string, handling nested objects.""" + extracted_values = {} + + # Enhanced pattern to match both quoted and unquoted values, as well as nested objects + regex_pattern = r'(?P"?\w+"?)\s*:\s*(?P{[^}]*}|".*?"|[^,}]+)' + + for match in re.finditer(regex_pattern, json_string, re.DOTALL): + key = match.group('key').strip('"') # Strip quotes from key + value = match.group('value').strip() + + # If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it + if value.startswith('{') and value.endswith('}'): + extracted_values[key] = extract_values_from_json(value) + else: + # Parse the value into the appropriate type (int, float, bool, etc.) + extracted_values[key] = parse_value(value) + + if not extracted_values: + logger.warning("No values could be extracted from the string.") + + return extracted_values def convert_response_to_json(response: str) -> dict: - json_str = locate_json_string_body_from_string(response) - assert json_str is not None, f"Unable to parse JSON from response: {response}" - try: - data = json.loads(json_str) - return data - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON: {json_str}") - raise e from None + """Convert response string to JSON, with error handling and fallback to non-standard JSON extraction.""" + prediction_json = extract_first_complete_json(response) + + if prediction_json is None: + logger.info("Attempting to extract values from a non-standard JSON string...") + prediction_json = extract_values_from_json(response, allow_no_quotes=True) + + if not prediction_json: + logger.error("Unable to extract meaningful data from the response.") + else: + logger.info("JSON data successfully extracted.") + + return prediction_json + + def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): @@ -149,26 +218,23 @@ async def __call__(self, *args, **kwargs) -> np.ndarray: # Decorators ------------------------------------------------------------------------ -def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): - """Add restriction of maximum async calling times for a async func""" - def final_decro(func): - """Not using async.Semaphore to aovid use nest-asyncio""" - __current_size = 0 + +def limit_async_func_call(max_size: int): + """Add restriction of maximum async calling times for a async func using Semaphore""" + + def final_decorator(func): + # Create a semaphore with the given max_size + semaphore = asyncio.Semaphore(max_size) @wraps(func) - async def wait_func(*args, **kwargs): - nonlocal __current_size - while __current_size >= max_size: - await asyncio.sleep(waitting_time) - __current_size += 1 - result = await func(*args, **kwargs) - __current_size -= 1 - return result + async def wrapped_func(*args, **kwargs): + async with semaphore: # Acquire the semaphore + return await func(*args, **kwargs) # Run the async function - return wait_func + return wrapped_func - return final_decro + return final_decorator def wrap_embedding_func_with_attrs(**kwargs): diff --git a/tests/test_json_parsing.py b/tests/test_json_parsing.py new file mode 100644 index 0000000..ea8ab6e --- /dev/null +++ b/tests/test_json_parsing.py @@ -0,0 +1,132 @@ +import unittest +from loguru import logger +from nano_graphrag._utils import convert_response_to_json + +class TestJSONExtraction(unittest.TestCase): + + def setUp(self): + """Set up runs before each test case.""" + logger.remove() + logger.add(lambda msg: None) # disallow output + + def test_standard_json(self): + """Test standard JSON extraction.""" + response = ''' + { + "reasoning": "This is a test.", + "answer": 42, + "data": {"key1": "value1", "key2": "value2"} + } + ''' + expected = { + "reasoning": "This is a test.", + "answer": 42, + "data": {"key1": "value1", "key2": "value2"} + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_non_standard_json_without_quotes(self): + """Test non-standard JSON without quotes on numbers and booleans.""" + response = ''' + { + "reasoning": "Boolean and numbers test.", + "answer": 42, + "isCorrect": true, + "data": {key1: value1} + } + ''' + expected = { + "reasoning": "Boolean and numbers test.", + "answer": 42, + "isCorrect": True, + "data": {"key1": "value1"} + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_nested_json(self): + """Test extraction of nested JSON objects.""" + response = ''' + { + "reasoning": "Nested structure.", + "answer": 42, + "data": {"nested": {"key": "value"}} + } + ''' + expected = { + "reasoning": "Nested structure.", + "answer": 42, + "data": { + "nested": {"key": "value"} + } + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_malformed_json(self): + """Test handling of malformed JSON.""" + response = ''' + Some text before JSON + { + "reasoning": "This is malformed.", + "answer": 42, + "data": {"key": "value"} + } + Some text after JSON + ''' + expected = { + "reasoning": "This is malformed.", + "answer": 42, + "data": {"key": "value"} + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_incomplete_json(self): + """Test handling of incomplete JSON.""" + response = ''' + { + "reasoning": "Incomplete structure", + "answer": 42 + ''' + expected = { + "reasoning": "Incomplete structure", + "answer": 42 + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_value_with_special_characters(self): + """Test JSON with special characters in values.""" + response = ''' + { + "reasoning": "Special characters !@#$%^&*()", + "answer": 42, + "data": {"key": "value with special characters !@#$%^&*()"} + } + ''' + expected = { + "reasoning": "Special characters !@#$%^&*()", + "answer": 42, + "data": {"key": "value with special characters !@#$%^&*()"} + } + self.assertEqual(convert_response_to_json(response), expected) + + def test_boolean_and_null_values(self): + """Test JSON with boolean and null values.""" + response = ''' + { + "reasoning": "Boolean and null test.", + "isCorrect": true, + "isWrong": false, + "unknown": null, + "answer": 42 + } + ''' + expected = { + "reasoning": "Boolean and null test.", + "isCorrect": True, + "isWrong": False, + "unknown": None, + "answer": 42 + } + self.assertEqual(convert_response_to_json(response), expected) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file