diff --git a/GraphDiff/GraphDiff.Tests/Models/TestModels.cs b/GraphDiff/GraphDiff.Tests/Models/TestModels.cs index 8988e6d..47468fc 100644 --- a/GraphDiff/GraphDiff.Tests/Models/TestModels.cs +++ b/GraphDiff/GraphDiff.Tests/Models/TestModels.cs @@ -13,7 +13,9 @@ namespace RefactorThis.GraphDiff.Tests.Models public class Entity { [Key] - public int Id { get; set; } + public int Id { get; set; } + + public Guid UniqueId { get; set; } [MaxLength(128)] public string Title { get; set; } diff --git a/GraphDiff/GraphDiff.Tests/Tests/OwnedCollectionBehaviours.cs b/GraphDiff/GraphDiff.Tests/Tests/OwnedCollectionBehaviours.cs index eb12535..8551e9f 100644 --- a/GraphDiff/GraphDiff.Tests/Tests/OwnedCollectionBehaviours.cs +++ b/GraphDiff/GraphDiff.Tests/Tests/OwnedCollectionBehaviours.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Data.Entity; using System.Collections.Generic; +using System; namespace RefactorThis.GraphDiff.Tests.Tests { @@ -341,5 +342,44 @@ public void ShouldMergeTwoCollectionsAndDecideOnUpdatesDeletesAndAdds() Assert.IsTrue(list[3].Title == "Finish"); } } + + [TestMethod] + public void ShouldUpdateItemInOwnedCollectionWithCustomKey() + { + var node1 = new TestNode + { + Title = "New Node", + OneToManyOwned = new List + { + new OneToManyOwnedModel { Title = "Hello", UniqueId = new Guid("DA6B78FF-BB7F-4FA1-8659-F64AC6457D14") } + } + }; + + int originalOwnedId; + using (var context = new TestDbContext()) + { + context.Nodes.Add(node1); + context.SaveChanges(); + originalOwnedId = node1.OneToManyOwned.First().Id; + } // Simulate detach + + node1.OneToManyOwned.First().Title = "What's up"; + node1.OneToManyOwned.First().Id = 0; //We will try to update on Guid + + using (var context = new TestDbContext()) + { + // Setup mapping + context.UpdateGraph(node1, map => map + .OwnedCollection(p => p.OneToManyOwned), + keysConfiguration: new KeysConfiguration() + .ForEntity(e => e.UniqueId)); + + context.SaveChanges(); + var node2 = context.Nodes.Include(p => p.OneToManyOwned).Single(p => p.Id == node1.Id); + Assert.IsNotNull(node2); + var owned = node2.OneToManyOwned.First(); + Assert.IsTrue(owned.OneParent == node2 && owned.Title == "What's up" && owned.Id == originalOwnedId); + } + } } } diff --git a/GraphDiff/GraphDiff/DbContextExtensions.cs b/GraphDiff/GraphDiff/DbContextExtensions.cs index 423a806..e29fdfe 100644 --- a/GraphDiff/GraphDiff/DbContextExtensions.cs +++ b/GraphDiff/GraphDiff/DbContextExtensions.cs @@ -7,11 +7,8 @@ using RefactorThis.GraphDiff.Internal; using RefactorThis.GraphDiff.Internal.Caching; using RefactorThis.GraphDiff.Internal.Graph; -using RefactorThis.GraphDiff.Internal.GraphBuilders; using System; -using System.Collections.Generic; using System.Data.Entity; -using System.Linq; using System.Linq.Expressions; namespace RefactorThis.GraphDiff @@ -26,10 +23,11 @@ public static class DbContextExtensions /// The root entity. /// The mapping configuration to define the bounds of the graph /// Update configuration overrides + /// The mapping configuration to define properties to use as key. The primary key is used if no other configuration is given. /// The attached entity graph - public static T UpdateGraph(this DbContext context, T entity, Expression, object>> mapping, UpdateParams updateParams = null) where T : class, new() + public static T UpdateGraph(this DbContext context, T entity, Expression, object>> mapping, UpdateParams updateParams = null, KeysConfiguration keysConfiguration = null) where T : class, new() { - return UpdateGraph(context, entity, mapping, null, updateParams); + return UpdateGraph(context, entity, mapping, null, updateParams, keysConfiguration); } /// @@ -40,10 +38,11 @@ public static class DbContextExtensions /// The root entity. /// Pre-configured mappingScheme /// Update configuration overrides + /// The mapping configuration to define properties to use as key. The primary key is used if no other configuration is given. /// The attached entity graph - public static T UpdateGraph(this DbContext context, T entity, string mappingScheme, UpdateParams updateParams = null) where T : class, new() + public static T UpdateGraph(this DbContext context, T entity, string mappingScheme, UpdateParams updateParams = null, KeysConfiguration keysConfiguration = null) where T : class, new() { - return UpdateGraph(context, entity, null, mappingScheme, updateParams); + return UpdateGraph(context, entity, null, mappingScheme, updateParams, keysConfiguration); } /// @@ -53,10 +52,11 @@ public static class DbContextExtensions /// The database context to attach / detach. /// The root entity. /// Update configuration overrides + /// The mapping configuration to define properties to use as key. The primary key is used if no other configuration is given. /// The attached entity graph - public static T UpdateGraph(this DbContext context, T entity, UpdateParams updateParams = null) where T : class, new() + public static T UpdateGraph(this DbContext context, T entity, UpdateParams updateParams = null, KeysConfiguration keysConfiguration = null) where T : class, new() { - return UpdateGraph(context, entity, null, null, updateParams); + return UpdateGraph(context, entity, null, null, updateParams, keysConfiguration); } /// @@ -69,7 +69,7 @@ public static class DbContextExtensions /// The aggregate loaded from the database public static T LoadAggregate(this DbContext context, Func keyPredicate, QueryMode queryMode = QueryMode.SingleQuery) where T : class { - var entityManager = new EntityManager(context); + var entityManager = new EntityManager(context, new KeysConfiguration()); var graph = new AggregateRegister(new CacheProvider()).GetEntityGraph(); var queryLoader = new QueryLoader(context, entityManager); @@ -85,12 +85,12 @@ public static T LoadAggregate(this DbContext context, Func keyPredic // other methods are convenience wrappers around this. private static T UpdateGraph(this DbContext context, T entity, Expression, object>> mapping, - string mappingScheme, UpdateParams updateParams) where T : class, new() + string mappingScheme, UpdateParams updateParams, KeysConfiguration keysConfiguration) where T : class, new() { GraphNode root; GraphDiffer differ; - var entityManager = new EntityManager(context); + var entityManager = new EntityManager(context, keysConfiguration ?? new KeysConfiguration()); var queryLoader = new QueryLoader(context, entityManager); var register = new AggregateRegister(new CacheProvider()); diff --git a/GraphDiff/GraphDiff/GraphDiff.csproj b/GraphDiff/GraphDiff/GraphDiff.csproj index be0b99c..6222b4b 100644 --- a/GraphDiff/GraphDiff/GraphDiff.csproj +++ b/GraphDiff/GraphDiff/GraphDiff.csproj @@ -72,6 +72,7 @@ + diff --git a/GraphDiff/GraphDiff/Internal/ChangeTracker.cs b/GraphDiff/GraphDiff/Internal/ChangeTracker.cs index 8fe557a..680b8db 100644 --- a/GraphDiff/GraphDiff/Internal/ChangeTracker.cs +++ b/GraphDiff/GraphDiff/Internal/ChangeTracker.cs @@ -83,9 +83,24 @@ public EntityState GetItemState(object item) public void UpdateItem(object from, object to, bool doConcurrencyCheck = false) { - if (doConcurrencyCheck && _context.Entry(to).State != EntityState.Added) + Type entityType = from.GetType(); + var toEntry = _context.Entry(to); + + if (doConcurrencyCheck && toEntry.State != EntityState.Added) + { + EnsureConcurrency(entityType, from, to); + } + + var metadata = _objectContext.MetadataWorkspace + .GetItems(DataSpace.OSpace) + .SingleOrDefault(p => p.FullName == entityType.FullName); + + // When a custom key is specified the primary key in the from object is ignored. + // We must set it to the actual value from database so it won't try to change the primary key + if (_entityManager.KeysConfiguration.HasConfigurationFor(entityType)) { - EnsureConcurrency(from, to); + // Copy inverted for primary key : from context entity to detached entity + _entityManager.CopyPrimaryKeyFields(entityType, from: to, to: from); } _context.Entry(to).CurrentValues.SetValues(from); @@ -171,10 +186,9 @@ public void AttachRequiredNavigationProperties(object updating, object persisted // Privates - private void EnsureConcurrency(object entity1, object entity2) + private void EnsureConcurrency(Type entityType, object entity1, object entity2) { // get concurrency properties of T - var entityType = ObjectContext.GetObjectType(entity1.GetType()); var metadata = _objectContext.MetadataWorkspace; var objType = metadata.GetItems(DataSpace.OSpace).Single(p => p.FullName == entityType.FullName); @@ -222,7 +236,7 @@ private object FindTrackedEntity(object entity) private object FindEntityByKey(object associatedEntity) { var associatedEntityType = ObjectContext.GetObjectType(associatedEntity.GetType()); - var keyFields = _entityManager.GetPrimaryKeyFieldsFor(associatedEntityType); + var keyFields = _entityManager.GetKeyFieldsFor(associatedEntityType); var keys = keyFields.Select(key => key.GetValue(associatedEntity, null)).ToArray(); return _context.Set(associatedEntityType).Find(keys); } diff --git a/GraphDiff/GraphDiff/Internal/EntityManager.cs b/GraphDiff/GraphDiff/Internal/EntityManager.cs index 6f0bd6f..e11dff6 100644 --- a/GraphDiff/GraphDiff/Internal/EntityManager.cs +++ b/GraphDiff/GraphDiff/Internal/EntityManager.cs @@ -15,6 +15,11 @@ namespace RefactorThis.GraphDiff.Internal /// internal interface IEntityManager { + /// + /// Gets custom key mappins for entities + /// + KeysConfiguration KeysConfiguration { get; } + /// /// Creates the unique entity key for an entity /// @@ -31,9 +36,14 @@ internal interface IEntityManager bool AreKeysIdentical(object entity1, object entity2); /// - /// Returns the primary key fields for a given entity type + /// Returns the key fields (using key configuration if available) for a given entity type + /// + IEnumerable GetKeyFieldsFor(Type entityType); + + /// + /// Copy primary key fields from an entity to another of the same type /// - IEnumerable GetPrimaryKeyFieldsFor(Type entityType); + void CopyPrimaryKeyFields(Type entityType, object from, object to); /// /// Retrieves the required navigation properties for the given type @@ -45,18 +55,27 @@ internal interface IEntityManager /// IEnumerable GetNavigationPropertiesForType(Type entityType); } - + internal class EntityManager : IEntityManager { private readonly DbContext _context; + private ObjectContext _objectContext { get { return ((IObjectContextAdapter)_context).ObjectContext; } } - public EntityManager(DbContext context) + public KeysConfiguration KeysConfiguration { get; private set; } + + public EntityManager(DbContext context, KeysConfiguration keysConfiguration) { + if (context == null) + throw new ArgumentNullException("context"); + if (keysConfiguration == null) + throw new ArgumentNullException("keysConfiguration"); + _context = context; + KeysConfiguration = keysConfiguration; } public EntityKey CreateEntityKey(object entity) @@ -66,7 +85,18 @@ public EntityKey CreateEntityKey(object entity) throw new ArgumentNullException("entity"); } - return _objectContext.CreateEntityKey(GetEntitySetName(entity.GetType()), entity); + var entityType = entity.GetType(); + var entitySetName = GetEntitySetName(entityType); + if (KeysConfiguration.HasConfigurationFor(entityType)) + { + var keyMembers = GetKeyFieldsFor(entityType) + .Select(p => new EntityKeyMember(p.Name, p.GetValue(entity, null))); + return new EntityKey(_objectContext.DefaultContainerName + "." + entitySetName, keyMembers); + } + else + { + return _objectContext.CreateEntityKey(entitySetName, entity); + } } public bool AreKeysIdentical(object newValue, object dbValue) @@ -81,16 +111,30 @@ public bool AreKeysIdentical(object newValue, object dbValue) public object CreateEmptyEntityWithKey(object entity) { - var instance = Activator.CreateInstance(entity.GetType()); - CopyPrimaryKeyFields(entity, instance); + var entityType = entity.GetType(); + var instance = Activator.CreateInstance(entityType); + CopyKeyFields(entityType, entity, instance); return instance; } + public IEnumerable GetKeyFieldsFor(Type entityType) + { + var keyColumns = KeysConfiguration.GetEntityKey(entityType); + if (keyColumns != null) + { + return keyColumns; + } + else + { + return GetPrimaryKeyFieldsFor(entityType); + } + } + public IEnumerable GetPrimaryKeyFieldsFor(Type entityType) { var metadata = _objectContext.MetadataWorkspace - .GetItems(DataSpace.OSpace) - .SingleOrDefault(p => p.FullName == entityType.FullName); + .GetItems(DataSpace.OSpace) + .SingleOrDefault(p => p.FullName == entityType.FullName); if (metadata == null) { @@ -134,9 +178,18 @@ private string GetEntitySetName(Type entityType) return set != null ? set.Name : null; } - private void CopyPrimaryKeyFields(object from, object to) + private void CopyKeyFields(Type entityType, object from, object to) + { + var keyProperties = GetKeyFieldsFor(entityType); + foreach (var keyProperty in keyProperties) + { + keyProperty.SetValue(to, keyProperty.GetValue(from, null), null); + } + } + + public void CopyPrimaryKeyFields(Type entityType, object from, object to) { - var keyProperties = GetPrimaryKeyFieldsFor(from.GetType()); + var keyProperties = GetPrimaryKeyFieldsFor(entityType); foreach (var keyProperty in keyProperties) { keyProperty.SetValue(to, keyProperty.GetValue(from, null), null); diff --git a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs index e176387..7a4264a 100644 --- a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs +++ b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs @@ -55,8 +55,8 @@ public T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery) throw new InvalidOperationException("GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method"); } - // Perform recursive update - var entityManager = new EntityManager(_dbContext); + // Perform recursive update + var entityManager = new EntityManager(_dbContext, _entityManager.KeysConfiguration); var changeTracker = new ChangeTracker(_dbContext, entityManager); _root.Update(changeTracker, entityManager, persisted, updating); diff --git a/GraphDiff/GraphDiff/Internal/QueryLoader.cs b/GraphDiff/GraphDiff/Internal/QueryLoader.cs index db4112f..a95843d 100644 --- a/GraphDiff/GraphDiff/Internal/QueryLoader.cs +++ b/GraphDiff/GraphDiff/Internal/QueryLoader.cs @@ -68,7 +68,7 @@ public T LoadEntity(Func keyPredicate, List includeStrings, private Func CreateKeyPredicateExpression(IObjectContextAdapter context, T entity) { // get key properties of T - var keyProperties = _entityManager.GetPrimaryKeyFieldsFor(typeof(T)).ToList(); + var keyProperties = _entityManager.GetKeyFieldsFor(typeof(T)).ToList(); ParameterExpression parameter = Expression.Parameter(typeof(T)); Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter); diff --git a/GraphDiff/GraphDiff/KeysConfiguration.cs b/GraphDiff/GraphDiff/KeysConfiguration.cs new file mode 100644 index 0000000..d2c1bfe --- /dev/null +++ b/GraphDiff/GraphDiff/KeysConfiguration.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; + +namespace RefactorThis.GraphDiff +{ + /// + /// Defines custom entity keys to use during merge instead of primary key + /// + public sealed class KeysConfiguration + { + private class PropertyInfoExpressionVisitor : ExpressionVisitor + { + public PropertyInfo PropertyInfo { get; private set; } + + protected override Expression VisitMember(MemberExpression node) + { + var pi = node.Member as PropertyInfo; + if (pi != null) + PropertyInfo = pi; + return base.VisitMember(node); + } + } + + private readonly Dictionary> _entityKeys = new Dictionary>(); + + /// + /// Defines a key configuration for an entity type. + /// Be careful about your key, you have to ensure uniqueness. + /// + /// Entity type + /// Path to entity key properties. Ensure that your key is unique. + /// Keys configuration to chain call + public KeysConfiguration ForEntity(params Expression>[] key) + { + if (_entityKeys.ContainsKey(typeof(T))) + throw new InvalidOperationException("A key configuration is already defined for entity type" + typeof(T).Name); + var propertyInfos = key.Select(e => GetPropertyInfo(e)); + _entityKeys.Add(typeof(T), propertyInfos.ToList()); + return this; + } + + private static PropertyInfo GetPropertyInfo(Expression> expression) + { + var visitor = new PropertyInfoExpressionVisitor(); + visitor.Visit(expression); + return visitor.PropertyInfo; + } + + internal IList GetEntityKey(Type entityType) + { + IList result; + if (_entityKeys.TryGetValue(entityType, out result)) + return result; + else + return null; + } + + internal bool HasConfigurationFor(Type entityType) + { + return _entityKeys.ContainsKey(entityType); + } + } +}