Skip to content

Commit

Permalink
feat: implement whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
catalyst17 committed Oct 29, 2024
1 parent 6d744ac commit 9cb0aac
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
49 changes: 49 additions & 0 deletions internal/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package common
import (
"fmt"
"math/big"
"regexp"
"strings"
"unicode"
)
Expand Down Expand Up @@ -169,3 +170,51 @@ func isType(word string) bool {

return types[word]
}

var allowedFunctions = map[string]struct{}{
"sum": {},
"count": {},
"reinterpretAsUInt256": {},
"reverse": {},
"unhex": {},
"substring": {},
"length": {},
"toUInt256": {},
"if": {},
}

var disallowedPatterns = []string{
`(?i)\b(UNION|INSERT|DELETE|UPDATE|DROP|CREATE|ALTER|TRUNCATE|EXEC|;|--)`,
}

// validateQuery checks the query for disallowed patterns and ensures only allowed functions are used.
func ValidateQuery(query string) error {
// Check for disallowed patterns
for _, pattern := range disallowedPatterns {
matched, err := regexp.MatchString(pattern, query)
if err != nil {
return fmt.Errorf("error checking disallowed patterns: %v", err)
}
if matched {
return fmt.Errorf("query contains disallowed keywords or patterns")
}
}

// Ensure the query is a SELECT statement
trimmedQuery := strings.TrimSpace(strings.ToUpper(query))
if !strings.HasPrefix(trimmedQuery, "SELECT") {
return fmt.Errorf("only SELECT queries are allowed")
}

// Extract function names and validate them
functionPattern := regexp.MustCompile(`(?i)(\b\w+\b)\s*\(`)
matches := functionPattern.FindAllStringSubmatch(query, -1)
for _, match := range matches {
funcName := match[1]
if _, ok := allowedFunctions[funcName]; !ok {
return fmt.Errorf("function '%s' is not allowed", funcName)
}
}

return nil
}
2 changes: 2 additions & 0 deletions internal/handlers/logs_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
aggregatesResult, err := mainStorage.GetAggregations("logs", qf)
if err != nil {
log.Error().Err(err).Msg("Error querying aggregates")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand All @@ -180,6 +181,7 @@ func handleLogsRequest(c *gin.Context, contractAddress, signature string) {
logsResult, err := mainStorage.GetLogs(qf)
if err != nil {
log.Error().Err(err).Msg("Error querying logs")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand Down
4 changes: 3 additions & 1 deletion internal/handlers/transactions_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
aggregatesResult, err := mainStorage.GetAggregations("transactions", qf)
if err != nil {
log.Error().Err(err).Msg("Error querying aggregates")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand All @@ -181,7 +182,8 @@ func handleTransactionsRequest(c *gin.Context, contractAddress, signature string
// Retrieve logs data
transactionsResult, err := mainStorage.GetTransactions(qf)
if err != nil {
log.Error().Err(err).Msg("Error querying tran")
log.Error().Err(err).Msg("Error querying transactions")
// TODO: might want to choose BadRequestError if it's due to not-allowed functions
api.InternalErrorHandler(c)
return
}
Expand Down
12 changes: 12 additions & 0 deletions internal/storage/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ func (c *ClickHouseConnector) GetBlocks(qf QueryFilter) (blocks []common.Block,

query += getLimitClause(int(qf.Limit))

if err := common.ValidateQuery(query); err != nil {
return nil, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return nil, err
Expand Down Expand Up @@ -369,6 +372,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
query += fmt.Sprintf(" GROUP BY %s", groupByColumns)
}

if err := common.ValidateQuery(query); err != nil {
return QueryResult[interface{}]{}, err
}
// Execute the query
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
Expand Down Expand Up @@ -421,6 +427,9 @@ func (c *ClickHouseConnector) GetAggregations(table string, qf QueryFilter) (Que
func executeQuery[T any](c *ClickHouseConnector, table, columns string, qf QueryFilter, scanFunc func(driver.Rows) (T, error)) (QueryResult[T], error) {
query := c.buildQuery(table, columns, qf)

if err := common.ValidateQuery(query); err != nil {
return QueryResult[T]{}, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return QueryResult[T]{}, err
Expand Down Expand Up @@ -856,6 +865,9 @@ func (c *ClickHouseConnector) GetTraces(qf QueryFilter) (traces []common.Trace,

query += getLimitClause(int(qf.Limit))

if err := common.ValidateQuery(query); err != nil {
return nil, err
}
rows, err := c.conn.Query(context.Background(), query)
if err != nil {
return nil, err
Expand Down

0 comments on commit 9cb0aac

Please sign in to comment.