diff options
Diffstat (limited to 'compat/sqlite')
| -rw-r--r-- | compat/sqlite/go.mod | 3 | ||||
| -rw-r--r-- | compat/sqlite/sqlite.go | 265 |
2 files changed, 268 insertions, 0 deletions
diff --git a/compat/sqlite/go.mod b/compat/sqlite/go.mod new file mode 100644 index 0000000..8779817 --- /dev/null +++ b/compat/sqlite/go.mod @@ -0,0 +1,3 @@ +module modernc.org/sqlite + +go 1.22 diff --git a/compat/sqlite/sqlite.go b/compat/sqlite/sqlite.go new file mode 100644 index 0000000..be8c0c9 --- /dev/null +++ b/compat/sqlite/sqlite.go @@ -0,0 +1,265 @@ +// Package sqlite is a CGO-backed compatibility shim that provides the same +// driver registration as modernc.org/sqlite: it registers a "sqlite3" +// database/sql driver backed by the system libsqlite3. +// +// When network access is available this package can be replaced with the real +// modernc.org/sqlite (pure-Go, no CGO) by removing the replace directive from +// the root go.mod and running: go get modernc.org/sqlite +package sqlite + +/* +#cgo pkg-config: sqlite3 +#include <sqlite3.h> +#include <stdlib.h> + +static int bind_text(sqlite3_stmt *s, int i, const char *v) { + return sqlite3_bind_text(s, i, v, -1, SQLITE_TRANSIENT); +} +static void enable_wal(sqlite3 *db) { + sqlite3_exec(db, "PRAGMA journal_mode=WAL", NULL, NULL, NULL); + sqlite3_exec(db, "PRAGMA synchronous=NORMAL", NULL, NULL, NULL); +} +*/ +import "C" + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "time" + "unsafe" +) + +func init() { + sql.Register("sqlite3", &sqliteDriver{}) +} + +// ── Driver ──────────────────────────────────────────────────────────────────── + +type sqliteDriver struct{} + +func (*sqliteDriver) Open(name string) (driver.Conn, error) { + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + + var db *C.sqlite3 + flags := C.int(C.SQLITE_OPEN_READWRITE | C.SQLITE_OPEN_CREATE | C.SQLITE_OPEN_FULLMUTEX) + if rc := C.sqlite3_open_v2(cname, &db, flags, nil); rc != C.SQLITE_OK { + msg := C.GoString(C.sqlite3_errmsg(db)) + C.sqlite3_close(db) + return nil, fmt.Errorf("sqlite open %s: %s", name, msg) + } + C.enable_wal(db) + return &conn{db: db}, nil +} + +// ── Conn ───────────────────────────────────────────────────────────────────── + +type conn struct{ db *C.sqlite3 } + +func (c *conn) Close() error { + C.sqlite3_close(c.db) + return nil +} + +func (c *conn) Begin() (driver.Tx, error) { + if err := c.execRaw("BEGIN"); err != nil { + return nil, err + } + return &tx{c}, nil +} + +func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { + if len(args) == 0 { + cq := C.CString(query) + defer C.free(unsafe.Pointer(cq)) + var cerr *C.char + if rc := C.sqlite3_exec(c.db, cq, nil, nil, &cerr); rc != C.SQLITE_OK { + msg := C.GoString(cerr) + C.sqlite3_free(unsafe.Pointer(cerr)) + return nil, errors.New(msg) + } + return &result{ + lastID: int64(C.sqlite3_last_insert_rowid(c.db)), + affected: int64(C.sqlite3_changes(c.db)), + }, nil + } + st, err := c.Prepare(query) + if err != nil { + return nil, err + } + defer st.Close() + return st.Exec(args) +} + +func (c *conn) Prepare(query string) (driver.Stmt, error) { + cq := C.CString(query) + defer C.free(unsafe.Pointer(cq)) + var s *C.sqlite3_stmt + if rc := C.sqlite3_prepare_v2(c.db, cq, -1, &s, nil); rc != C.SQLITE_OK { + return nil, fmt.Errorf("prepare: %s", C.GoString(C.sqlite3_errmsg(c.db))) + } + return &stmt{c: c, s: s}, nil +} + +func (c *conn) execRaw(q string) error { + cq := C.CString(q) + defer C.free(unsafe.Pointer(cq)) + var cerr *C.char + if rc := C.sqlite3_exec(c.db, cq, nil, nil, &cerr); rc != C.SQLITE_OK { + msg := C.GoString(cerr) + C.sqlite3_free(unsafe.Pointer(cerr)) + return errors.New(msg) + } + return nil +} + +// ── Tx ─────────────────────────────────────────────────────────────────────── + +type tx struct{ c *conn } + +func (t *tx) Commit() error { return t.c.execRaw("COMMIT") } +func (t *tx) Rollback() error { return t.c.execRaw("ROLLBACK") } + +// ── Stmt ───────────────────────────────────────────────────────────────────── + +type stmt struct { + c *conn + s *C.sqlite3_stmt +} + +func (st *stmt) Close() error { + C.sqlite3_finalize(st.s) + return nil +} + +func (st *stmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(st.s)) } + +func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { + C.sqlite3_reset(st.s) + if err := st.bind(args); err != nil { + return nil, err + } + rc := C.sqlite3_step(st.s) + if rc != C.SQLITE_DONE && rc != C.SQLITE_ROW { + return nil, fmt.Errorf("exec: %s", C.GoString(C.sqlite3_errmsg(st.c.db))) + } + return &result{ + lastID: int64(C.sqlite3_last_insert_rowid(st.c.db)), + affected: int64(C.sqlite3_changes(st.c.db)), + }, nil +} + +func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { + C.sqlite3_reset(st.s) + if err := st.bind(args); err != nil { + return nil, err + } + ncols := int(C.sqlite3_column_count(st.s)) + cols := make([]string, ncols) + for i := range cols { + cols[i] = C.GoString(C.sqlite3_column_name(st.s, C.int(i))) + } + return &rows{st: st, cols: cols}, nil +} + +func (st *stmt) bind(args []driver.Value) error { + for i, arg := range args { + n := C.int(i + 1) + var rc C.int + switch v := arg.(type) { + case nil: + rc = C.sqlite3_bind_null(st.s, n) + case int64: + rc = C.sqlite3_bind_int64(st.s, n, C.sqlite3_int64(v)) + case float64: + rc = C.sqlite3_bind_double(st.s, n, C.double(v)) + case bool: + b := C.int(0) + if v { + b = 1 + } + rc = C.sqlite3_bind_int(st.s, n, b) + case string: + cs := C.CString(v) + rc = C.bind_text(st.s, n, cs) + C.free(unsafe.Pointer(cs)) + case []byte: + if len(v) == 0 { + rc = C.sqlite3_bind_null(st.s, n) + } else { + rc = C.sqlite3_bind_blob(st.s, n, + unsafe.Pointer(&v[0]), C.int(len(v)), C.SQLITE_TRANSIENT) + } + case time.Time: + s := v.UTC().Format(time.RFC3339) + cs := C.CString(s) + rc = C.bind_text(st.s, n, cs) + C.free(unsafe.Pointer(cs)) + default: + return fmt.Errorf("unsupported bind type %T at index %d", arg, i) + } + if rc != C.SQLITE_OK { + return fmt.Errorf("bind[%d]: %s", i, C.GoString(C.sqlite3_errmsg(st.c.db))) + } + } + return nil +} + +// ── Rows ───────────────────────────────────────────────────────────────────── + +type rows struct { + st *stmt + cols []string +} + +func (r *rows) Columns() []string { return r.cols } + +func (r *rows) Close() error { + C.sqlite3_reset(r.st.s) + return nil +} + +func (r *rows) Next(dest []driver.Value) error { + rc := C.sqlite3_step(r.st.s) + if rc == C.SQLITE_DONE { + return io.EOF + } + if rc != C.SQLITE_ROW { + return fmt.Errorf("next: %s", C.GoString(C.sqlite3_errmsg(r.st.c.db))) + } + for i := range dest { + switch C.sqlite3_column_type(r.st.s, C.int(i)) { + case C.SQLITE_INTEGER: + dest[i] = int64(C.sqlite3_column_int64(r.st.s, C.int(i))) + case C.SQLITE_FLOAT: + dest[i] = float64(C.sqlite3_column_double(r.st.s, C.int(i))) + case C.SQLITE_TEXT: + dest[i] = C.GoString((*C.char)(unsafe.Pointer( + C.sqlite3_column_text(r.st.s, C.int(i))))) + case C.SQLITE_BLOB: + sz := int(C.sqlite3_column_bytes(r.st.s, C.int(i))) + b := make([]byte, sz) + if sz > 0 { + ptr := C.sqlite3_column_blob(r.st.s, C.int(i)) + copy(b, (*[1 << 28]byte)(ptr)[:sz:sz]) + } + dest[i] = b + default: // SQLITE_NULL + dest[i] = nil + } + } + return nil +} + +// ── Result ──────────────────────────────────────────────────────────────────── + +type result struct { + lastID int64 + affected int64 +} + +func (r *result) LastInsertId() (int64, error) { return r.lastID, nil } +func (r *result) RowsAffected() (int64, error) { return r.affected, nil } |
