diff --git a/internal/event/target/postgresql.go b/internal/event/target/postgresql.go index bb10fdaf6..40e4ba8c9 100644 --- a/internal/event/target/postgresql.go +++ b/internal/event/target/postgresql.go @@ -26,9 +26,11 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strconv" "strings" "time" + "unicode" _ "github.com/lib/pq" // Register postgres driver @@ -101,6 +103,10 @@ func (p PostgreSQLArgs) Validate() error { if p.Table == "" { return fmt.Errorf("empty table name") } + if err := validatePsqlTableName(p.Table); err != nil { + return err + } + if p.Format != "" { f := strings.ToLower(p.Format) if f != event.NamespaceFormat && f != event.AccessFormat { @@ -444,3 +450,43 @@ func NewPostgreSQLTarget(id string, args PostgreSQLArgs, loggerOnce logger.LogOn return target, nil } + +var errInvalidPsqlTablename = errors.New("invalid PostgreSQL table") + +func validatePsqlTableName(name string) error { + // check for quoted string (string may not contain a quote) + if match, err := regexp.MatchString("^\"[^\"]+\"$", name); err != nil { + return err + } else if match { + return nil + } + + // normalize the name to letters, digits, _ or $ + valid := true + cleaned := strings.Map(func(r rune) rune { + switch { + case unicode.IsLetter(r): + return 'a' + case unicode.IsDigit(r): + return '0' + case r == '_', r == '$': + return r + default: + valid = false + return -1 + } + }, name) + + if valid { + // check for simple name or quoted name + // - letter/underscore followed by one or more letter/digit/underscore + // - any text between quotes (text cannot contain a quote itself) + if match, err := regexp.MatchString("^[a_][a0_$]*$", cleaned); err != nil { + return err + } else if match { + return nil + } + } + + return errInvalidPsqlTablename +} diff --git a/internal/event/target/postgresql_test.go b/internal/event/target/postgresql_test.go index 0ec94f6f1..cd03c7134 100644 --- a/internal/event/target/postgresql_test.go +++ b/internal/event/target/postgresql_test.go @@ -36,3 +36,19 @@ func TestPostgreSQLRegistration(t *testing.T) { t.Fatal("postgres driver not registered") } } + +func TestPsqlTableNameValidation(t *testing.T) { + validTables := []string{"táblë", "table", "TableName", "\"Table name\"", "\"✅✅\"", "table$one", "\"táblë\""} + invalidTables := []string{"table name", "table \"name\"", "✅✅", "$table$"} + + for _, name := range validTables { + if err := validatePsqlTableName(name); err != nil { + t.Errorf("Should be valid: %s - %s", name, err) + } + } + for _, name := range invalidTables { + if err := validatePsqlTableName(name); err != errInvalidPsqlTablename { + t.Errorf("Should be invalid: %s - %s", name, err) + } + } +}