Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b4529e018f | |||
| 9b998b7904 | |||
| ca2081c72c | |||
|
|
46e6c1272d | ||
|
|
87ae7265f4 |
18
.gitea/workflows/build.yaml
Executable file
18
.gitea/workflows/build.yaml
Executable file
@@ -0,0 +1,18 @@
|
||||
name: Go
|
||||
on: [push]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: [ '1.19' ]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
101
README.md
Normal file → Executable file
101
README.md
Normal file → Executable file
@@ -1,46 +1,95 @@
|
||||
# migrate
|
||||
|
||||
`migrate` is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore) it is intended to keep its footprint small, requiring only an additional table in the database there is no rollback support as you should only ever roll forward. Sqlite3 support is provided, support for other datbases can be added by implementing the `Dialect` interface
|
||||
`migrate` is a package for SQL database migrations in the spirit of [dbstore](rsc.io/dbstore). It is intended to keep its footprint small, requiring only an additional table in the database. There is no rollback support as you should only ever roll forward. SQLite and PostgreSQL support is provided, support for other databases can be added by implementing the `Dialect` interface.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/jchenry/migrate
|
||||
go get git.sdf.org/jchenry/migrate
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### SQLite Example
|
||||
|
||||
```go
|
||||
...
|
||||
records :=
|
||||
[]Record{
|
||||
{
|
||||
Description: "create people table",
|
||||
F: func(ctx Context) (err error) {
|
||||
_, err = ctx.Exec(`
|
||||
import (
|
||||
"database/sql"
|
||||
"git.sdf.org/jchenry/migrate"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
db, _ := sql.Open("sqlite", "database.db")
|
||||
|
||||
changes := []migrate.Change{
|
||||
{
|
||||
Description: "create people table",
|
||||
Apply: func(ctx migrate.Context) error {
|
||||
_, err := ctx.Exec(`
|
||||
CREATE TABLE people (
|
||||
given_name VARCHAR(20),
|
||||
surname VARCHAR(30),
|
||||
gender CHAR(1),
|
||||
sex CHAR(1),
|
||||
age SMALLINT);
|
||||
`)
|
||||
return
|
||||
},
|
||||
},
|
||||
{
|
||||
Description: "Insert a person into people",
|
||||
F: func(ctx Context) (err error) {
|
||||
_, err = ctx.Exec(`INSERT INTO people VALUES('Henry','Colin','M', 42)`)
|
||||
return
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = migrate.Apply(db, migtate.Sqlite3(), records)
|
||||
...
|
||||
`)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
Description: "Insert a person into people",
|
||||
Apply: func(ctx migrate.Context) error {
|
||||
_, err := ctx.Exec(`INSERT INTO people VALUES('Henry','Colin','M', 42)`)
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := migrate.Apply(db, migrate.Sqlite3(), changes)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
```
|
||||
|
||||
### PostgreSQL Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"database/sql"
|
||||
"git.sdf.org/jchenry/migrate"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
db, _ := sql.Open("postgres", "postgres://user:password@localhost/dbname?sslmode=disable")
|
||||
|
||||
changes := []migrate.Change{
|
||||
{
|
||||
Description: "create people table",
|
||||
Apply: func(ctx migrate.Context) error {
|
||||
_, err := ctx.Exec(`
|
||||
CREATE TABLE people (
|
||||
given_name VARCHAR(20),
|
||||
surname VARCHAR(30),
|
||||
sex CHAR(1),
|
||||
age SMALLINT);
|
||||
`)
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := migrate.Apply(db, migrate.Postgres(), changes)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
```
|
||||
|
||||
## Supported Databases
|
||||
|
||||
- **SQLite** - Use `migrate.Sqlite3()`
|
||||
- **PostgreSQL** - Use `migrate.Postgres()`
|
||||
|
||||
To add support for other databases, implement the `Dialect` interface.
|
||||
|
||||
## Contributing
|
||||
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
|
||||
|
||||
@@ -48,5 +97,3 @@ Please make sure to update tests as appropriate.
|
||||
|
||||
## License
|
||||
[MIT](https://choosealicense.com/licenses/mit/)
|
||||
|
||||
courtesey of https://www.makeareadme.com
|
||||
35
dialect.go
Normal file → Executable file
35
dialect.go
Normal file → Executable file
@@ -1,33 +1,18 @@
|
||||
package migrate
|
||||
|
||||
// Dialect defines the interface for database-specific SQL generation.
|
||||
// Implementations must provide SQL statements for creating and managing
|
||||
// the migration version table.
|
||||
type Dialect interface {
|
||||
// CreateTable returns SQL to create the migration version table
|
||||
CreateTable(table string) string
|
||||
|
||||
// TableExists returns SQL to check if the migration version table exists
|
||||
TableExists(table string) string
|
||||
|
||||
// CheckVersion returns SQL to get the current migration version
|
||||
CheckVersion(table string) string
|
||||
|
||||
// InsertVersion returns SQL to insert a new migration version record
|
||||
InsertVersion(table string) string
|
||||
}
|
||||
|
||||
func Sqlite3() Dialect {
|
||||
return sqlite3{}
|
||||
}
|
||||
|
||||
type sqlite3 struct{}
|
||||
|
||||
func (s sqlite3) CreateTable(table string) string {
|
||||
return "CREATE TABLE " + table + ` (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description VARCHAR,
|
||||
applied TIMESTAMP);`
|
||||
}
|
||||
|
||||
func (s sqlite3) TableExists(table string) string {
|
||||
return "SELECT * FROM " + table + ";"
|
||||
}
|
||||
|
||||
func (s sqlite3) CheckVersion(table string) string {
|
||||
return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
|
||||
}
|
||||
|
||||
func (s sqlite3) InsertVersion(table string) string {
|
||||
return "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
|
||||
}
|
||||
|
||||
2
doc.go
Normal file → Executable file
2
doc.go
Normal file → Executable file
@@ -1,6 +1,6 @@
|
||||
package migrate
|
||||
|
||||
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
|
||||
// migrate is a package for SQL database migrations in the spirit of dbstore(rsc.io/dbstore)
|
||||
// it is intended to keep its footprint small, requiring only an additional table in the database
|
||||
// there is no rollback support as you should only ever roll forward.
|
||||
// uses SQL99 compatible SQL only.
|
||||
|
||||
23
go.mod
Normal file → Executable file
23
go.mod
Normal file → Executable file
@@ -1,10 +1,23 @@
|
||||
module github.com/jchenry/migrate
|
||||
module git.sdf.org/jchenry/migrate
|
||||
|
||||
go 1.16
|
||||
go 1.25
|
||||
|
||||
require github.com/mattn/go-sqlite3 v1.14.7
|
||||
require modernc.org/sqlite v1.44.0
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
modernc.org/libc v1.67.4 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
|
||||
retract (
|
||||
v0.0.1 // Published accidentally.
|
||||
v1.0.2 // Contains retractions only.
|
||||
v1.0.2 // Contains retractions only.
|
||||
v0.0.1 // Published accidentally.
|
||||
)
|
||||
|
||||
59
go.sum
Normal file → Executable file
59
go.sum
Normal file → Executable file
@@ -1,2 +1,57 @@
|
||||
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
|
||||
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
||||
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc=
|
||||
modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE=
|
||||
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
|
||||
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
97
migrate.go
Normal file → Executable file
97
migrate.go
Normal file → Executable file
@@ -8,84 +8,79 @@ import (
|
||||
|
||||
const table = "dbversion"
|
||||
|
||||
type Error struct {
|
||||
description string
|
||||
wrapped error
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("%s: %v", e.description, e.wrapped)
|
||||
}
|
||||
|
||||
func (e Error) Unwrap() error {
|
||||
return e.wrapped
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
type Change struct {
|
||||
Description string
|
||||
F func(ctx Context) error
|
||||
Apply func(ctx Context) error
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
|
||||
if err = initialize(ctx, d); err == nil {
|
||||
var currentVersion int64
|
||||
if currentVersion, err = dbVersion(ctx, d); err == nil {
|
||||
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
||||
for i, m := range migrations {
|
||||
if err = apply(ctx, d, m); err != nil {
|
||||
err = Error{
|
||||
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
|
||||
wrapped: err,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
func Apply(ctx Context, d Dialect, migrations []Change) error {
|
||||
if err := initialize(ctx, d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentVersion, err := dbVersion(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
||||
for _, m := range migrations {
|
||||
if err := apply(ctx, d, m); err != nil {
|
||||
return fmt.Errorf("error performing migration \"%s\": %w", m.Description, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func initialize(ctx Context, d Dialect) (err error) {
|
||||
if noVersionTable(ctx, d) {
|
||||
func initialize(ctx Context, d Dialect) error {
|
||||
if !versionTableExists(ctx, d) {
|
||||
return createVersionTable(ctx, d)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func noVersionTable(ctx Context, d Dialect) bool {
|
||||
func versionTableExists(ctx Context, d Dialect) bool {
|
||||
rows, table_check := ctx.Query(d.TableExists(table))
|
||||
if rows != nil {
|
||||
defer rows.Close()
|
||||
}
|
||||
return table_check != nil
|
||||
return table_check == nil
|
||||
}
|
||||
|
||||
func apply(ctx Context, d Dialect, r Record) (err error) {
|
||||
if err = r.F(ctx); err == nil {
|
||||
err = incrementVersion(ctx, d, r.Description)
|
||||
func apply(ctx Context, d Dialect, r Change) error {
|
||||
if err := r.Apply(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
return incrementVersion(ctx, d, r.Description)
|
||||
}
|
||||
|
||||
func createVersionTable(ctx Context, d Dialect) (err error) {
|
||||
_, err = ctx.Exec(d.CreateTable(table))
|
||||
return
|
||||
func createVersionTable(ctx Context, d Dialect) error {
|
||||
_, err := ctx.Exec(d.CreateTable(table))
|
||||
return err
|
||||
}
|
||||
|
||||
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
|
||||
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
|
||||
return
|
||||
func incrementVersion(ctx Context, d Dialect, description string) error {
|
||||
_, err := ctx.Exec(d.InsertVersion(table), description, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
|
||||
row, err := ctx.Query(d.CheckVersion(table))
|
||||
if row.Next() {
|
||||
err = row.Scan(&id)
|
||||
rows, err := ctx.Query(d.CheckVersion(table))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
64
migrate_test.go
Normal file → Executable file
64
migrate_test.go
Normal file → Executable file
@@ -3,28 +3,28 @@ package migrate
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestHelperFuncs(t *testing.T) {
|
||||
path, db, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = teardownTestDB(path, db); err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateVersionTable(t *testing.T) {
|
||||
path, db, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = createVersionTable(db, Sqlite3())
|
||||
@@ -33,14 +33,14 @@ func TestCreateVersionTable(t *testing.T) {
|
||||
}
|
||||
|
||||
if err = teardownTestDB(path, db); err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementVersion(t *testing.T) {
|
||||
path, db, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sl3 := Sqlite3()
|
||||
@@ -80,14 +80,14 @@ func TestIncrementVersion(t *testing.T) {
|
||||
}
|
||||
|
||||
if err = teardownTestDB(path, db); err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbVersion(t *testing.T) {
|
||||
path, db, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sl3 := Sqlite3()
|
||||
@@ -114,23 +114,23 @@ func TestDbVersion(t *testing.T) {
|
||||
// err = incrementVersion(db, d)
|
||||
|
||||
if err = teardownTestDB(path, db); err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply(t *testing.T) {
|
||||
path, db, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sl3 := Sqlite3()
|
||||
|
||||
records :=
|
||||
[]Record{
|
||||
[]Change{
|
||||
{
|
||||
Description: "create people table",
|
||||
F: func(ctx Context) (err error) {
|
||||
Apply: func(ctx Context) (err error) {
|
||||
_, err = ctx.Exec(`
|
||||
CREATE TABLE people (
|
||||
given_name VARCHAR(20),
|
||||
@@ -143,7 +143,7 @@ func TestApply(t *testing.T) {
|
||||
},
|
||||
{
|
||||
Description: "Insert a person into people",
|
||||
F: func(ctx Context) (err error) {
|
||||
Apply: func(ctx Context) (err error) {
|
||||
_, err = ctx.Exec(`INSERT INTO people VALUES('Henry','Colin','M', 42)`)
|
||||
return
|
||||
},
|
||||
@@ -159,7 +159,9 @@ func TestApply(t *testing.T) {
|
||||
r := db.QueryRow("SELECT given_name FROM people")
|
||||
|
||||
var given_name string
|
||||
r.Scan(&given_name)
|
||||
if err := r.Scan(&given_name); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if given_name != "Henry" {
|
||||
t.Fatalf("second migration did not complete: %s != %s", given_name, "Henry")
|
||||
@@ -179,9 +181,9 @@ func TestApply(t *testing.T) {
|
||||
|
||||
ishouldntHideUserErrors := errors.New("I should fail")
|
||||
|
||||
records = append(records, Record{
|
||||
records = append(records, Change{
|
||||
Description: "Insert a person into people",
|
||||
F: func(ctx Context) (err error) {
|
||||
Apply: func(ctx Context) (err error) {
|
||||
return ishouldntHideUserErrors
|
||||
},
|
||||
})
|
||||
@@ -200,23 +202,27 @@ func TestApply(t *testing.T) {
|
||||
}
|
||||
|
||||
if err = teardownTestDB(path, db); err != nil {
|
||||
t.Fail()
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func createTestDB() (path string, db *sql.DB, err error) {
|
||||
if f, err := ioutil.TempFile(os.TempDir(), "migrate-test-db"); err == nil {
|
||||
f.Close()
|
||||
if db, err := sql.Open("sqlite3", f.Name()); err == nil {
|
||||
return f.Name(), db, err
|
||||
}
|
||||
f, err := os.CreateTemp(os.TempDir(), "migrate-test-db")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
func teardownTestDB(path string, db *sql.DB) (err error) {
|
||||
if err = db.Close(); err == nil {
|
||||
err = os.Remove(path)
|
||||
f.Close()
|
||||
|
||||
db, err = sql.Open("sqlite", f.Name())
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return
|
||||
return f.Name(), db, nil
|
||||
}
|
||||
func teardownTestDB(path string, db *sql.DB) error {
|
||||
if err := db.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
42
postgres.go
Normal file
42
postgres.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Postgres returns a PostgreSQL dialect implementation
|
||||
func Postgres() Dialect {
|
||||
return postgres{}
|
||||
}
|
||||
|
||||
type postgres struct{}
|
||||
|
||||
// quoteIdentifier safely quotes a SQL identifier to prevent SQL injection
|
||||
// PostgreSQL uses double quotes for identifiers and doubles them for escaping
|
||||
func (p postgres) quoteIdentifier(identifier string) string {
|
||||
// Replace any existing quotes with double quotes (SQL escape mechanism)
|
||||
escaped := strings.ReplaceAll(identifier, `"`, `""`)
|
||||
return fmt.Sprintf(`"%s"`, escaped)
|
||||
}
|
||||
|
||||
func (p postgres) CreateTable(table string) string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
description VARCHAR,
|
||||
applied TIMESTAMP);`, p.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (p postgres) TableExists(table string) string {
|
||||
// PostgreSQL-specific way to check if table exists
|
||||
return fmt.Sprintf("SELECT 1 FROM %s LIMIT 1;", p.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (p postgres) CheckVersion(table string) string {
|
||||
return fmt.Sprintf("SELECT id FROM %s ORDER BY id DESC LIMIT 1;", p.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (p postgres) InsertVersion(table string) string {
|
||||
// PostgreSQL uses $1, $2 for placeholders instead of ?
|
||||
return fmt.Sprintf("INSERT INTO %s(description, applied) VALUES ($1, $2);", p.quoteIdentifier(table))
|
||||
}
|
||||
67
postgres_test.go
Normal file
67
postgres_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPostgresDialect(t *testing.T) {
|
||||
pg := Postgres()
|
||||
|
||||
t.Run("CreateTable", func(t *testing.T) {
|
||||
sql := pg.CreateTable("dbversion")
|
||||
if !strings.Contains(sql, "SERIAL PRIMARY KEY") {
|
||||
t.Errorf("Expected SERIAL PRIMARY KEY, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TableExists", func(t *testing.T) {
|
||||
sql := pg.TableExists("dbversion")
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, "SELECT 1") {
|
||||
t.Errorf("Expected SELECT 1 for existence check, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CheckVersion", func(t *testing.T) {
|
||||
sql := pg.CheckVersion("dbversion")
|
||||
if !strings.Contains(sql, "ORDER BY id DESC") {
|
||||
t.Errorf("Expected ORDER BY id DESC, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, "LIMIT 1") {
|
||||
t.Errorf("Expected LIMIT 1, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertVersion", func(t *testing.T) {
|
||||
sql := pg.InsertVersion("dbversion")
|
||||
// PostgreSQL uses $1, $2 placeholders
|
||||
if !strings.Contains(sql, "$1") || !strings.Contains(sql, "$2") {
|
||||
t.Errorf("Expected PostgreSQL placeholders ($1, $2), got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("QuoteIdentifier", func(t *testing.T) {
|
||||
pg := postgres{}
|
||||
|
||||
// Test normal identifier
|
||||
quoted := pg.quoteIdentifier("tablename")
|
||||
if quoted != `"tablename"` {
|
||||
t.Errorf("Expected quoted identifier, got: %s", quoted)
|
||||
}
|
||||
|
||||
// Test identifier with quotes (SQL injection attempt)
|
||||
quoted = pg.quoteIdentifier(`table"; DROP TABLE users; --`)
|
||||
if !strings.Contains(quoted, `""`) {
|
||||
t.Errorf("Expected escaped quotes, got: %s", quoted)
|
||||
}
|
||||
})
|
||||
}
|
||||
39
sqlite.go
Normal file
39
sqlite.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Sqlite3 returns a SQLite dialect implementation
|
||||
func Sqlite3() Dialect {
|
||||
return sqlite3{}
|
||||
}
|
||||
|
||||
type sqlite3 struct{}
|
||||
|
||||
// quoteIdentifier safely quotes a SQL identifier to prevent SQL injection
|
||||
func (s sqlite3) quoteIdentifier(identifier string) string {
|
||||
// Replace any existing quotes with double quotes (SQL escape mechanism)
|
||||
escaped := strings.ReplaceAll(identifier, `"`, `""`)
|
||||
return fmt.Sprintf(`"%s"`, escaped)
|
||||
}
|
||||
|
||||
func (s sqlite3) CreateTable(table string) string {
|
||||
return fmt.Sprintf(`CREATE TABLE %s (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description VARCHAR,
|
||||
applied TIMESTAMP);`, s.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (s sqlite3) TableExists(table string) string {
|
||||
return fmt.Sprintf("SELECT * FROM %s;", s.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (s sqlite3) CheckVersion(table string) string {
|
||||
return fmt.Sprintf("SELECT id FROM %s ORDER BY id DESC LIMIT 0, 1;", s.quoteIdentifier(table))
|
||||
}
|
||||
|
||||
func (s sqlite3) InsertVersion(table string) string {
|
||||
return fmt.Sprintf("INSERT INTO %s(description, applied) VALUES (?,?);", s.quoteIdentifier(table))
|
||||
}
|
||||
67
sqlite_test.go
Normal file
67
sqlite_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSqlite3Dialect(t *testing.T) {
|
||||
sl3 := Sqlite3()
|
||||
|
||||
t.Run("CreateTable", func(t *testing.T) {
|
||||
sql := sl3.CreateTable("dbversion")
|
||||
if !strings.Contains(sql, "INTEGER PRIMARY KEY AUTOINCREMENT") {
|
||||
t.Errorf("Expected INTEGER PRIMARY KEY AUTOINCREMENT, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TableExists", func(t *testing.T) {
|
||||
sql := sl3.TableExists("dbversion")
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, "SELECT *") {
|
||||
t.Errorf("Expected SELECT *, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CheckVersion", func(t *testing.T) {
|
||||
sql := sl3.CheckVersion("dbversion")
|
||||
if !strings.Contains(sql, "ORDER BY id DESC") {
|
||||
t.Errorf("Expected ORDER BY id DESC, got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, "LIMIT 0, 1") {
|
||||
t.Errorf("Expected LIMIT 0, 1, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertVersion", func(t *testing.T) {
|
||||
sql := sl3.InsertVersion("dbversion")
|
||||
// SQLite uses ? placeholders
|
||||
if !strings.Contains(sql, "VALUES (?,?)") {
|
||||
t.Errorf("Expected SQLite placeholders (?,?), got: %s", sql)
|
||||
}
|
||||
if !strings.Contains(sql, `"dbversion"`) {
|
||||
t.Errorf("Expected quoted table name, got: %s", sql)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("QuoteIdentifier", func(t *testing.T) {
|
||||
sl3 := sqlite3{}
|
||||
|
||||
// Test normal identifier
|
||||
quoted := sl3.quoteIdentifier("tablename")
|
||||
if quoted != `"tablename"` {
|
||||
t.Errorf("Expected quoted identifier, got: %s", quoted)
|
||||
}
|
||||
|
||||
// Test identifier with quotes (SQL injection attempt)
|
||||
quoted = sl3.quoteIdentifier(`table"; DROP TABLE users; --`)
|
||||
if !strings.Contains(quoted, `""`) {
|
||||
t.Errorf("Expected escaped quotes, got: %s", quoted)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user