@@ -3,36 +3,47 @@ package client
33import (
44 "context"
55 "fmt"
6+ "os"
7+ "path/filepath"
68 "watchAlert/internal/global"
79 "watchAlert/internal/models"
810
911 "github.com/zeromicro/go-zero/core/logc"
1012 "gorm.io/driver/mysql"
13+ "gorm.io/driver/sqlite"
1114 "gorm.io/gorm"
1215 "gorm.io/gorm/logger"
1316)
1417
1518type DBConfig struct {
16- Host string
17- Port string
18- User string
19- Pass string
20- DBName string
21- Timeout string
19+ Type string // 数据库类型: mysql 或 sqlite
20+ Host string // MySQL 主机地址
21+ Port string // MySQL 端口
22+ User string // MySQL 用户名
23+ Pass string // MySQL 密码
24+ DBName string // MySQL 数据库名
25+ Timeout string // MySQL 连接超时
26+ Path string // SQLite 数据库文件路径
2227}
2328
2429func NewDBClient (config DBConfig ) * gorm.DB {
25- // 初始化本地 test. db 数据库文件
26- //db, err := gorm.Open(sqlite.Open("data/sql.db"), &gorm.Config{})
30+ var db * gorm. DB
31+ var err error
2732
28- dsn := fmt .Sprintf ("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4,utf8&parseTime=True&loc=Local&timeout=%s" ,
29- config .User ,
30- config .Pass ,
31- config .Host ,
32- config .Port ,
33- config .DBName ,
34- config .Timeout )
35- db , err := gorm .Open (mysql .Open (dsn ), & gorm.Config {})
33+ // 设置默认数据库类型为 mysql
34+ if config .Type == "" {
35+ config .Type = "mysql"
36+ }
37+
38+ switch config .Type {
39+ case "sqlite" :
40+ db , err = initSQLiteDB (config )
41+ case "mysql" :
42+ db , err = initMySQLDB (config )
43+ default :
44+ logc .Errorf (context .Background (), "unsupported database type: %s" , config .Type )
45+ return nil
46+ }
3647
3748 if err != nil {
3849 logc .Errorf (context .Background (), "failed to connect database: %s" , err .Error ())
@@ -83,3 +94,34 @@ func NewDBClient(config DBConfig) *gorm.DB {
8394
8495 return db
8596}
97+
98+ // initMySQLDB 初始化 MySQL 数据库连接
99+ func initMySQLDB (config DBConfig ) (* gorm.DB , error ) {
100+ dsn := fmt .Sprintf ("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4,utf8&parseTime=True&loc=Local&timeout=%s" ,
101+ config .User ,
102+ config .Pass ,
103+ config .Host ,
104+ config .Port ,
105+ config .DBName ,
106+ config .Timeout )
107+
108+ logc .Infof (context .Background (), "connecting to MySQL database: %s:%s/%s" , config .Host , config .Port , config .DBName )
109+ return gorm .Open (mysql .Open (dsn ), & gorm.Config {})
110+ }
111+
112+ // initSQLiteDB 初始化 SQLite 数据库连接
113+ func initSQLiteDB (config DBConfig ) (* gorm.DB , error ) {
114+ // 设置默认 SQLite 文件路径
115+ if config .Path == "" {
116+ config .Path = "data/watchalert.db"
117+ }
118+
119+ // 确保目录存在
120+ dir := filepath .Dir (config .Path )
121+ if err := os .MkdirAll (dir , 0755 ); err != nil {
122+ return nil , fmt .Errorf ("failed to create directory %s: %w" , dir , err )
123+ }
124+
125+ logc .Infof (context .Background (), "connecting to SQLite database: %s" , config .Path )
126+ return gorm .Open (sqlite .Open (config .Path ), & gorm.Config {})
127+ }
0 commit comments