diff --git a/dialect_common.go b/dialect_common.go index f3e598f0..3ab9cfe0 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -29,17 +29,22 @@ func (commonDialect) Lock(fn func() error) error { return fn() } -func (commonDialect) Quote(key string) string { - parts := strings.Split(key, ".") +func quoteIdentifiers(s, quote string) string { + parts := strings.Split(s, ".") for i, part := range parts { - part = strings.Trim(part, `"`) + part = strings.Trim(part, quote) part = strings.TrimSpace(part) - parts[i] = fmt.Sprintf(`"%v"`, part) + parts[i] = quote + part + quote } return strings.Join(parts, ".") + +} + +func (commonDialect) Quote(key string) string { + return quoteIdentifiers(key, `"`) } func genericCreate(c *Connection, model *Model, cols columns.Columns, quoter quotable) error { diff --git a/dialect_mysql.go b/dialect_mysql.go index 0a233889..9afe5a83 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -43,7 +43,7 @@ func (m *mysql) DefaultDriver() string { } func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) + return quoteIdentifiers(key, "`") } func (m *mysql) Details() *ConnectionDetails { diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index e70d1f74..882b0966 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -244,3 +244,12 @@ func (s *MySQLSuite) Test_MySQL_DDL_Schema() { err = PDB.Dialect.DumpSchema(f) r.Error(err) } + +func Test_MySQL_Quote(t *testing.T) { + r := require.New(t) + + m := &mysql{} + r.Equal("`table_name`", m.Quote("table_name")) + r.Equal("`schema`.`table_name`", m.Quote("schema.table_name")) + r.Equal("`schema`.`table_name`", m.Quote(m.Quote("schema.table_name"))) +}