Skip to content

Commit

Permalink
fix: sub module conflict error (Thriftpy#295)
Browse files Browse the repository at this point in the history
* fix: sub module name conflict error
  • Loading branch information
StellarisW authored Jan 14, 2025
1 parent b83fbae commit 463478f
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 13 deletions.
1 change: 1 addition & 0 deletions tests/parser-cases/foo.bar.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "foo/bar.thrift"
Empty file.
1 change: 1 addition & 0 deletions tests/parser-cases/include.thrift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "included.thrift"
include "include/included_1.thrift"

const included.Timestamp datetime = 1422009523
1 change: 1 addition & 0 deletions tests/parser-cases/include/included_1.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include "included_2.thrift"
Empty file.
8 changes: 5 additions & 3 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def test_load_struct():
def test_load_union():
assert storm_tt.JavaObjectArg.__base__ == TPayload
assert storm.JavaObjectArg.thrift_spec == \
storm_tt.JavaObjectArg.thrift_spec
storm_tt.JavaObjectArg.thrift_spec


def test_load_exc():
assert ab_tt.PersonNotExistsError.__base__ == TException
assert ab.PersonNotExistsError.thrift_spec == \
ab_tt.PersonNotExistsError.thrift_spec
ab_tt.PersonNotExistsError.thrift_spec


def test_load_service():
Expand All @@ -70,4 +70,6 @@ def test_load_include():
g = load("parent.thrift")

ts = g.Greet.thrift_spec
assert ts[1][2] == b.Hello and ts[2][0] == TType.I64 and ts[3][2] == b.Code
assert (ts[1][2].thrift_spec == b.Hello.thrift_spec and
ts[2][0] == TType.I64 and
ts[3][2]._NAMES_TO_VALUES == b.Code._NAMES_TO_VALUES)
25 changes: 23 additions & 2 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import sys
import threading

import pytest
Expand Down Expand Up @@ -36,8 +36,26 @@ def test_constants():

def test_include():
thrift = load('parser-cases/include.thrift', include_dirs=[
'./parser-cases'])
'./parser-cases'], module_name='include_thrift')
assert thrift.datetime == 1422009523
assert sys.modules['include_thrift'] is not None
assert sys.modules['included_thrift'] is not None
assert sys.modules['include.included_1_thrift'] is not None
assert sys.modules['include.included_2_thrift'] is not None


def test_include_with_module_name_prefix():
load('parser-cases/include.thrift', module_name='parser_cases.include_thrift')
assert sys.modules['parser_cases.include_thrift'] is not None
assert sys.modules['parser_cases.included_thrift'] is not None
assert sys.modules['parser_cases.include.included_1_thrift'] is not None
assert sys.modules['parser_cases.include.included_2_thrift'] is not None


def test_include_conflict():
with pytest.raises(ThriftParserError) as excinfo:
load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift')
assert 'Module name conflict between' in str(excinfo.value)


def test_cpp_include():
Expand Down Expand Up @@ -295,6 +313,9 @@ def test_thrift_meta():


def test_load_fp():
from thriftpy2.parser import threadlocal
threadlocal.__dict__.clear()

thrift = None
with open('parser-cases/shared.thrift') as thrift_fp:
thrift = load_fp(thrift_fp, 'shared_thrift')
Expand Down
23 changes: 16 additions & 7 deletions thriftpy2/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import types

from .parser import parse, parse_fp, threadlocal, _cast
from .exc import ThriftParserError
from .exc import ThriftParserError, ThriftModuleNameConflict
from ..thrift import TPayloadMeta


Expand All @@ -41,12 +41,21 @@ def load(path,
# add sub modules to sys.modules recursively
if real_module:
sys.modules[module_name] = thrift
sub_modules = thrift.__thrift_meta__["includes"][:]
while sub_modules:
module = sub_modules.pop()
if module not in sys.modules:
sys.modules[module.__name__] = module
sub_modules.extend(module.__thrift_meta__["includes"])
include_thrifts = thrift.__thrift_meta__["includes"][:]
while include_thrifts:
include_thrift = include_thrifts.pop()
registered_thrift = sys.modules.get(include_thrift.__thrift_module_name__)
if registered_thrift is None:
sys.modules[include_thrift.__thrift_module_name__] = include_thrift
if hasattr(include_thrift, "__thrift_meta__"):
include_thrifts.extend(
include_thrift.__thrift_meta__["includes"][:])
else:
if registered_thrift.__thrift_file__ != include_thrift.__thrift_file__:
raise ThriftModuleNameConflict(
'Module name conflict between "%s" and "%s"' %
(registered_thrift.__thrift_file__, include_thrift.__thrift_file__)
)
return thrift


Expand Down
4 changes: 4 additions & 0 deletions thriftpy2/parser/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class ThriftParserError(Exception):
pass


class ThriftModuleNameConflict(ThriftParserError):
pass


class ThriftLexerError(ThriftParserError):
pass

Expand Down
16 changes: 15 additions & 1 deletion thriftpy2/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,21 @@ def p_include(p):
for include_dir in replace_include_dirs:
path = os.path.join(include_dir, p[2])
if os.path.exists(path):
child = parse(path)
thrift_file_name_module = os.path.basename(thrift.__thrift_file__)
if thrift_file_name_module.endswith(".thrift"):
thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift"
module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module)

child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__))
child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift")
child_module_name = module_prefix + child_module_name

child = parse(path, module_name=child_module_name)
child_include_module_name = os.path.basename(path)
if child_include_module_name.endswith(".thrift"):
child_include_module_name = child_include_module_name[:-7]
setattr(child, '__name__', child_include_module_name)
setattr(child, '__thrift_module_name__', child_module_name)
setattr(thrift, child.__name__, child)
_add_thrift_meta('includes', child)
return
Expand Down

0 comments on commit 463478f

Please sign in to comment.