Skip to content
This repository has been archived by the owner on Jul 29, 2021. It is now read-only.

Commit

Permalink
add upsert/merge feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkjeff committed Jun 16, 2016
1 parent a5abc50 commit 743740e
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 14 deletions.
Binary file not shown.
Binary file added EntityFramework.Utilities/.vs/Tests/Tests.scgdat
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ public class ColumnMapping
public string DataType { get; set; }

public bool IsPrimaryKey { get; set; }
public bool IsStoreGeneratedIdentity { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ public interface IEFBatchOperationBase<TContext, T> where T : class
/// <param name="connection">The DbConnection to use for the insert. Only needed when for example a profiler wraps the connection. Then you need to provide a connection of the type the provider use.</param>
/// <param name="batchSize">The size of each batch. Default depends on the provider. SqlProvider uses 15000 as default</param>
void UpdateAll<TEntity>(IEnumerable<TEntity> items, Action<UpdateSpecification<TEntity>> updateSpecification, DbConnection connection = null, int? batchSize = null) where TEntity : class, T;

/// <summary>
/// provider batch upsert operation
/// SQL:
/// merge into [(the table of source entity)] as Target
/// using (tempTable) as Source
/// on <paramref name="identitySpecification"/>
/// when matched then
/// update set <paramref name="whenMatchedUpdateSpecification"/>
/// when not matched then
/// insert ...;
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="items"></param>
/// <param name="identitySpecification">match identity specification. if parameter is null, use primary key as default</param>
/// <param name="whenMatchedUpdateSpecification">update specification when matched by <paramref name="identitySpecification"/>. if parameter is null, update all columns except primary key</param>
/// <param name="connection"></param>
/// <param name="batchSize"></param>
void MergeAll<TEntity>(IEnumerable<TEntity> items, Action<IdentitySpecification<TEntity>> identitySpecification = null,
Action<UpdateSpecification<TEntity>> whenMatchedUpdateSpecification = null, DbConnection connection = null, int? batchSize = null) where TEntity : class, T;
}

public class UpdateSpecification<T>
Expand All @@ -52,6 +72,22 @@ public UpdateSpecification<T> ColumnsToUpdate(params Expression<Func<T, object>>
public Expression<Func<T, object>>[] Properties { get; set; }
}

public class IdentitySpecification<T>
{
/// <summary>
/// Set each column you use to identity.
/// </summary>
/// <param name="properties"></param>
/// <returns></returns>
public IdentitySpecification<T> ColumnsToIdentity(params Expression<Func<T, object>>[] properties)
{
Properties = properties;
return this;
}

public Expression<Func<T, object>>[] Properties { get; set; }
}

public interface IEFBatchOperationFiltered<TContext, T>
{
int Delete();
Expand Down Expand Up @@ -161,6 +197,7 @@ public void UpdateAll<TEntity>(IEnumerable<TEntity> items, Action<UpdateSpecific
NameInDatabase = p.ColumnName,
NameOnObject = p.PropertyName,
DataType = p.DataTypeFull,
IsStoreGeneratedIdentity = p.IsStoreGeneratedIdentity,
IsPrimaryKey = p.IsPrimaryKey
}).ToList();

Expand All @@ -175,6 +212,70 @@ public void UpdateAll<TEntity>(IEnumerable<TEntity> items, Action<UpdateSpecific
}
}

public void MergeAll<TEntity>(IEnumerable<TEntity> items, Action<IdentitySpecification<TEntity>> identitySpecification, Action<UpdateSpecification<TEntity>> updateSpecification, DbConnection connection, int? batchSize)
where TEntity : class, T
{
var con = context.Connection as EntityConnection;
if (con == null && connection == null)
{
Configuration.Log("No provider could be found because the Connection didn't implement System.Data.EntityClient.EntityConnection");
throw new InvalidOperationException("No provider supporting the upsert operation");
}

var connectionToUse = connection ?? con.StoreConnection;
var currentType = typeof(TEntity);
var provider = Configuration.Providers.FirstOrDefault(p => p.CanHandle(connectionToUse));
if (provider != null && provider.CanBulkUpdate)
{

var mapping = EfMappingFactory.GetMappingsForContext(this.dbContext);
var typeMapping = mapping.TypeMappings[typeof(T)];
var tableMapping = typeMapping.TableMappings.First();

var properties = tableMapping.PropertyMappings
.Where(p => currentType.IsSubclassOf(p.ForEntityType) || p.ForEntityType == currentType)
.Select(p => new ColumnMapping
{
NameInDatabase = p.ColumnName,
NameOnObject = p.PropertyName,
DataType = p.DataTypeFull,
IsPrimaryKey = p.IsPrimaryKey,
IsStoreGeneratedIdentity = p.IsStoreGeneratedIdentity,
}).ToList();

HashSet<string> columnsToMatch;
if (identitySpecification != null)
{
var identity = new IdentitySpecification<TEntity>();
identitySpecification(identity);
columnsToMatch = new HashSet<string>(identity.Properties.Select(p => p.GetPropertyName()));
}
else
{
columnsToMatch = new HashSet<string>(properties.Where(p => p.IsPrimaryKey).Select(p => p.NameOnObject));
}

HashSet<string> columnsToUpdate;
if (updateSpecification != null)
{
var spec = new UpdateSpecification<TEntity>();
updateSpecification(spec);
columnsToUpdate = new HashSet<string>(spec.Properties.Select(p => p.GetPropertyName()));
}
else
{
columnsToUpdate = new HashSet<string>(properties.Where(p => !p.IsPrimaryKey).Select(p => p.NameOnObject));
}

provider.UpsertImtes(items, tableMapping.Schema, tableMapping.TableName, properties, connectionToUse, batchSize, columnsToMatch, columnsToUpdate);
}
else
{
Configuration.Log("Found provider: " + (provider == null ? "[]" : provider.GetType().Name) + " for " + connectionToUse.GetType().Name);
throw new InvalidOperationException("No provider supporting the upsert operation");
}
}

public IEFBatchOperationFiltered<TContext, T> Where(Expression<Func<T, bool>> predicate)
{
this.predicate = predicate;
Expand Down Expand Up @@ -247,6 +348,6 @@ public int Update<TP>(Expression<Func<T, TP>> prop, Expression<Func<T, TP>> modi
}



}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Text;

namespace EntityFramework.Utilities
{
Expand All @@ -17,6 +14,7 @@ public interface IQueryProvider
string GetUpdateQuery(QueryInformation predicateQueryInfo, QueryInformation modificationQueryInfo);
void InsertItems<T>(IEnumerable<T> items, string schema, string tableName, IList<ColumnMapping> properties, DbConnection storeConnection, int? batchSize);
void UpdateItems<T>(IEnumerable<T> items, string schema, string tableName, IList<ColumnMapping> properties, DbConnection storeConnection, int? batchSize, UpdateSpecification<T> updateSpecification);
void UpsertImtes<T>(IEnumerable<T> items, string schema, string tableName, IList<ColumnMapping> properties, DbConnection storeConnection, int? batchSize, HashSet<string> identitySpecification, HashSet<string> updateSpecification);

bool CanHandle(DbConnection storeConnection);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
using System.Data.Entity.Infrastructure;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Xml;
using System.Xml.Linq;
using System.Reflection;

namespace EntityFramework.Utilities
{
Expand Down Expand Up @@ -94,6 +94,7 @@ public class PropertyMapping
public bool IsPrimaryKey { get; set; }

public string DataTypeFull { get; set; }
public bool IsStoreGeneratedIdentity { get; set; }
}

/// <summary>
Expand Down Expand Up @@ -122,7 +123,7 @@ public EfMapping(DbContext db)
var conceptualContainer = metadata.GetItems<EntityContainer>(DataSpace.CSpace).Single();

// Storage part of the model has info about the shape of our tables
var storeContainer = metadata.GetItems<EntityContainer>(DataSpace.SSpace).Single();
var storeContainer = metadata.GetItems(DataSpace.SSpace).OfType<EntityType>();

// Object part of the model that contains info about the actual CLR types
var objectItemCollection = ((ObjectItemCollection)metadata.GetItemCollection(DataSpace.OSpace));
Expand Down Expand Up @@ -220,6 +221,11 @@ public EfMapping(DbContext db)
if ((mappingToLookAt.EntityType ?? mappingToLookAt.IsOfEntityTypes[0]).KeyProperties.Any(p => p.Name == item.PropertyName))
{
item.IsPrimaryKey = true;
item.IsStoreGeneratedIdentity = storeContainer.FirstOrDefault(t => t.Name == item.ForEntityType.Name)
?.Properties
?.FirstOrDefault(p => p.Name == item.ColumnName)
?.IsStoreGeneratedIdentity
?? false;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;

namespace EntityFramework.Utilities
Expand Down Expand Up @@ -83,7 +82,7 @@ public void InsertItems<T>(IEnumerable<T> items, string schema, string tableName

public void UpdateItems<T>(IEnumerable<T> items, string schema, string tableName, IList<ColumnMapping> properties, DbConnection storeConnection, int? batchSize, UpdateSpecification<T> updateSpecification)
{
var tempTableName = "temp_" + tableName + "_" + DateTime.Now.Ticks;
var tempTableName = "#temp_" + tableName + "_" + DateTime.Now.Ticks;
var columnsToUpdate = updateSpecification.Properties.Select(p => p.GetPropertyName()).ToDictionary(x => x);
var filtered = properties.Where(p => columnsToUpdate.ContainsKey(p.NameOnObject) || p.IsPrimaryKey).ToList();
var columns = filtered.Select(c => "[" + c.NameInDatabase + "] " + c.DataType);
Expand All @@ -110,9 +109,14 @@ INNER JOIN
ON
{2}", tableName, tempTableName, filter, setters);

var dropCommand = $@"IF Object_id('tempdb..{tempTableName}') IS NOT NULL
BEGIN DROP TABLE {tempTableName} END
ELSE
BEGIN THROW 51000,'Drop temp table {tempTableName} fail.',1; END";

using (var createCommand = new SqlCommand(str, con))
using (var mCommand = new SqlCommand(mergeCommand, con))
using (var dCommand = new SqlCommand(string.Format("DROP table {0}.[{1}]", schema, tempTableName), con))
using (var dCommand = new SqlCommand(dropCommand, con))
{
createCommand.ExecuteNonQuery();
InsertItems(items, schema, tempTableName, filtered, storeConnection, batchSize);
Expand All @@ -123,6 +127,56 @@ INNER JOIN

}

public void UpsertImtes<T>(IEnumerable<T> items, string schema, string tableName, IList<ColumnMapping> properties, DbConnection storeConnection, int? batchSize, HashSet<string> columnsToIdentity, HashSet<string> columnsToUpdate)
{
var tempTableName = "#temp_" + tableName + "_" + DateTime.Now.Ticks;

var str = $@"CREATE TABLE {schema}.[{tempTableName}] (
{string.Join(", ", properties.Select(c => "[" + c.NameInDatabase + "] " + c.DataType))},
PRIMARY KEY ({string.Join(", ", properties.Where(p => p.IsPrimaryKey).Select(c => "[" + c.NameInDatabase + "]"))})
)";

var con = storeConnection as SqlConnection;
if (con.State != System.Data.ConnectionState.Open)
{
con.Open();
}

var insertProperties = properties.Where(p => !p.IsStoreGeneratedIdentity).Select(p => p.NameInDatabase).ToArray();
string mergeCommand =
$@"merge into [{tableName}] as Target
using {tempTableName} as Source
on {string.Join(" and ", properties
.Where(p => columnsToIdentity.Contains(p.NameOnObject))
.Select(p => $"Target.{p.NameInDatabase}=Source.{p.NameInDatabase}"))}
when matched then
update set {string.Join(",", properties
.Where(p => columnsToUpdate.Contains(p.NameOnObject) && !p.IsPrimaryKey)
.Select(p => "Target.[" + p.NameInDatabase + "] = Source.[" + p.NameInDatabase + "]"))}
when not matched then
insert (
{string.Join(",", insertProperties)}
) values (
{string.Join(",", insertProperties.Select(p=>$"Source.{p}"))}
);";


var dropCommand = $@"IF Object_id('tempdb..{tempTableName}') IS NOT NULL
BEGIN DROP TABLE {tempTableName} END
ELSE
BEGIN THROW 51000,'Drop temp table {tempTableName} fail.',1; END";

using (var createCommand = new SqlCommand(str, con))
using (var mCommand = new SqlCommand(mergeCommand, con))
using (var dCommand = new SqlCommand(dropCommand, con))
{
createCommand.ExecuteNonQuery();
InsertItems(items, schema, tempTableName, properties, storeConnection, batchSize);
mCommand.ExecuteNonQuery();
dCommand.ExecuteNonQuery();
}
}


public bool CanHandle(System.Data.Common.DbConnection storeConnection)
{
Expand Down
1 change: 0 additions & 1 deletion EntityFramework.Utilities/Tests/InsertTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tests.FakeDomain;
using Tests.FakeDomain.Models;
using System;

namespace Tests
{
Expand Down
Loading

0 comments on commit 743740e

Please sign in to comment.