Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

i18n: Support SQLAlchemy 2 MappedColumn #705

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions sqlalchemy_utils/i18n.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
except ImportError:
babel = None

try:
from sqlalchemy.orm import MappedColumn
except ImportError:
MappedColumn = None


def get_locale():
try:
Expand All @@ -24,6 +29,13 @@ def get_locale():
)


def get_key(attr):
if MappedColumn is not None and isinstance(attr, MappedColumn):
return attr.column.key
else:
return attr.key


def cast_locale(obj, locale, attr):
"""
Cast given locale to string. Supports also callbacks that return locales.
Expand All @@ -35,7 +47,7 @@ def cast_locale(obj, locale, attr):
"""
if callable(locale):
try:
locale = locale(obj, attr.key)
locale = locale(obj, get_key(attr))
except TypeError:
try:
locale = locale(obj)
Expand Down Expand Up @@ -83,26 +95,26 @@ def getter_factory(self, attr):
def getter(obj):
current_locale = cast_locale(obj, self.current_locale, attr)
try:
return getattr(obj, attr.key)[current_locale]
return getattr(obj, get_key(attr))[current_locale]
except (TypeError, KeyError):
default_locale = cast_locale(obj, self.default_locale, attr)
try:
return getattr(obj, attr.key)[default_locale]
return getattr(obj, get_key(attr))[default_locale]
except (TypeError, KeyError):
return self.default_value
return getter

def setter_factory(self, attr):
def setter(obj, value):
if getattr(obj, attr.key) is None:
setattr(obj, attr.key, {})
if getattr(obj, get_key(attr)) is None:
setattr(obj, get_key(attr), {})
locale = cast_locale(obj, self.current_locale, attr)
getattr(obj, attr.key)[locale] = value
getattr(obj, get_key(attr))[locale] = value
return setter

def expr_factory(self, attr):
def expr(cls):
cls_attr = getattr(cls, attr.key)
cls_attr = getattr(cls, get_key(attr))
current_locale = cast_locale_expr(cls, self.current_locale, attr)
default_locale = cast_locale_expr(cls, self.default_locale, attr)
return sa.func.coalesce(
Expand Down