Skip to content

Commit

Permalink
don't check authorization in view
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Feb 22, 2024
1 parent 37239d8 commit a2c38eb
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public class CascadesContext implements ScheduleContext {
private boolean isLeadingJoin = false;

private final Map<String, Hint> hintMap = Maps.newLinkedHashMap();
private final boolean shouldCheckRelationAuthentication;

/**
* Constructor of OptimizerContext.
Expand All @@ -125,7 +126,7 @@ public class CascadesContext implements ScheduleContext {
*/
private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> currentTree,
StatementContext statementContext, Plan plan, Memo memo,
CTEContext cteContext, PhysicalProperties requireProperties) {
CTEContext cteContext, PhysicalProperties requireProperties, boolean shouldCheckRelationAuthentication) {
this.parent = Objects.requireNonNull(parent, "parent should not null");
this.currentTree = Objects.requireNonNull(currentTree, "currentTree should not null");
this.statementContext = Objects.requireNonNull(statementContext, "statementContext should not null");
Expand All @@ -138,6 +139,7 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
this.currentJobContext = new JobContext(this, requireProperties, Double.MAX_VALUE);
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
this.shouldCheckRelationAuthentication = shouldCheckRelationAuthentication;
}

/**
Expand All @@ -146,7 +148,13 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
public static CascadesContext initContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return newContext(Optional.empty(), Optional.empty(), statementContext,
initPlan, new CTEContext(), requireProperties);
initPlan, new CTEContext(), requireProperties, true);
}

public static CascadesContext initViewContext(StatementContext statementContext,
Plan initPlan, PhysicalProperties requireProperties) {
return newContext(Optional.empty(), Optional.empty(), statementContext,
initPlan, new CTEContext(), requireProperties, false);
}

/**
Expand All @@ -155,13 +163,14 @@ public static CascadesContext initContext(StatementContext statementContext,
public static CascadesContext newContextWithCteContext(CascadesContext cascadesContext,
Plan initPlan, CTEContext cteContext) {
return newContext(Optional.of(cascadesContext), Optional.empty(),
cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY);
cascadesContext.getStatementContext(), initPlan, cteContext, PhysicalProperties.ANY,
cascadesContext.shouldCheckRelationAuthentication);
}

public static CascadesContext newCurrentTreeContext(CascadesContext context) {
return CascadesContext.newContext(context.getParent(), context.getCurrentTree(), context.getStatementContext(),
context.getRewritePlan(), context.getCteContext(),
context.getCurrentJobContext().getRequiredProperties());
context.getCurrentJobContext().getRequiredProperties(), context.shouldCheckRelationAuthentication);
}

/**
Expand All @@ -170,13 +179,14 @@ public static CascadesContext newCurrentTreeContext(CascadesContext context) {
public static CascadesContext newSubtreeContext(Optional<CTEId> subtree, CascadesContext context,
Plan plan, PhysicalProperties requireProperties) {
return CascadesContext.newContext(Optional.of(context), subtree, context.getStatementContext(),
plan, context.getCteContext(), requireProperties);
plan, context.getCteContext(), requireProperties, context.shouldCheckRelationAuthentication);
}

private static CascadesContext newContext(Optional<CascadesContext> parent, Optional<CTEId> subtree,
StatementContext statementContext, Plan initPlan,
CTEContext cteContext, PhysicalProperties requireProperties) {
return new CascadesContext(parent, subtree, statementContext, initPlan, null, cteContext, requireProperties);
StatementContext statementContext, Plan initPlan, CTEContext cteContext,
PhysicalProperties requireProperties, boolean shouldCheckRelationAuthentication) {
return new CascadesContext(parent, subtree, statementContext, initPlan, null,
cteContext, requireProperties, shouldCheckRelationAuthentication);
}

public CascadesContext getRoot() {
Expand Down Expand Up @@ -622,6 +632,10 @@ public void setLeadingJoin(boolean leadingJoin) {
isLeadingJoin = leadingJoin;
}

public boolean shouldCheckRelationAuthentication() {
return shouldCheckRelationAuthentication;
}

public Map<String, Hint> getHintMap() {
return hintMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.ResolveOrdinalInOrderByAndGroupBy;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
import org.apache.doris.nereids.rules.rewrite.JoinCommute;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;

Expand Down Expand Up @@ -115,8 +114,7 @@ private static List<RewriteJob> buildAnalyzeViewJobs(Optional<CustomTableResolve
topDown(new AnalyzeCTE()),
bottomUp(
new BindRelation(customTableResolver),
new CheckPolicy(),
new UserAuthentication()
new CheckPolicy()
)
);
}
Expand All @@ -128,8 +126,7 @@ private static List<RewriteJob> buildAnalyzeJobs(Optional<CustomTableResolver> c
topDown(new EliminateLogicalSelectHint()),
bottomUp(
new BindRelation(customTableResolver),
new CheckPolicy(),
new UserAuthentication()
new CheckPolicy()
),
bottomUp(new BindExpression()),
topDown(new BindSink()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private LogicalPlan bindWithCurrentDb(CascadesContext cascadesContext, UnboundRe
}

// TODO: should generate different Scan sub class according to table's type
LogicalPlan scan = getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
LogicalPlan scan = getAndCheckLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
if (cascadesContext.isLeadingJoin()) {
LeadingHint leading = (LeadingHint) cascadesContext.getHintMap().get("Leading");
leading.putRelationIdAndTableName(Pair.of(unboundRelation.getRelationId(), tableName));
Expand All @@ -178,7 +178,7 @@ private LogicalPlan bind(CascadesContext cascadesContext, UnboundRelation unboun
if (table == null) {
table = RelationUtil.getTable(tableQualifier, cascadesContext.getConnectContext().getEnv());
}
return getLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
return getAndCheckLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
}

private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier) {
Expand Down Expand Up @@ -234,7 +234,17 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation,
return scan;
}

private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier,
private LogicalPlan getAndCheckLogicalPlan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier,
CascadesContext cascadesContext) {
// if current context is in the view, we can skip check authentication because
// the view already checked authentication
if (cascadesContext.shouldCheckRelationAuthentication()) {
UserAuthentication.checkPermission(table, cascadesContext.getConnectContext());
}
return doGetLogicalPlan(table, unboundRelation, tableQualifier, cascadesContext);
}

private LogicalPlan doGetLogicalPlan(TableIf table, UnboundRelation unboundRelation, List<String> tableQualifier,
CascadesContext cascadesContext) {
switch (table.getType()) {
case OLAP:
Expand Down Expand Up @@ -288,7 +298,7 @@ private Plan parseAndAnalyzeView(String ddlSql, CascadesContext parentContext) {
if (parsedViewPlan instanceof UnboundResultSink) {
parsedViewPlan = (LogicalPlan) ((UnboundResultSink<?>) parsedViewPlan).child();
}
CascadesContext viewContext = CascadesContext.initContext(
CascadesContext viewContext = CascadesContext.initViewContext(
parentContext.getStatementContext(), parsedViewPlan, PhysicalProperties.ANY);
viewContext.newAnalyzer(true, customTableResolver).analyze();
// we should remove all group expression of the plan which in other memo, so the groupId would not conflict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,55 +23,39 @@
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.qe.ConnectContext;

/**
* Check whether a user is permitted to scan specific tables.
*/
public class UserAuthentication extends OneAnalysisRuleFactory {

@Override
public Rule build() {
return logicalRelation()
.when(CatalogRelation.class::isInstance)
.thenApply(ctx -> checkPermission((CatalogRelation) ctx.root, ctx.connectContext))
.toRule(RuleType.RELATION_AUTHENTICATION);
}

private Plan checkPermission(CatalogRelation relation, ConnectContext connectContext) {
public class UserAuthentication {
public static void checkPermission(TableIf table, ConnectContext connectContext) {
if (table == null) {
return;
}
// do not check priv when replaying dump file
if (connectContext.getSessionVariable().isPlayNereidsDump()) {
return null;
}
TableIf table = relation.getTable();
if (table == null) {
return null;
return;
}
String tableName = table.getName();
DatabaseIf db = table.getDatabase();
// when table inatanceof FunctionGenTable,db will be null
if (db == null) {
return null;
return;
}
String dbName = db.getFullName();
CatalogIf catalog = db.getCatalog();
if (catalog == null) {
return null;
return;
}
String ctlName = catalog.getName();
// TODO: 2023/7/19 checkColumnsPriv
if (!connectContext.getEnv().getAccessManager().checkTblPriv(connectContext, ctlName, dbName,
tableName, PrivPredicate.SELECT)) {
tableName, PrivPredicate.SELECT)) {
String message = ErrorCode.ERR_TABLEACCESS_DENIED_ERROR.formatErrorMsg("SELECT",
ConnectContext.get().getQualifiedUser(), ConnectContext.get().getRemoteIP(),
ctlName + ": " + dbName + ": " + tableName);
ConnectContext.get().getQualifiedUser(), ConnectContext.get().getRemoteIP(),
ctlName + ": " + dbName + ": " + tableName);
throw new AnalysisException(message);
}
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("view_authorization") {
def db = context.config.getDbNameByFile(context.file)
def user1 = "test_view_auth_user1"
def baseTable = "test_view_auth_base_table"
def view1 = "test_view_auth_view1"
def view2 = "test_view_auth_view2"
def view3 = "test_view_auth_view3"


sql "drop table if exists ${baseTable}"
sql "drop view if exists ${view1}"
sql "drop view if exists ${view2}"
sql "drop view if exists ${view3}"
sql "drop user if exists ${user1}"

sql """
CREATE TABLE ${baseTable} (id INT, name TEXT)
DISTRIBUTED BY HASH(`id`)
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""

sql "insert into ${baseTable} values(1, 'hello'), (2, 'world'), (3, 'doris');"
sql "create view ${view1} as select *, concat(name, '_', id) from ${db}.${baseTable} where id=1;"
sql "create view ${view2} as select *, concat(name, '_', id) as xxx from ${db}.${baseTable} where id != 1;"
sql "create view ${view3} as select xxx, 100 from ${db}.${view2} where id=3"

sql "create user ${user1}"
sql "grant SELECT_PRIV on ${db}.${view1} to '${user1}'@'%';"
sql "grant SELECT_PRIV on ${db}.${view3} to '${user1}'@'%';"

def defaultDbUrl = context.config.jdbcUrl.substring(0, context.config.jdbcUrl.lastIndexOf("/"))
logger.info("connect to ${defaultDbUrl}".toString())
connect(user = user1, password = null, url = defaultDbUrl) {
sql "set enable_fallback_to_original_planner=false"

// no privilege to base table
test {
sql "select * from ${db}.${baseTable}"
exception "SELECT command denied to user '${user1}'"
}

// has privilega to view1
test {
sql "select * from ${db}.${view1}"
result([[1, 'hello', 'hello_1']])
}

// no privilega to view2
test {
sql "select * from ${db}.${view2}"
exception "SELECT command denied to user '${user1}'"
}

// nested view
// has privilega to view3
test {
sql "select * from ${db}.${view3}"
result([['doris_3', 100]])
}
}
}

0 comments on commit a2c38eb

Please sign in to comment.