diff --git a/scrapy_djangoitem/__init__.py b/scrapy_djangoitem/__init__.py index dc28cc8..6604194 100644 --- a/scrapy_djangoitem/__init__.py +++ b/scrapy_djangoitem/__init__.py @@ -3,7 +3,6 @@ from scrapy.item import Field, Item, ItemMeta - class DjangoItemMeta(ItemMeta): def __new__(mcs, class_name, bases, attrs): @@ -11,13 +10,13 @@ def __new__(mcs, class_name, bases, attrs): cls.fields = cls.fields.copy() if cls.django_model: - cls._model_fields = [] + cls._model_fields = {} cls._model_meta = cls.django_model._meta for model_field in cls._model_meta.fields: if not model_field.auto_created: if model_field.name not in cls.fields: cls.fields[model_field.name] = Field() - cls._model_fields.append(model_field.name) + cls._model_fields[model_field.name] = model_field return cls @@ -67,7 +66,12 @@ def _get_errors(self, exclude=None): @property def instance(self): if self._instance is None: - modelargs = dict((k, self.get(k)) for k in self._values - if k in self._model_fields) + modelargs = {} + for k in self._values: + if k in self._model_fields: + if self._model_fields[k].is_relation: + modelargs[k] = self._model_fields[k].related_model(pk=self.get(k)) + else: + modelargs[k] = self.get(k) self._instance = self.django_model(**modelargs) return self._instance diff --git a/tests/models.py b/tests/models.py index 7435727..f85eb9a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -16,3 +16,12 @@ class IdentifiedPerson(models.Model): class Meta: app_label = 'test_djangoitem' + + +class Property(models.Model): + person = models.ForeignKey(Person) + name = models.CharField(max_length=255) + description = models.TextField() + + class Meta: + app_label = 'test_djangoitem' diff --git a/tests/test_djangoitem.py b/tests/test_djangoitem.py index 80897f5..1ee817e 100644 --- a/tests/test_djangoitem.py +++ b/tests/test_djangoitem.py @@ -6,7 +6,7 @@ django.setup() from scrapy_djangoitem import DjangoItem, Field -from tests.models import Person, IdentifiedPerson +from tests.models import Person, IdentifiedPerson, Property class BasePersonItem(DjangoItem): @@ -25,6 +25,10 @@ class IdentifiedPersonItem(DjangoItem): django_model = IdentifiedPerson +class PropertyItem(DjangoItem): + django_model = Property + + class DjangoItemTest(unittest.TestCase): def assertSortedEqual(self, first, second, msg=None): @@ -100,3 +104,11 @@ def test_default_field_values(self): i = BasePersonItem() person = i.save(commit=False) self.assertEqual(person.name, 'Robot') + + def test_foreign_key(self): + i = PropertyItem() + i['name'] = 'White House' + i['description'] = 'White House' + i['person'] = 1 + p = i.save(commit=False) + self.assertTrue(p)