Skip to content

Commit

Permalink
Add support for table valued functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pmiddleton committed Mar 25, 2018
1 parent 034fb68 commit f5421c0
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1105,10 +1105,11 @@ var equalityExpression
{
var newArguments = Visit(dbFunctionExpression.Arguments);

if (newArguments.Any(a => a == null))
//I dont think I need this anymore
/* if (newArguments.Any(a => a == null))
{
return null;
}
}*/

//TODO - can you custom translate here?
return //dbFunctionExpression.Translate(newArguments)
Expand Down
75 changes: 68 additions & 7 deletions src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ var newProjection
break;


// TODO: Visit sub-query (SelectExpression) here?
// TODO: Visit sub-query (SelectExpression) here?
}
}

Expand Down Expand Up @@ -1312,6 +1312,9 @@ protected override void OptimizeQueryModel(

var typeIsExpressionTranslatingVisitor = new TypeIsExpressionTranslatingVisitor(QueryCompilationContext.Model);
queryModel.TransformExpressions(typeIsExpressionTranslatingVisitor.Visit);

var dbFunctionSrouceSubqueryInjector = new DbFunctionSourceSubqueryInjector();
queryModel.SelectClause.TransformExpressions(dbFunctionSrouceSubqueryInjector.Visit);
}

/// <summary>
Expand All @@ -1328,6 +1331,60 @@ protected virtual void WarnClientEval(
QueryCompilationContext.Logger.QueryClientEvaluationWarning(queryModel, queryModelElement);
}

private class DbFunctionSourceSubqueryInjector : ExpressionVisitorBase
{
private bool _shouldInject;

protected override Expression VisitNew(NewExpression expression)
{
_shouldInject = true;

try
{
return base.VisitNew(expression);
}
finally
{
_shouldInject = false;
}
}

protected override Expression VisitSubQuery(SubQueryExpression subQueryExpression)
{
var shouldInject = _shouldInject;
_shouldInject = false;

try
{
return base.VisitSubQuery(subQueryExpression);
}
finally
{
_shouldInject = shouldInject;
}
}

protected override Expression VisitExtension(Expression extensionExpression)
{
if (_shouldInject && extensionExpression is DbFunctionSourceExpression dbf)
{
return InjectSubquery(dbf);
}

return base.VisitExtension(extensionExpression);
}

private static Expression InjectSubquery(DbFunctionSourceExpression expression)
{
var targetType = expression.ReturnType;
var mainFromClause = new MainFromClause(targetType.Name.Substring(0, 1).ToLowerInvariant(), targetType, expression);
var selector = new QuerySourceReferenceExpression(mainFromClause);

var subqueryModel = new QueryModel(mainFromClause, new SelectClause(selector));
return new SubQueryExpression(subqueryModel);
}
}

private class TypeIsExpressionTranslatingVisitor : ExpressionVisitorBase
{
private readonly IModel _model;
Expand Down Expand Up @@ -1491,8 +1548,9 @@ var joinExpression
= correlated
? QueryCompilationContext.IsLateralJoinOuterSupported
&& innerShapedQuery?.Method.MethodIsClosedFormOf(LinqOperatorProvider.DefaultIfEmpty) == true
&& innerSelectExpression.Tables.First() is SelectExpression s
&& s.Tables.First() is TableValuedSqlFunctionExpression
&& ((innerSelectExpression.Tables.First() is SelectExpression s
&& s.Tables.First() is TableValuedSqlFunctionExpression)
|| innerSelectExpression.Tables.First() is TableValuedSqlFunctionExpression)
? outerSelectExpression.AddCrossJoinLateralOuter(
innerSelectExpression.Tables.First(),
innerSelectExpression.Projection)
Expand Down Expand Up @@ -2216,10 +2274,13 @@ var parameterWithSamePrefixCount

_injectedParameters[parameterName] = propertyExpression;

Expression
= CreateInjectParametersExpression(
Expression,
new Dictionary<string, Expression> { [parameterName] = propertyExpression });
if(Expression != null)
{
Expression
= CreateInjectParametersExpression(
Expression,
new Dictionary<string, Expression> { [parameterName] = propertyExpression });
}

return Expression.Parameter(
property.ClrType,
Expand Down
14 changes: 9 additions & 5 deletions src/EFCore/DbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1421,9 +1421,11 @@ public virtual Task<TEntity> FindAsync<TEntity>([CanBeNull] object[] keyValues,
/// <typeparam name="U">todo</typeparam>
/// <param name="dbFuncCall">todo</param>
/// <returns>todo</returns>
protected virtual T ExecuteScalarMethod<U, T>(Expression<Func<U, T>> dbFuncCall)
protected virtual T ExecuteScalarMethod<U, T>([NotNull] Expression<Func<U, T>> dbFuncCall)
where U : DbContext
{
Check.NotNull(dbFuncCall, nameof(dbFuncCall));

//todo - verify dbFuncCall contains a method call expression
var dbFuncFac = InternalServiceProvider.GetRequiredService<IDbFunctionSourceFactory>();
var resultsQuery = DbContextDependencies.QueryProvider.Execute(dbFuncFac.GenerateDbFunctionSource(dbFuncCall.Body as MethodCallExpression, Model)) as IEnumerable<T>;
Expand All @@ -1443,9 +1445,11 @@ protected virtual T ExecuteScalarMethod<U, T>(Expression<Func<U, T>> dbFuncCall)
/// <typeparam name="T">todo</typeparam>
/// <param name="dbFuncCall">todo</param>
/// <returns>todo</returns>
protected IQueryable<T> ExecuteTableValuedFunction<U, T>(Expression<Func<U, IQueryable<T>>> dbFuncCall)
protected virtual IQueryable<T> ExecuteTableValuedFunction<U, T>([NotNull] Expression<Func<U, IQueryable<T>>> dbFuncCall)
where U : DbContext
{
Check.NotNull(dbFuncCall, nameof(dbFuncCall));

var dbFuncFac = InternalServiceProvider.GetRequiredService<IDbFunctionSourceFactory>();

//todo - verify dbFuncCall contains a method call expression
Expand All @@ -1454,14 +1458,14 @@ protected IQueryable<T> ExecuteTableValuedFunction<U, T>(Expression<Func<U, IQue
return DbContextDependencies.QueryProvider.CreateQuery<T>(resultsQuery);
}

/// <summary>
/* /// <summary>
/// todo
/// </summary>
/// <typeparam name="T">todo</typeparam>
/// <param name="callingMethod">todo</param>
/// <param name="methodParams">todo</param>
/// <returns>todo</returns>
protected IQueryable<T> ExecuteTableValuedFunction<T>(MethodInfo callingMethod, params object[] methodParams)
protected IQueryable<T> ExecuteTableValuedFunction<T>([NotNull] MethodInfo callingMethod, params object[] methodParams)
{
var c = Expression.Call(Expression.Constant(this),
callingMethod,
Expand Down Expand Up @@ -1489,7 +1493,7 @@ protected IQueryable<T> ExecuteTableValuedFunction<T>(MethodInfo callingMethod,
Expression.Constant(this),
callingMethod,
paramExps));*/
}
// }

#region Hidden System.Object members

Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/Internal/DbFunctionSourceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Text;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Internal
Expand All @@ -16,7 +17,7 @@ public class DbFunctionSourceFactory : IDbFunctionSourceFactory
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model)
public virtual Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model)
{
throw new NotImplementedException();
}
Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/Internal/IDbFunctionSourceFactory.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Internal
Expand All @@ -13,6 +14,6 @@ public interface IDbFunctionSourceFactory
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
Expression GenerateDbFunctionSource(MethodCallExpression methodCall, IModel model);
Expression GenerateDbFunctionSource([NotNull] MethodCallExpression methodCall, [NotNull] IModel model);
}
}
Loading

0 comments on commit f5421c0

Please sign in to comment.