diff --git a/.golangci.yml b/.golangci.yml index d775fb76..ba826d15 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -56,7 +56,6 @@ linters: - contextcheck - durationcheck - errorlint - - goconst - goimports - revive - misspell diff --git a/api/api.go b/api/api.go index db941b99..6edcb19d 100644 --- a/api/api.go +++ b/api/api.go @@ -39,7 +39,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/dataloader" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/gql" "github.com/satisfactorymodding/smr-api/migrations" @@ -79,7 +78,6 @@ func Initialize(baseCtx context.Context) context.Context { } redis.InitializeRedis(ctx) - postgres.InitializePostgres(ctx) ctx, err := db.WithDB(ctx) if err != nil { diff --git a/conversion/ent_to_graphql.go b/conversion/ent_to_graphql.go index 1ea5e3ae..8f954f92 100644 --- a/conversion/ent_to_graphql.go +++ b/conversion/ent_to_graphql.go @@ -83,7 +83,7 @@ type Mod interface { // goverter:extend TimeToString UIntToInt Int64ToInt type Version interface { // goverter:map Edges.Targets Targets - // goverter:ignore Link Mod Dependencies Size Hash + // goverter:ignore Link Mod Dependencies Convert(source *ent.Version) *generated.Version ConvertSlice(source []*ent.Version) []*generated.Version diff --git a/db/mod.go b/db/mod.go new file mode 100644 index 00000000..61b6f134 --- /dev/null +++ b/db/mod.go @@ -0,0 +1,75 @@ +package db + +import ( + "strings" + + "entgo.io/ent/dialect/sql" + + "github.com/satisfactorymodding/smr-api/generated" + "github.com/satisfactorymodding/smr-api/generated/ent" + "github.com/satisfactorymodding/smr-api/generated/ent/mod" + "github.com/satisfactorymodding/smr-api/generated/ent/modtag" + "github.com/satisfactorymodding/smr-api/models" +) + +func ConvertModFilter(query *ent.ModQuery, filter *models.ModFilter, count bool, unapproved bool) *ent.ModQuery { + query = query.WithTags() + + if len(filter.Ids) > 0 { + query = query.Where(mod.IDIn(filter.Ids...)) + } else if len(filter.References) > 0 { + query = query.Where(mod.ModReferenceIn(filter.References...)) + } else if filter != nil { + query = query. + Limit(*filter.Limit). + Offset(*filter.Offset) + + if *filter.OrderBy != generated.ModFieldsSearch { + if string(*filter.OrderBy) == "last_version_date" { + query = query.Modify(func(s *sql.Selector) { + s.OrderExpr(sql.ExprP("case when last_version_date is null then 1 else 0 end, last_version_date")) + }).Clone() + } else { + query = query.Order(sql.OrderByField( + filter.OrderBy.String(), + OrderToOrder(filter.Order.String()), + ).ToFunc()) + } + } + + if filter.Search != nil && *filter.Search != "" { + cleanSearch := strings.ReplaceAll(strings.TrimSpace(*filter.Search), " ", " & ") + + query = query.Where(func(s *sql.Selector) { + join := sql.SelectExpr(sql.ExprP("id, (similarity(name, ?) * 2 + similarity(short_description, ?) + similarity(full_description, ?) * 0.5) as s", cleanSearch, cleanSearch, cleanSearch)) + join.From(sql.Table(mod.Table)).As("t1") + s.Join(join).On(s.C(mod.FieldID), join.C("id")) + }) + + query = query.Where(func(s *sql.Selector) { + s.Where(sql.ExprP(`"t1"."s" > 0.2`)) + }) + + if !count && *filter.OrderBy == generated.ModFieldsSearch { + query = query.Order(func(s *sql.Selector) { + s.OrderExpr(sql.ExprP(`"t1"."s" DESC`)) + }) + } + } + + if filter.Hidden == nil || !(*filter.Hidden) { + query = query.Where(mod.Hidden(false)) + } + + if filter.TagIDs != nil && len(filter.TagIDs) > 0 { + query = query.Where(func(s *sql.Selector) { + t := sql.Table(modtag.Table) + s.Join(t).OnP(sql.ExprP("mod_tags.tag_id in ? AND mod_tags.mod_id = mods.id", filter.TagIDs)) + }) + } + } + + query = query.Where(mod.Approved(!unapproved), mod.Denied(false)) + + return query +} diff --git a/db/oauth.go b/db/oauth.go new file mode 100644 index 00000000..d34976ce --- /dev/null +++ b/db/oauth.go @@ -0,0 +1,103 @@ +package db + +import ( + "bytes" + "context" + + "github.com/satisfactorymodding/smr-api/generated/ent" + "github.com/satisfactorymodding/smr-api/generated/ent/user" + "github.com/satisfactorymodding/smr-api/oauth" + "github.com/satisfactorymodding/smr-api/storage" + "github.com/satisfactorymodding/smr-api/util" +) + +func CompleteOAuthFlow(ctx context.Context, u *oauth.UserData, userAgent string) (*string, error) { + avatarURL := u.Avatar + u.Avatar = "" + + find := From(ctx).User.Query().Where(user.Email(u.Email)) + + if u.Site == oauth.SiteGithub { + find = find.Where(user.GithubID(u.ID)) + } else if u.Site == oauth.SiteGoogle { + find = find.Where(user.GoogleID(u.ID)) + } else if u.Site == oauth.SiteFacebook { + find = find.Where(user.FacebookID(u.ID)) + } + + found, err := find.First(ctx) + if err != nil && !ent.IsNotFound(err) { + return nil, err + } + + newUser := false + if ent.IsNotFound(err) { + var err error + create := From(ctx).User. + Create(). + SetEmail(u.Email). + SetAvatar(u.Avatar). + SetJoinedFrom(string(u.Site)). + SetUsername(u.Username) + + if u.Site == oauth.SiteGithub { + create = create.SetGithubID(u.ID) + } else if u.Site == oauth.SiteGoogle { + create = create.SetGoogleID(u.ID) + } else if u.Site == oauth.SiteFacebook { + create = create.SetFacebookID(u.ID) + } + + found, err = create.Save(ctx) + if err != nil { + return nil, err + } + + newUser = true + } + + if !newUser { + var update *ent.UserUpdateOne + if u.Site == oauth.SiteGithub && found.GithubID == "" { + update = found.Update().SetGithubID(u.ID) + } else if u.Site == oauth.SiteGoogle && found.GoogleID == "" { + update = found.Update().SetGoogleID(u.ID) + } else if u.Site == oauth.SiteFacebook && found.FacebookID == "" { + update = found.Update().SetFacebookID(u.ID) + } + + if update != nil { + if err := update.Exec(ctx); err != nil { + return nil, err + } + } + } + + // TODO Archive old deleted sessions to cold storage + + session, err := From(ctx).UserSession. + Create(). + SetUserID(found.ID). + SetToken(util.GenerateUserToken()). + SetUserAgent(userAgent). + Save(ctx) + if err != nil { + return nil, err + } + + if avatarURL != "" && newUser { + avatarData, err := util.LinkToWebp(ctx, avatarURL) + if err != nil { + return nil, err + } + + success, avatarKey := storage.UploadUserAvatar(ctx, found.ID, bytes.NewReader(avatarData)) + if success { + if err := found.Update().SetAvatar(storage.GenerateDownloadLink(avatarKey)).Exec(ctx); err != nil { + return nil, err + } + } + } + + return &session.Token, nil +} diff --git a/db/postgres/mod.go b/db/postgres/mod.go deleted file mode 100644 index 4dbf1cb0..00000000 --- a/db/postgres/mod.go +++ /dev/null @@ -1,119 +0,0 @@ -package postgres - -import ( - "context" - "fmt" - "strings" - - "github.com/patrickmn/go-cache" -) - -func GetModByID(ctx context.Context, modID string) *Mod { - cacheKey := "GetModById_" + modID - if mod, ok := dbCache.Get(cacheKey); ok { - return mod.(*Mod) - } - - return GetModByIDNoCache(ctx, modID) -} - -func GetModByIDNoCache(ctx context.Context, modID string) *Mod { - var mod Mod - DBCtx(ctx).Preload("Tags").Preload("Versions.Targets").Find(&mod, "id = ?", modID) - - if mod.ID == "" { - return nil - } - - dbCache.Set("GetModById_"+modID, &mod, cache.DefaultExpiration) - - return &mod -} - -func GetModsByID(ctx context.Context, modIds []string) []Mod { - cacheKey := "GetModsById_" + strings.Join(modIds, ":") - if mods, ok := dbCache.Get(cacheKey); ok { - return mods.([]Mod) - } - - var mods []Mod - DBCtx(ctx).Preload("Tags").Find(&mods, "id in (?)", modIds) - - if len(modIds) != len(mods) { - return nil - } - - dbCache.Set(cacheKey, mods, cache.DefaultExpiration) - - return mods -} - -func GetModCount(ctx context.Context, search string, unapproved bool) int64 { - cacheLey := "GetModCount_" + search + "_" + fmt.Sprint(unapproved) - if count, ok := dbCache.Get(cacheLey); ok { - return count.(int64) - } - - var modCount int64 - query := DBCtx(ctx).Model(Mod{}).Where("approved = ? AND denied = ?", !unapproved, false) - - if search != "" { - query = query.Where("to_tsvector(name) @@ to_tsquery(?)", strings.ReplaceAll(search, " ", " & ")) - } - - query.Count(&modCount) - - dbCache.Set(cacheLey, modCount, cache.DefaultExpiration) - - return modCount -} - -func IncrementModViews(ctx context.Context, mod *Mod) { - DBCtx(ctx).Model(mod).Update("views", mod.Views+1) -} - -func GetMods(ctx context.Context, limit int, offset int, orderBy string, order string, search string, unapproved bool) []Mod { - cacheKey := "GetMods_" + fmt.Sprint(limit) + "_" + fmt.Sprint(offset) + "_" + orderBy + "_" + order + "_" + search + "_" + fmt.Sprint(unapproved) - if mods, ok := dbCache.Get(cacheKey); ok { - return mods.([]Mod) - } - - var mods []Mod - query := DBCtx(ctx).Limit(limit).Offset(offset) - - if orderBy == "last_version_date" { - query = query.Order("case when last_version_date is null then 1 else 0 end, last_version_date") - } else { - query = query.Order(orderBy + " " + order) - } - - query = query.Where("approved = ? AND denied = ?", !unapproved, false) - - if search != "" { - query = query.Where("to_tsvector(name) @@ to_tsquery(?)", strings.ReplaceAll(search, " ", " & ")) - } - - query.Find(&mods) - - dbCache.Set(cacheKey, mods, cache.DefaultExpiration) - - return mods -} - -func GetModByIDOrReference(ctx context.Context, modIDOrReference string) *Mod { - cacheKey := "GetModByIDOrReference_" + modIDOrReference - if mod, ok := dbCache.Get(cacheKey); ok { - return mod.(*Mod) - } - - var mod Mod - DBCtx(ctx).Preload("Tags").Preload("Versions.Targets").Find(&mod, "mod_reference = ? OR id = ?", modIDOrReference, modIDOrReference) - - if mod.ID == "" { - return nil - } - - dbCache.Set(cacheKey, &mod, cache.DefaultExpiration) - - return &mod -} diff --git a/db/postgres/otel/callbacks.go b/db/postgres/otel/callbacks.go deleted file mode 100644 index cc691f7d..00000000 --- a/db/postgres/otel/callbacks.go +++ /dev/null @@ -1,77 +0,0 @@ -package otel - -import ( - "strings" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - semconv "go.opentelemetry.io/otel/semconv/v1.7.0" - oteltrace "go.opentelemetry.io/otel/trace" - "gorm.io/gorm" -) - -const ( - spanName = "gorm.query" - - dbTableKey = attribute.Key("db.table") - dbCountKey = attribute.Key("db.count") - dbOperationKey = semconv.DBOperationKey - dbStatementKey = semconv.DBStatementKey -) - -func dbTable(name string) attribute.KeyValue { - return dbTableKey.String(name) -} - -func dbStatement(stmt string) attribute.KeyValue { - return dbStatementKey.String(stmt) -} - -func dbCount(n int64) attribute.KeyValue { - return dbCountKey.Int64(n) -} - -func dbOperation(op string) attribute.KeyValue { - return dbOperationKey.String(op) -} - -func (op *Plugin) before(tx *gorm.DB) { - tx.Statement.Context, _ = op.tracer. - Start(tx.Statement.Context, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient)) -} - -func extractQuery(tx *gorm.DB) string { - return tx.Dialector.Explain(tx.Statement.SQL.String(), tx.Statement.Vars...) -} - -func (op *Plugin) after(operation string) gormHookFunc { - return func(tx *gorm.DB) { - span := oteltrace.SpanFromContext(tx.Statement.Context) - if !span.IsRecording() { - // skip the reporting if not recording - return - } - defer span.End() - - // Error - if tx.Error != nil { - span.SetStatus(codes.Error, tx.Error.Error()) - } - - // extract the db operation - query := extractQuery(tx) - if operation == "" { - operation = strings.ToUpper(strings.Split(query, " ")[0]) - } - - if tx.Statement.Table != "" { - span.SetAttributes(dbTable(tx.Statement.Table)) - } - - span.SetAttributes( - dbStatement(query), - dbOperation(operation), - dbCount(tx.Statement.RowsAffected), - ) - } -} diff --git a/db/postgres/otel/config.go b/db/postgres/otel/config.go deleted file mode 100644 index 996e6806..00000000 --- a/db/postgres/otel/config.go +++ /dev/null @@ -1,28 +0,0 @@ -package otel - -import ( - oteltrace "go.opentelemetry.io/otel/trace" -) - -type config struct { - tracerProvider oteltrace.TracerProvider - serviceName string -} - -// Option is used to configure the client. -type Option func(*config) - -// WithTracerProvider specifies a tracer provider to use for creating a tracer. -// If none is specified, the global provider is used. -func WithTracerProvider(provider oteltrace.TracerProvider) Option { - return func(cfg *config) { - cfg.tracerProvider = provider - } -} - -// WithServiceName sets the service name. -func WithServiceName(serviceName string) Option { - return func(cfg *config) { - cfg.serviceName = serviceName - } -} diff --git a/db/postgres/otel/plugin.go b/db/postgres/otel/plugin.go deleted file mode 100644 index f4a0dd5b..00000000 --- a/db/postgres/otel/plugin.go +++ /dev/null @@ -1,98 +0,0 @@ -package otel - -import ( - "fmt" - - "go.opentelemetry.io/contrib" - "go.opentelemetry.io/otel" - oteltrace "go.opentelemetry.io/otel/trace" - "gorm.io/gorm" -) - -const ( - defaultTracerName = "go.opentelemetry.io/contrib/instrumentation/github.com/go-gorm/gorm/otelgorm" - defaultServiceName = "gorm" - - callBackBeforeName = "otel:before" - callBackAfterName = "otel:after" -) - -type gormHookFunc func(tx *gorm.DB) - -type Plugin struct { - cfg *config - tracer oteltrace.Tracer -} - -func (op *Plugin) Name() string { - return "OpenTelemetryPlugin" -} - -// NewPlugin initialize a new gorm.DB plugin that traces queries -// You may pass optional Options to the function -func NewPlugin(opts ...Option) *Plugin { - cfg := &config{} - for _, o := range opts { - o(cfg) - } - - if cfg.serviceName == "" { - cfg.serviceName = defaultServiceName - } - - if cfg.tracerProvider == nil { - cfg.tracerProvider = otel.GetTracerProvider() - } - - return &Plugin{ - cfg: cfg, - tracer: cfg.tracerProvider.Tracer( - defaultTracerName, - oteltrace.WithInstrumentationVersion(contrib.Version()), - ), - } -} - -type registerCallback interface { - Register(name string, fn func(*gorm.DB)) error -} - -func beforeName(name string) string { - return callBackBeforeName + "_" + name -} - -func afterName(name string) string { - return callBackAfterName + "_" + name -} - -func (op *Plugin) Initialize(db *gorm.DB) error { - registerHooks := []struct { - callback registerCallback - hook gormHookFunc - name string - }{ - // before hooks - {db.Callback().Create().Before("gorm:before_create"), op.before, beforeName("create")}, - {db.Callback().Query().Before("gorm:query"), op.before, beforeName("query")}, - {db.Callback().Delete().Before("gorm:before_delete"), op.before, beforeName("delete")}, - {db.Callback().Update().Before("gorm:before_update"), op.before, beforeName("update")}, - {db.Callback().Row().Before("gorm:row"), op.before, beforeName("row")}, - {db.Callback().Raw().Before("gorm:raw"), op.before, beforeName("raw")}, - - // after hooks - {db.Callback().Create().After("gorm:after_create"), op.after("INSERT"), afterName("create")}, - {db.Callback().Query().After("gorm:after_query"), op.after("SELECT"), afterName("select")}, - {db.Callback().Delete().After("gorm:after_delete"), op.after("DELETE"), afterName("delete")}, - {db.Callback().Update().After("gorm:after_update"), op.after("UPDATE"), afterName("update")}, - {db.Callback().Row().After("gorm:row"), op.after(""), afterName("row")}, - {db.Callback().Raw().After("gorm:raw"), op.after(""), afterName("raw")}, - } - - for _, h := range registerHooks { - if err := h.callback.Register(h.name, h.hook); err != nil { - return fmt.Errorf("register %s hook: %w", h.name, err) - } - } - - return nil -} diff --git a/db/postgres/postgres.go b/db/postgres/postgres.go deleted file mode 100644 index 4c6b7a78..00000000 --- a/db/postgres/postgres.go +++ /dev/null @@ -1,140 +0,0 @@ -package postgres - -import ( - "context" - "fmt" - "log/slog" - "time" - - "github.com/Vilsol/slox" - "github.com/patrickmn/go-cache" - "github.com/spf13/viper" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/logger" - "gorm.io/gorm/utils" - - "github.com/satisfactorymodding/smr-api/db/postgres/otel" -) - -var ( - db *gorm.DB - dbCache *cache.Cache -) - -type GormLogger struct { - SlowThreshold time.Duration - Debug bool -} - -func (l *GormLogger) LogMode(mode logger.LogLevel) logger.Interface { - return &GormLogger{ - SlowThreshold: l.SlowThreshold, - Debug: mode >= 4, - } -} - -func (*GormLogger) Info(ctx context.Context, msg string, data ...interface{}) { - slox.Info(ctx, fmt.Sprintf(msg, data...), slog.String("file", utils.FileWithLineNum())) -} - -func (*GormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { - slox.Warn(ctx, fmt.Sprintf(msg, data...), slog.String("file", utils.FileWithLineNum())) -} - -func (*GormLogger) Error(ctx context.Context, msg string, data ...interface{}) { - slox.Error(ctx, fmt.Sprintf(msg, data...), slog.String("file", utils.FileWithLineNum())) -} - -func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { - since := time.Since(begin) - elapsed := float64(since.Nanoseconds()) / 1e6 - - sql, rows := fc() - - level := slog.LevelInfo - attrs := make([]slog.Attr, 0) - toLog := false - switch { - case err != nil: - level = slog.LevelError - attrs = append(attrs, slog.Any("err", err)) - toLog = true - case since > l.SlowThreshold && l.SlowThreshold != 0: - level = slog.LevelWarn - toLog = true - case l.Debug: - level = slog.LevelInfo - toLog = true - } - - if toLog { - attrs = append(attrs, slog.Float64("elapsed", elapsed)) - attrs = append(attrs, slog.Int64("rows", rows)) - slog.LogAttrs(ctx, level, sql, attrs...) - } -} - -var debugEnabled = false - -func InitializePostgres(ctx context.Context) { - connection := postgres.Open(fmt.Sprintf( - "sslmode=disable host=%s port=%d user=%s dbname=%s password=%s", - viper.GetString("database.postgres.host"), - viper.GetInt("database.postgres.port"), - viper.GetString("database.postgres.user"), - viper.GetString("database.postgres.db"), - viper.GetString("database.postgres.pass"), - )) - - dbInit, err := gorm.Open(connection, &gorm.Config{ - Logger: &GormLogger{ - SlowThreshold: time.Millisecond * 50, - }, - }) - if err != nil { - panic(err) - } - - err = dbInit.Use(otel.NewPlugin()) - if err != nil { - panic(err) - } - - db = dbInit - - if debugEnabled { - db = db.Debug() - } - - dbCache = cache.New(time.Second*5, time.Second*10) - - // TODO Create search indexes - - slox.Info(ctx, "Postgres initialized") -} - -func Save(ctx context.Context, object interface{}) { - DBCtx(ctx).Save(object) -} - -func DBCtx(ctx context.Context) *gorm.DB { - if ctx != nil { - dbCtx := DBFromContext(ctx) - if dbCtx != nil { - return dbCtx - } - - return db.WithContext(ctx) - } - - return db -} - -func EnableDebug() { - if db != nil { - db = db.Debug() - } - - debugEnabled = true -} diff --git a/db/postgres/postgres_types.go b/db/postgres/postgres_types.go deleted file mode 100644 index ab3a9503..00000000 --- a/db/postgres/postgres_types.go +++ /dev/null @@ -1,164 +0,0 @@ -package postgres - -import ( - "time" - - "gorm.io/gorm" -) - -type Tabler interface { - TableName() string -} - -type SMRDates struct { - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index"` -} - -type SMRModel struct { - ID string `gorm:"primary_key;type:varchar(14)"` - SMRDates -} - -type User struct { - GithubID *string - GoogleID *string - FacebookID *string - SMRModel - Email string `gorm:"type:varchar(256);unique_index"` - Username string `gorm:"type:varchar(32)"` - Avatar string - JoinedFrom string - Mods []Mod `gorm:"many2many:user_mods;"` - Banned bool `gorm:"default:false;not null"` -} - -type UserSession struct { - SMRModel - UserID string - Token string `gorm:"type:varchar(256);unique_index"` - UserAgent string - User User -} - -type Mod struct { - LastVersionDate *time.Time - Compatibility *CompatibilityInfo `gorm:"serializer:json"` - SMRModel - CreatorID string - Logo string - SourceURL string - FullDescription string - ShortDescription string `gorm:"type:varchar(128)"` - Name string `gorm:"type:varchar(32)"` - ModReference string - Versions []Version - Tags []Tag `gorm:"many2many:mod_tags"` - Users []User `gorm:"many2many:user_mods;"` - Downloads uint - Popularity uint - Hotness uint - Views uint - Hidden bool - Denied bool `gorm:"default:false;not null"` - Approved bool `gorm:"default:false;not null"` -} - -type UserMod struct { - UserID string `gorm:"primary_key"` - ModID string `gorm:"primary_key"` - Role string -} - -// If updated, update dataloader -type Version struct { - Metadata *string - Hash *string - Size *int64 - VersionPatch *int - VersionMinor *int - VersionMajor *int - ModReference *string - SMRModel - Changelog string - Stability string `gorm:"default:'alpha'" sql:"type:version_stability"` - Key string - SMLVersion string `gorm:"type:varchar(16)"` - Version string `gorm:"type:varchar(16)"` - ModID string - Targets []VersionTarget `gorm:"foreignKey:VersionID"` - Hotness uint - Downloads uint - Denied bool `gorm:"default:false;not null"` - Approved bool `gorm:"default:false;not null"` -} - -type TinyVersion struct { - Hash *string - Size *int64 - SMRModel - SMLVersion string `gorm:"type:varchar(16)"` - Version string `gorm:"type:varchar(16)"` - Targets []VersionTarget `gorm:"foreignKey:VersionID;preload:true"` - Dependencies []VersionDependency `gorm:"foreignKey:VersionID"` -} - -func (TinyVersion) TableName() string { - return "versions" -} - -type SMLVersion struct { - Date time.Time - BootstrapVersion *string - SMRModel - Version string `gorm:"type:varchar(32);unique_index"` - Stability string `sql:"type:version_stability"` - Link string - Changelog string - EngineVersion string - Targets []SMLVersionTarget `gorm:"foreignKey:VersionID"` - SatisfactoryVersion int -} - -type VersionDependency struct { - SMRDates - - VersionID string `gorm:"primary_key;type:varchar(14)"` - ModID string `gorm:"primary_key;type:varchar(14)"` - - Condition string `gorm:"type:varchar(64)"` - Optional bool -} - -type Tag struct { - SMRModel - - Name string `gorm:"type:varchar(24)"` - - Mods []Mod `gorm:"many2many:mod_tags"` -} - -type CompatibilityInfo struct { - Ea Compatibility `gorm:"type:compatibility" json:"EA"` - Exp Compatibility `gorm:"type:compatibility" json:"EXP"` -} - -type Compatibility struct { - State string - Note string -} - -type VersionTarget struct { - VersionID string `gorm:"primary_key;type:varchar(14)"` - TargetName string `gorm:"primary_key;type:varchar(16)"` - Key string - Hash string - Size int64 -} - -type SMLVersionTarget struct { - VersionID string `gorm:"primary_key;type:varchar(14)"` - TargetName string `gorm:"primary_key;type:varchar(16)"` - Link string -} diff --git a/db/postgres/sml_version.go b/db/postgres/sml_version.go deleted file mode 100644 index 2d7d68ad..00000000 --- a/db/postgres/sml_version.go +++ /dev/null @@ -1,15 +0,0 @@ -package postgres - -import ( - "context" -) - -func GetSMLLatestVersions(ctx context.Context) *[]SMLVersion { - var smlVersions []SMLVersion - - DBCtx(ctx).Preload("Targets").Select("distinct on (stability) *"). - Order("stability, created_at desc"). - Find(&smlVersions) - - return &smlVersions -} diff --git a/db/postgres/user.go b/db/postgres/user.go deleted file mode 100644 index ce12df66..00000000 --- a/db/postgres/user.go +++ /dev/null @@ -1,139 +0,0 @@ -package postgres - -import ( - "context" - - "github.com/satisfactorymodding/smr-api/oauth" - "github.com/satisfactorymodding/smr-api/util" -) - -func GetUserSession(ctx context.Context, oauthUser *oauth.UserData, userAgent string) (*UserSession, *User, bool) { - user := User{ - Email: oauthUser.Email, - Avatar: oauthUser.Avatar, - JoinedFrom: string(oauthUser.Site), - Username: oauthUser.Username, - } - - // Find or create the user by email - find := DBCtx(ctx).Where(&User{Email: oauthUser.Email}) - - if oauthUser.Site == oauth.SiteGithub { - find = find.Or(&User{GithubID: &oauthUser.ID}) - } else if oauthUser.Site == oauth.SiteGoogle { - find = find.Or(&User{GoogleID: &oauthUser.ID}) - } else if oauthUser.Site == oauth.SiteFacebook { - find = find.Or(&User{FacebookID: &oauthUser.ID}) - } - - find.Find(&user) - - newUser := false - - if user.ID == "" { - user.ID = util.GenerateUniqueID() - - if oauthUser.Site == oauth.SiteGithub { - user.GithubID = &oauthUser.ID - } else if oauthUser.Site == oauth.SiteGoogle { - user.GoogleID = &oauthUser.ID - } else if oauthUser.Site == oauth.SiteFacebook { - user.FacebookID = &oauthUser.ID - } - - DBCtx(ctx).Create(&user) - newUser = true - } - - if !newUser { - newID := false - if oauthUser.Site == oauth.SiteGithub && user.GithubID == nil { - user.GithubID = &oauthUser.ID - newID = true - } else if oauthUser.Site == oauth.SiteGoogle && user.GoogleID == nil { - user.GoogleID = &oauthUser.ID - newID = true - } else if oauthUser.Site == oauth.SiteFacebook && user.FacebookID == nil { - user.FacebookID = &oauthUser.ID - newID = true - } - - if newID { - Save(ctx, &user) - } - } - - // TODO Archive old deleted sessions to cold storage - // DBCtx(ctx).Delete(&UserSession{UserAgent: userAgent}) - - session := UserSession{ - User: user, - Token: util.GenerateUserToken(), - UserAgent: userAgent, - } - - session.ID = util.GenerateUniqueID() - - // Create a new session - DBCtx(ctx).Create(&session) - - return &session, &user, newUser -} - -func LogoutSession(ctx context.Context, token string) { - // TODO Archive old deleted sessions to cold storage - DBCtx(ctx).Delete(&UserSession{Token: token}) -} - -func GetUserByToken(ctx context.Context, token string) *User { - // TODO Merge into a single query - var session UserSession - DBCtx(ctx).Find(&session, UserSession{Token: token}) - - if session.ID == "" { - return nil - } - - var user User - DBCtx(ctx).Find(&user, "id = ?", session.UserID) - - if user.ID == "" { - return nil - } - - return &user -} - -func GetUserByID(ctx context.Context, userID string) *User { - var user User - DBCtx(ctx).Find(&user, "id = ?", userID) - - if user.ID == "" { - return nil - } - - return &user -} - -func GetUsersByID(ctx context.Context, userIds []string) *[]User { - var users []User - DBCtx(ctx).Find(&users, "id in (?)", userIds) - - if len(userIds) != len(users) { - return nil - } - - return &users -} - -func GetUserMods(ctx context.Context, userID string) []UserMod { - var mods []UserMod - DBCtx(ctx).Raw("SELECT * from \"user_mods\" as tdm WHERE user_id = ? AND mod_id = (SELECT id FROM mods WHERE id = tdm.mod_id AND deleted_at is NULL LIMIT 1)", userID).Find(&mods) - return mods -} - -func GetModAuthors(ctx context.Context, modID string) []UserMod { - var authors []UserMod - DBCtx(ctx).Find(&authors, "mod_id = ?", modID) - return authors -} diff --git a/db/postgres/utils.go b/db/postgres/utils.go deleted file mode 100644 index 44f119a5..00000000 --- a/db/postgres/utils.go +++ /dev/null @@ -1,21 +0,0 @@ -package postgres - -import ( - "context" - - "gorm.io/gorm" -) - -type ( - ContextDB struct{} -) - -func DBFromContext(ctx context.Context) *gorm.DB { - value := ctx.Value(ContextDB{}) - - if value == nil { - return nil - } - - return value.(*gorm.DB) -} diff --git a/db/postgres/version.go b/db/postgres/version.go deleted file mode 100644 index b0b9a184..00000000 --- a/db/postgres/version.go +++ /dev/null @@ -1,148 +0,0 @@ -package postgres - -import ( - "context" - "fmt" - "strings" - - "github.com/patrickmn/go-cache" - - "github.com/satisfactorymodding/smr-api/models" -) - -func GetModsLatestVersions(ctx context.Context, modIds []string, unapproved bool) *[]Version { - cacheKey := "GetModsLatestVersions_" + strings.Join(modIds, ":") + "_" + fmt.Sprint(unapproved) - if versions, ok := dbCache.Get(cacheKey); ok { - return versions.(*[]Version) - } - - var versions []Version - - DBCtx(ctx).Preload("Targets").Select("distinct on (mod_id, stability) *"). - Where("mod_id in (?)", modIds). - Where("approved = ? AND denied = ?", !unapproved, false). - Order("mod_id, stability, created_at desc"). - Find(&versions) - - dbCache.Set(cacheKey, &versions, cache.DefaultExpiration) - - return &versions -} - -func GetModVersions(ctx context.Context, modID string, limit int, offset int, orderBy string, order string, unapproved bool) []Version { - cacheKey := "GetModVersions_" + modID + "_" + fmt.Sprint(limit) + "_" + fmt.Sprint(offset) + "_" + orderBy + "_" + order + "_" + fmt.Sprint(unapproved) - if versions, ok := dbCache.Get(cacheKey); ok { - return versions.([]Version) - } - - var versions []Version - DBCtx(ctx).Preload("Targets").Limit(limit).Offset(offset).Order(orderBy+" "+order).Where("approved = ? AND denied = ?", !unapproved, false).Find(&versions, "mod_id = ?", modID) - - dbCache.Set(cacheKey, versions, cache.DefaultExpiration) - - return versions -} - -func GetAllModVersionsWithDependencies(ctx context.Context, modID string) []TinyVersion { - cacheKey := "GetAllModVersionsWithDependencies_" + modID - if versions, ok := dbCache.Get(cacheKey); ok { - return versions.([]TinyVersion) - } - - var versions []TinyVersion - DBCtx(ctx). - Preload("Dependencies"). - Preload("Targets"). - Where("approved = ? AND denied = ?", true, false). - Find(&versions, "mod_id = ?", modID) - - dbCache.Set(cacheKey, versions, cache.DefaultExpiration) - - return versions -} - -func GetModVersionsNew(ctx context.Context, modID string, filter *models.VersionFilter, unapproved bool) []Version { - hash, err := filter.Hash() - cacheKey := "" - if err == nil { - cacheKey = "GetModVersionsNew_" + modID + "_" + hash + "_" + fmt.Sprint(unapproved) - if versions, ok := dbCache.Get(cacheKey); ok { - return versions.([]Version) - } - } - - var versions []Version - query := DBCtx(ctx).Preload("Targets") - - if filter != nil { - query = query.Limit(*filter.Limit). - Offset(*filter.Offset). - Order(string(*filter.OrderBy) + " " + string(*filter.Order)) - } - - query.Preload("Targets").Where("approved = ? AND denied = ?", !unapproved, false).Find(&versions, "mod_id = ?", modID) - - if cacheKey != "" { - dbCache.Set(cacheKey, versions, cache.DefaultExpiration) - } - - return versions -} - -func GetModVersion(ctx context.Context, modID string, versionID string) *Version { - cacheKey := "GetModVersion_" + modID + "_" + versionID - if version, ok := dbCache.Get(cacheKey); ok { - return version.(*Version) - } - - var version Version - DBCtx(ctx).Preload("Targets").First(&version, "mod_id = ? AND id = ?", modID, versionID) - - if version.ID == "" { - return nil - } - - dbCache.Set(cacheKey, &version, cache.DefaultExpiration) - - return &version -} - -func IncrementVersionDownloads(ctx context.Context, version *Version) { - DBCtx(ctx).Model(version).Update("downloads", version.Downloads+1) -} - -func GetVersion(ctx context.Context, versionID string) *Version { - cacheKey := "GetVersion_" + versionID - if version, ok := dbCache.Get(cacheKey); ok { - return version.(*Version) - } - - var version Version - DBCtx(ctx).Preload("Targets").First(&version, "id = ?", versionID) - - if version.ID == "" { - return nil - } - - dbCache.Set(cacheKey, &version, cache.DefaultExpiration) - - return &version -} - -func GetVersionTarget(ctx context.Context, versionID string, target string) *VersionTarget { - cacheKey := "GetVersionTarget_" + versionID + "_" + target - if versionTarget, ok := dbCache.Get(cacheKey); ok { - return versionTarget.(*VersionTarget) - } - - var versionTarget VersionTarget - DBCtx(ctx).First(&versionTarget, "version_id = ? AND target_name = ?", versionID, target) - - if versionTarget.VersionID == "" { - return nil - } - - dbCache.Set(cacheKey, &versionTarget, cache.DefaultExpiration) - - return &versionTarget -} diff --git a/generated/conv/version.go b/generated/conv/version.go index fe15fc4b..0e30d461 100755 --- a/generated/conv/version.go +++ b/generated/conv/version.go @@ -34,6 +34,10 @@ func (c *VersionImpl) Convert(source *ent.Version) *generated.Version { generatedVersion.Targets = pGeneratedVersionTargetList pString := (*source).Metadata generatedVersion.Metadata = &pString + pInt := conversion.Int64ToInt((*source).Size) + generatedVersion.Size = &pInt + pString2 := (*source).Hash + generatedVersion.Hash = &pString2 pGeneratedVersion = &generatedVersion } return pGeneratedVersion diff --git a/go.mod b/go.mod index bdd8a732..8b7fbad0 100755 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/satisfactorymodding/smr-api go 1.21 +toolchain go1.21.4 + require ( ariga.io/entcache v0.1.0 entgo.io/ent v0.12.4 @@ -33,7 +35,6 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/mapstructure v1.5.0 github.com/o1egl/paseto v1.0.0 - github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/russross/blackfriday v1.6.0 github.com/sizeofint/gif-to-webp v0.0.0-20210224202734-e9d7ed071591 @@ -44,7 +45,6 @@ require ( github.com/vmihailenco/taskq/extra/taskqotel/v3 v3.2.9 github.com/vmihailenco/taskq/v3 v3.2.9 github.com/xeipuuv/gojsonschema v1.2.0 - go.opentelemetry.io/contrib v1.20.0 go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.45.0 go.opentelemetry.io/otel v1.19.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 @@ -58,8 +58,6 @@ require ( google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 gopkg.in/go-playground/validator.v9 v9.31.0 - gorm.io/driver/postgres v1.5.3 - gorm.io/gorm v1.25.5 ) require ( @@ -122,8 +120,6 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.1 // indirect diff --git a/go.sum b/go.sum index 8abf08c9..6756e3d7 100644 --- a/go.sum +++ b/go.sum @@ -375,10 +375,6 @@ github.com/jamillosantos/macchiato v0.0.0-20171220130318-3be045cc5033 h1:R0efOJW github.com/jamillosantos/macchiato v0.0.0-20171220130318-3be045cc5033/go.mod h1:JHpPOBFu/UpmWT79z9fw5lQn7Oem6lnkS3jN4ZQdfLQ= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -491,8 +487,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrBg0D7ufOcFM= github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -619,8 +613,6 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= -go.opentelemetry.io/contrib v1.20.0 h1:oXUiIQLlkbi9uZB/bt5B1WRLsrTKqb7bPpAQ+6htn2w= -go.opentelemetry.io/contrib v1.20.0/go.mod h1:gIzjwWFoGazJmtCaDgViqOSJPde2mCWzv60o0bWPcZs= go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.45.0 h1:JJCIHAxGCB5HM3NxeIwFjHc087Xwk96TG9kaZU6TAec= go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.45.0/go.mod h1:Px9kH7SJ+NhsgWRtD/eMcs15Tyt4uL3rM7X54qv6pfA= go.opentelemetry.io/contrib/propagators/b3 v1.20.0 h1:Yty9Vs4F3D6/liF1o6FNt0PvN85h/BJJ6DQKJ3nrcM0= @@ -1045,10 +1037,6 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.3 h1:qKGY5CPHOuj47K/VxbCXJfFvIUeqMSXXadqdCY+MbBU= -gorm.io/driver/postgres v1.5.3/go.mod h1:F+LtvlFhZT7UBiA81mC9W6Su3D4WUhSboc/36QZU0gk= -gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= -gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/gql/resolver_announcements.go b/gql/resolver_announcements.go index ead4eb83..4c021b36 100644 --- a/gql/resolver_announcements.go +++ b/gql/resolver_announcements.go @@ -85,7 +85,7 @@ func (r *queryResolver) GetAnnouncements(ctx context.Context) ([]*generated.Anno wrapper, ctx := WrapQueryTrace(ctx, "getAnnouncements") defer wrapper.end() - result, err := db.From(ctx).Debug().Announcement.Query().All(ctx) + result, err := db.From(ctx).Announcement.Query().All(ctx) if err != nil { return nil, err } diff --git a/gql/resolver_mods.go b/gql/resolver_mods.go index 12e5d1f1..13538041 100644 --- a/gql/resolver_mods.go +++ b/gql/resolver_mods.go @@ -23,7 +23,6 @@ import ( "github.com/satisfactorymodding/smr-api/generated/conv" "github.com/satisfactorymodding/smr-api/generated/ent" "github.com/satisfactorymodding/smr-api/generated/ent/mod" - "github.com/satisfactorymodding/smr-api/generated/ent/modtag" "github.com/satisfactorymodding/smr-api/generated/ent/usermod" "github.com/satisfactorymodding/smr-api/generated/ent/version" "github.com/satisfactorymodding/smr-api/integrations" @@ -425,8 +424,8 @@ func (r *getModsResolver) Mods(ctx context.Context, _ *generated.GetMods) ([]*ge modFilter.AddField(field.Name) } - query := db.From(ctx).Debug().Mod.Query() - query = convertModFilter(query, modFilter, false, unapproved) + query := db.From(ctx).Mod.Query() + query = db.ConvertModFilter(query, modFilter, false, unapproved) result, err := query.All(ctx) if err != nil { @@ -449,7 +448,7 @@ func (r *getModsResolver) Count(ctx context.Context, _ *generated.GetMods) (int, } query := db.From(ctx).Mod.Query() - query = convertModFilter(query, modFilter, false, unapproved) + query = db.ConvertModFilter(query, modFilter, false, unapproved) result, err := query.Count(ctx) if err != nil { @@ -478,7 +477,7 @@ func (r *getMyModsResolver) Mods(ctx context.Context, _ *generated.GetMyMods) ([ } query := db.From(ctx).Mod.Query() - query = convertModFilter(query, modFilter, false, unapproved) + query = db.ConvertModFilter(query, modFilter, false, unapproved) result, err := query.All(ctx) if err != nil { @@ -501,7 +500,7 @@ func (r *getMyModsResolver) Count(ctx context.Context, _ *generated.GetMyMods) ( } query := db.From(ctx).Mod.Query() - query = convertModFilter(query, modFilter, false, unapproved) + query = db.ConvertModFilter(query, modFilter, false, unapproved) result, err := query.Count(ctx) if err != nil { @@ -736,65 +735,3 @@ func (r *queryResolver) ResolveModVersions(ctx context.Context, filter []*genera return modVersions, nil } - -func convertModFilter(query *ent.ModQuery, filter *models.ModFilter, count bool, unapproved bool) *ent.ModQuery { - query = query.WithTags() - - if len(filter.Ids) > 0 { - query = query.Where(mod.IDIn(filter.Ids...)) - } else if len(filter.References) > 0 { - query = query.Where(mod.ModReferenceIn(filter.References...)) - } else if filter != nil { - query = query. - Limit(*filter.Limit). - Offset(*filter.Offset) - - if *filter.OrderBy != generated.ModFieldsSearch { - if string(*filter.OrderBy) == "last_version_date" { - query = query.Modify(func(s *sql.Selector) { - s.OrderExpr(sql.ExprP("case when last_version_date is null then 1 else 0 end, last_version_date")) - }).Clone() - } else { - query = query.Order(sql.OrderByField( - filter.OrderBy.String(), - db.OrderToOrder(filter.Order.String()), - ).ToFunc()) - } - } - - if filter.Search != nil && *filter.Search != "" { - cleanSearch := strings.ReplaceAll(strings.TrimSpace(*filter.Search), " ", " & ") - - query = query.Where(func(s *sql.Selector) { - join := sql.SelectExpr(sql.ExprP("id, (similarity(name, ?) * 2 + similarity(short_description, ?) + similarity(full_description, ?) * 0.5) as s", cleanSearch, cleanSearch, cleanSearch)) - join.From(sql.Table(mod.Table)).As("t1") - s.Join(join).On(s.C(mod.FieldID), join.C("id")) - }) - - query = query.Where(func(s *sql.Selector) { - s.Where(sql.ExprP(`"t1"."s" > 0.2`)) - }) - - if !count && *filter.OrderBy == generated.ModFieldsSearch { - query = query.Order(func(s *sql.Selector) { - s.OrderExpr(sql.ExprP(`"t1"."s" DESC`)) - }) - } - } - - if filter.Hidden == nil || !(*filter.Hidden) { - query = query.Where(mod.Hidden(false)) - } - - if filter.TagIDs != nil && len(filter.TagIDs) > 0 { - query = query.Where(func(s *sql.Selector) { - t := sql.Table(modtag.Table) - s.Join(t).OnP(sql.ExprP("mod_tags.tag_id in ? AND mod_tags.mod_id = mods.id", filter.TagIDs)) - }) - } - } - - query = query.Where(mod.Approved(!unapproved), mod.Denied(false)) - - return query -} diff --git a/gql/resolver_oauth.go b/gql/resolver_oauth.go index 2df28efd..40942289 100644 --- a/gql/resolver_oauth.go +++ b/gql/resolver_oauth.go @@ -1,7 +1,6 @@ package gql import ( - "bytes" "context" "fmt" "net/http" @@ -11,10 +10,7 @@ import ( "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/generated" - "github.com/satisfactorymodding/smr-api/generated/ent" - "github.com/satisfactorymodding/smr-api/generated/ent/user" "github.com/satisfactorymodding/smr-api/oauth" - "github.com/satisfactorymodding/smr-api/storage" "github.com/satisfactorymodding/smr-api/util" ) @@ -52,7 +48,7 @@ func (r *mutationResolver) OAuthGithub(ctx context.Context, code string, state s header := ctx.Value(util.ContextHeader{}).(http.Header) userAgent := header.Get("User-Agent") - token, err := completeOAuthFlow(ctx, u, userAgent) + token, err := db.CompleteOAuthFlow(ctx, u, userAgent) if err != nil { return nil, err } @@ -78,7 +74,7 @@ func (r *mutationResolver) OAuthGoogle(ctx context.Context, code string, state s header := ctx.Value(util.ContextHeader{}).(http.Header) userAgent := header.Get("User-Agent") - token, err := completeOAuthFlow(ctx, u, userAgent) + token, err := db.CompleteOAuthFlow(ctx, u, userAgent) if err != nil { return nil, err } @@ -104,7 +100,7 @@ func (r *mutationResolver) OAuthFacebook(ctx context.Context, code string, state header := ctx.Value(util.ContextHeader{}).(http.Header) userAgent := header.Get("User-Agent") - token, err := completeOAuthFlow(ctx, u, userAgent) + token, err := db.CompleteOAuthFlow(ctx, u, userAgent) if err != nil { return nil, err } @@ -113,94 +109,3 @@ func (r *mutationResolver) OAuthFacebook(ctx context.Context, code string, state Token: *token, }, nil } - -func completeOAuthFlow(ctx context.Context, u *oauth.UserData, userAgent string) (*string, error) { - avatarURL := u.Avatar - u.Avatar = "" - - find := db.From(ctx).User.Query().Where(user.Email(u.Email)) - - if u.Site == oauth.SiteGithub { - find = find.Where(user.GithubID(u.ID)) - } else if u.Site == oauth.SiteGoogle { - find = find.Where(user.GoogleID(u.ID)) - } else if u.Site == oauth.SiteFacebook { - find = find.Where(user.FacebookID(u.ID)) - } - - found, err := find.First(ctx) - if err != nil && !ent.IsNotFound(err) { - return nil, err - } - - newUser := false - if ent.IsNotFound(err) { - var err error - create := db.From(ctx).User. - Create(). - SetEmail(u.Email). - SetAvatar(u.Avatar). - SetJoinedFrom(string(u.Site)). - SetUsername(u.Username) - - if u.Site == oauth.SiteGithub { - create = create.SetGithubID(u.ID) - } else if u.Site == oauth.SiteGoogle { - create = create.SetGoogleID(u.ID) - } else if u.Site == oauth.SiteFacebook { - create = create.SetFacebookID(u.ID) - } - - found, err = create.Save(ctx) - if err != nil { - return nil, err - } - - newUser = true - } - - if !newUser { - var update *ent.UserUpdateOne - if u.Site == oauth.SiteGithub && found.GithubID == "" { - update = found.Update().SetGithubID(u.ID) - } else if u.Site == oauth.SiteGoogle && found.GoogleID == "" { - update = found.Update().SetGoogleID(u.ID) - } else if u.Site == oauth.SiteFacebook && found.FacebookID == "" { - update = found.Update().SetFacebookID(u.ID) - } - - if update != nil { - if err := update.Exec(ctx); err != nil { - return nil, err - } - } - } - - // TODO Archive old deleted sessions to cold storage - - session, err := db.From(ctx).UserSession. - Create(). - SetUserID(found.ID). - SetToken(util.GenerateUserToken()). - SetUserAgent(userAgent). - Save(ctx) - if err != nil { - return nil, err - } - - if avatarURL != "" && newUser { - avatarData, err := util.LinkToWebp(ctx, avatarURL) - if err != nil { - return nil, err - } - - success, avatarKey := storage.UploadUserAvatar(ctx, found.ID, bytes.NewReader(avatarData)) - if success { - if err := found.Update().SetAvatar(storage.GenerateDownloadLink(avatarKey)).Exec(ctx); err != nil { - return nil, err - } - } - } - - return &session.Token, nil -} diff --git a/gql/resolver_versions.go b/gql/resolver_versions.go index 3f4acb86..c38e2696 100644 --- a/gql/resolver_versions.go +++ b/gql/resolver_versions.go @@ -242,7 +242,7 @@ func (r *queryResolver) GetVersion(ctx context.Context, versionID string) (*gene wrapper, ctx := WrapQueryTrace(ctx, "getVersion") defer wrapper.end() - result, err := db.From(ctx).Version.Get(ctx, versionID) + result, err := db.From(ctx).Version.Query().WithTargets().Where(version.ID(versionID)).First(ctx) if err != nil { return nil, err } diff --git a/migrations/code.go b/migrations/code.go index fc5e13ff..9f34a1f5 100644 --- a/migrations/code.go +++ b/migrations/code.go @@ -2,14 +2,17 @@ package migrations import ( "context" + "database/sql" + "fmt" "log/slog" "os" "strings" "github.com/lab259/go-migration" + "github.com/spf13/viper" - postgres2 "github.com/satisfactorymodding/smr-api/db/postgres" - + // Import pgx + _ "github.com/jackc/pgx/v5/stdlib" // Import all migrations _ "github.com/satisfactorymodding/smr-api/migrations/code" ) @@ -24,13 +27,24 @@ func (c codeMigrationLogger) Write(p []byte) (int, error) { return len(p), nil } -func codeMigrations(ctx context.Context) { +func codeMigrations(_ context.Context) { source := migration.DefaultCodeSource() // TODO Custom reporter, this one's very ugly reporter := migration.NewDefaultReporterWithParams(codeMigrationLogger{}, os.Exit) - db, _ := postgres2.DBCtx(ctx).DB() + db, err := sql.Open("pgx", fmt.Sprintf( + "sslmode=disable host=%s port=%d user=%s dbname=%s password=%s", + viper.GetString("database.postgres.host"), + viper.GetInt("database.postgres.port"), + viper.GetString("database.postgres.user"), + viper.GetString("database.postgres.db"), + viper.GetString("database.postgres.pass"), + )) + if err != nil { + panic(err) + } + manager := migration.NewDefaultManager(migration.NewPostgreSQLTarget(db), source) runner := migration.NewArgsRunnerCustom(reporter, manager, os.Exit, "migrate") runner.Run(db) diff --git a/migrations/code/20200426221900_test_new_migration.go b/migrations/code/20200426221900_test_new_migration.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20200426221900_test_new_migration.go +++ b/migrations/code/20200426221900_test_new_migration.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20200501224200_parse_paks.go b/migrations/code/20200501224200_parse_paks.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20200501224200_parse_paks.go +++ b/migrations/code/20200501224200_parse_paks.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20200524203800_after_reference_fix.go b/migrations/code/20200524203800_after_reference_fix.go index 311c6ff2..1fe62003 100644 --- a/migrations/code/20200524203800_after_reference_fix.go +++ b/migrations/code/20200524203800_after_reference_fix.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), false, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, false, nil, nil) }, ) } diff --git a/migrations/code/20200621195500_after_id_length_fix.go b/migrations/code/20200621195500_after_id_length_fix.go index 311c6ff2..1fe62003 100644 --- a/migrations/code/20200621195500_after_id_length_fix.go +++ b/migrations/code/20200621195500_after_id_length_fix.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), false, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, false, nil, nil) }, ) } diff --git a/migrations/code/20200622003600_after_validation_disable.go b/migrations/code/20200622003600_after_validation_disable.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20200622003600_after_validation_disable.go +++ b/migrations/code/20200622003600_after_validation_disable.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20200629093800_copy_to_new_bucket.go b/migrations/code/20200629093800_copy_to_new_bucket.go index e16ce85b..efe9855e 100644 --- a/migrations/code/20200629093800_copy_to_new_bucket.go +++ b/migrations/code/20200629093800_copy_to_new_bucket.go @@ -5,6 +5,7 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/redis/jobs" "github.com/satisfactorymodding/smr-api/storage" ) @@ -12,8 +13,12 @@ import ( func init() { migration.NewCodeMigration( func(executionContext interface{}) error { + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } storage.ScheduleCopyAllObjectsFromOldBucket(func(key string) { - jobs.SubmitJobCopyObjectFromOldBucketTask(context.TODO(), key) + jobs.SubmitJobCopyObjectFromOldBucketTask(ctx, key) }) return nil }, diff --git a/migrations/code/20200707150700_after_sml_version_fix.go b/migrations/code/20200707150700_after_sml_version_fix.go index 9326e632..e61130dc 100644 --- a/migrations/code/20200707150700_after_sml_version_fix.go +++ b/migrations/code/20200707150700_after_sml_version_fix.go @@ -5,17 +5,21 @@ import ( "github.com/lab259/go-migration" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/ent" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), false, nil, func(version postgres.Version) bool { - return version.SMLVersion == "" + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, false, nil, func(version *ent.Version) bool { + return version.SmlVersion == "" }) - return nil }, ) } diff --git a/migrations/code/20200829171600_after_db_dirty.go b/migrations/code/20200829171600_after_db_dirty.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20200829171600_after_db_dirty.go +++ b/migrations/code/20200829171600_after_db_dirty.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20200829225100_after_broken_datajson.go b/migrations/code/20200829225100_after_broken_datajson.go index 9da63802..0e179c5d 100644 --- a/migrations/code/20200829225100_after_broken_datajson.go +++ b/migrations/code/20200829225100_after_broken_datajson.go @@ -5,17 +5,21 @@ import ( "github.com/lab259/go-migration" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/ent" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, func(version postgres.Version) bool { - return version.Hash == nil || *version.Hash == "" + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, func(version *ent.Version) bool { + return version.Hash == "" }) - return nil }, ) } diff --git a/migrations/code/20200830011200_after_gorm_hotfix.go b/migrations/code/20200830011200_after_gorm_hotfix.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20200830011200_after_gorm_hotfix.go +++ b/migrations/code/20200830011200_after_gorm_hotfix.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20201014162200_after_bp_fix.go b/migrations/code/20201014162200_after_bp_fix.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20201014162200_after_bp_fix.go +++ b/migrations/code/20201014162200_after_bp_fix.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20201016202600_after_body_enable.go b/migrations/code/20201016202600_after_body_enable.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20201016202600_after_body_enable.go +++ b/migrations/code/20201016202600_after_body_enable.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/code/20201019203800_after_remove_filter.go b/migrations/code/20201019203800_after_remove_filter.go index 488054c0..9f66ec92 100644 --- a/migrations/code/20201019203800_after_remove_filter.go +++ b/migrations/code/20201019203800_after_remove_filter.go @@ -5,14 +5,18 @@ import ( "github.com/lab259/go-migration" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/migrations/utils" ) func init() { migration.NewCodeMigration( func(executionContext interface{}) error { - utils.ReindexAllModFiles(context.TODO(), true, nil, nil) - return nil + ctx, err := db.WithDB(context.Background()) + if err != nil { + return err + } + return utils.ReindexAllModFiles(ctx, true, nil, nil) }, ) } diff --git a/migrations/migrations.go b/migrations/migrations.go index 68b6b49e..053ff64d 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -2,6 +2,7 @@ package migrations import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -9,11 +10,12 @@ import ( "github.com/Vilsol/slox" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" - - postgres2 "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/spf13/viper" // Import migrations _ "github.com/golang-migrate/migrate/v4/source/file" + // Import pgx + _ "github.com/jackc/pgx/v5/stdlib" ) func RunMigrations(ctx context.Context) { @@ -29,8 +31,19 @@ func SetMigrationDir(newMigrationDir string) { } func databaseMigrations(ctx context.Context) { - db, _ := postgres2.DBCtx(ctx).DB() - driver, err := postgres.WithInstance(db, &postgres.Config{}) + connection, err := sql.Open("pgx", fmt.Sprintf( + "sslmode=disable host=%s port=%d user=%s dbname=%s password=%s", + viper.GetString("database.postgres.host"), + viper.GetInt("database.postgres.port"), + viper.GetString("database.postgres.user"), + viper.GetString("database.postgres.db"), + viper.GetString("database.postgres.pass"), + )) + if err != nil { + panic(err) + } + + driver, err := postgres.WithInstance(connection, &postgres.Config{}) if err != nil { panic(err) } diff --git a/migrations/utils/utils.go b/migrations/utils/utils.go index 2c8f2815..f15f9002 100644 --- a/migrations/utils/utils.go +++ b/migrations/utils/utils.go @@ -3,58 +3,58 @@ package utils import ( "context" - "github.com/satisfactorymodding/smr-api/db/postgres" - "github.com/satisfactorymodding/smr-api/generated" - "github.com/satisfactorymodding/smr-api/models" + "entgo.io/ent/dialect/sql" + + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/ent" + "github.com/satisfactorymodding/smr-api/generated/ent/mod" + "github.com/satisfactorymodding/smr-api/generated/ent/version" "github.com/satisfactorymodding/smr-api/redis/jobs" ) -func ReindexAllModFiles(ctx context.Context, withMetadata bool, modFilter func(postgres.Mod) bool, versionFilter func(version postgres.Version) bool) { +func ReindexAllModFiles(ctx context.Context, withMetadata bool, modFilter func(*ent.Mod) bool, versionFilter func(version *ent.Version) bool) error { offset := 0 - limit := 100 - createdAt := generated.VersionFieldsCreatedAt - orderDesc := generated.OrderDesc - for { - mods := postgres.GetMods(ctx, 100, offset, "created_at", "asc", "", false) - offset += 100 + mods, err := db.From(ctx).Mod.Query().Limit(100).Offset(100).Order(mod.ByCreatedAt(sql.OrderDesc())).All(ctx) + if err != nil { + return err + } + offset += len(mods) if len(mods) == 0 { break } - for _, mod := range mods { + for _, m := range mods { versionOffset := 0 if modFilter != nil { - if !modFilter(mod) { + if !modFilter(m) { continue } } for { - versions := postgres.GetModVersionsNew(ctx, mod.ID, &models.VersionFilter{ - Limit: &limit, - Offset: &versionOffset, - OrderBy: &createdAt, - Order: &orderDesc, - }, false) + versions, err := m.QueryVersions().Limit(100).Offset(versionOffset).Order(version.ByCreatedAt(sql.OrderDesc())).All(ctx) + if err != nil { + return err + } versionOffset += len(versions) if len(versions) > 0 { - for _, version := range versions { + for _, v := range versions { if versionFilter != nil { - if !versionFilter(version) { + if !versionFilter(v) { continue } } if withMetadata { - jobs.SubmitJobUpdateDBFromModVersionFileTask(ctx, mod.ID, version.ID) + jobs.SubmitJobUpdateDBFromModVersionFileTask(ctx, m.ID, v.ID) } else { - jobs.SubmitJobUpdateDBFromModVersionJSONFileTask(ctx, mod.ID, version.ID) + jobs.SubmitJobUpdateDBFromModVersionJSONFileTask(ctx, m.ID, v.ID) } } } else { @@ -63,4 +63,6 @@ func ReindexAllModFiles(ctx context.Context, withMetadata bool, modFilter func(p } } } + + return nil } diff --git a/nodes/mod.go b/nodes/mod.go index ee3b53c6..3eb28980 100644 --- a/nodes/mod.go +++ b/nodes/mod.go @@ -1,12 +1,21 @@ package nodes import ( + "log/slog" "strings" "time" + "entgo.io/ent/dialect/sql" + "github.com/Vilsol/slox" "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated" + "github.com/satisfactorymodding/smr-api/generated/conv" + mod2 "github.com/satisfactorymodding/smr-api/generated/ent/mod" + version2 "github.com/satisfactorymodding/smr-api/generated/ent/version" + "github.com/satisfactorymodding/smr-api/generated/ent/versiontarget" + "github.com/satisfactorymodding/smr-api/models" "github.com/satisfactorymodding/smr-api/redis" "github.com/satisfactorymodding/smr-api/storage" "github.com/satisfactorymodding/smr-api/util" @@ -31,14 +40,27 @@ func getMods(c echo.Context) (interface{}, *ErrorResponse) { order := util.OneOf(c, "order", []string{"asc", "desc"}, "desc") search := c.QueryParam("search") - mods := postgres.GetMods(c.Request().Context(), limit, offset, orderBy, order, search, false) + modFilter := models.DefaultModFilter() + modFilter.Limit = &limit + modFilter.Offset = &offset - converted := make([]*Mod, len(mods)) - for k, v := range mods { - converted[k] = ModToMod(&v, true) + orderByGen := generated.ModFields(orderBy) + modFilter.OrderBy = &orderByGen + + orderGen := generated.Order(order) + modFilter.Order = &orderGen + modFilter.Search = &search + + query := db.From(c.Request().Context()).Mod.Query() + query = db.ConvertModFilter(query, modFilter, false, false) + + mods, err := query.All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed retrieving mods", slog.Any("err", err)) + return nil, &ErrorVersionNotFound } - return converted, nil + return (*conv.ModImpl)(nil).ConvertSlice(mods), nil } // @Summary Retrieve a count of Mods @@ -51,7 +73,20 @@ func getMods(c echo.Context) (interface{}, *ErrorResponse) { // @Router /mods/count [get] func getModCount(c echo.Context) (interface{}, *ErrorResponse) { search := c.QueryParam("search") - return postgres.GetModCount(c.Request().Context(), search, false), nil + + modFilter := models.DefaultModFilter() + modFilter.Search = &search + + query := db.From(c.Request().Context()).Mod.Query() + query = db.ConvertModFilter(query, modFilter, true, false) + + count, err := query.Count(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed retrieving mod count", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } + + return count, nil } // @Summary Retrieve a Mod @@ -65,7 +100,14 @@ func getModCount(c echo.Context) (interface{}, *ErrorResponse) { func getMod(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return nil, &ErrorModNotFound + } if mod == nil { return nil, &ErrorModNotFound @@ -73,11 +115,11 @@ func getMod(c echo.Context) (interface{}, *ErrorResponse) { if _, ok := c.QueryParams()["view"]; ok { if redis.CanIncrement(c.RealIP(), "view", "mod:"+modID, time.Hour*4) { - postgres.IncrementModViews(c.Request().Context(), mod) + _ = mod.Update().AddViews(1).Exec(c.Request().Context()) } } - return ModToMod(mod, false), nil + return (*conv.ModImpl)(nil).Convert(mod), nil } // @Summary Retrieve a list of Mods by ID @@ -94,18 +136,20 @@ func getModsByIds(c echo.Context) (interface{}, *ErrorResponse) { // TODO limit amount of users requestable - mods := postgres.GetModsByID(c.Request().Context(), modIDSplit) + mods, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.IDIn(modIDSplit...)). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mods", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if mods == nil { return nil, &ErrorModNotFound } - converted := make([]*Mod, len(mods)) - for k, v := range mods { - converted[k] = ModToMod(&v, true) - } - - return converted, nil + return (*conv.ModImpl)(nil).ConvertSlice(mods), nil } // @Summary Retrieve a list of latest versions for a mod @@ -119,16 +163,27 @@ func getModsByIds(c echo.Context) (interface{}, *ErrorResponse) { func getModLatestVersions(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") - versions := postgres.GetModsLatestVersions(c.Request().Context(), []string{modID}, false) + versions, err := db.From(c.Request().Context()).Version.Query(). + WithTargets(). + Modify(func(s *sql.Selector) { + s.SelectExpr(sql.ExprP("distinct on (mod_id, stability) *")) + }). + Where(version2.Approved(true), version2.Denied(false), version2.ModID(modID)). + Order(version2.ByStability(sql.OrderDesc()), version2.ByCreatedAt(sql.OrderDesc())). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching versions", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if versions == nil { return nil, &ErrorVersionNotFound } - result := make(map[string]*Version) + result := make(map[string]*generated.Version) - for _, v := range *versions { - result[v.Stability] = VersionToVersion(&v) + for _, v := range versions { + result[string(v.Stability)] = (*conv.VersionImpl)(nil).Convert(v) } return result, nil @@ -148,19 +203,30 @@ func getModsLatestVersions(c echo.Context) (interface{}, *ErrorResponse) { // TODO limit amount of mods requestable - versions := postgres.GetModsLatestVersions(c.Request().Context(), modIDSplit, false) + versions, err := db.From(c.Request().Context()).Version.Query(). + WithTargets(). + Modify(func(s *sql.Selector) { + s.SelectExpr(sql.ExprP("distinct on (mod_id, stability) *")) + }). + Where(version2.Approved(true), version2.Denied(false), version2.ModIDIn(modIDSplit...)). + Order(version2.ByStability(sql.OrderDesc()), version2.ByCreatedAt(sql.OrderDesc())). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching versions", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if versions == nil { return nil, &ErrorVersionNotFound } - result := make(map[string]map[string]*Version) + result := make(map[string]map[string]*generated.Version) - for _, v := range *versions { + for _, v := range versions { if _, ok := result[v.ModID]; !ok { - result[v.ModID] = make(map[string]*Version) + result[v.ModID] = make(map[string]*generated.Version) } - result[v.ModID][v.Stability] = VersionToVersion(&v) + result[v.ModID][string(v.Stability)] = (*conv.VersionImpl)(nil).Convert(v) } return result, nil @@ -186,20 +252,33 @@ func getModVersions(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if mod == nil { return nil, &ErrorModNotFound } - versions := postgres.GetModVersions(c.Request().Context(), mod.ID, limit, offset, orderBy, order, false) - - converted := make([]*Version, len(versions)) - for k, v := range versions { - converted[k] = VersionToVersion(&v) + versions, err := mod.QueryVersions(). + WithDependencies(). + WithTargets(). + Limit(limit). + Offset(offset). + Order(sql.OrderByField(orderBy, db.OrderToOrder(order)).ToFunc()). + Where(version2.Approved(true), version2.Denied(false)). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching versions", slog.Any("err", err)) + return nil, &ErrorVersionNotFound } - return converted, nil + return (*conv.VersionImpl)(nil).ConvertSlice(versions), nil } // @Summary Retrieve a Mod Authors @@ -213,20 +292,26 @@ func getModVersions(c echo.Context) (interface{}, *ErrorResponse) { func getModAuthors(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return nil, &ErrorModNotFound + } if mod == nil { return nil, &ErrorModNotFound } - authors := postgres.GetModAuthors(c.Request().Context(), mod.ID) - - converted := make([]*ModUser, len(authors)) - for k, v := range authors { - converted[k] = ModUserToModUser(&v) + authors, err := mod.QueryUserMods().All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching authors", slog.Any("err", err)) + return nil, &ErrorModNotFound } - return converted, nil + return (*conv.UserModImpl)(nil).ConvertSlice(authors), nil } // @Summary Retrieve a Mod Version @@ -242,19 +327,30 @@ func getModVersion(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") versionID := c.Param("versionId") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if mod == nil { return nil, &ErrorModNotFound } - version := postgres.GetModVersion(c.Request().Context(), mod.ID, versionID) + version, err := mod.QueryVersions().Where(version2.ID(versionID)).First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if version == nil { return nil, &ErrorVersionNotFound } - return VersionToVersion(version), nil + return (*conv.VersionImpl)(nil).Convert(version), nil } // @Summary Download a Mod Version @@ -270,20 +366,31 @@ func downloadModVersion(c echo.Context) error { modID := c.Param("modId") versionID := c.Param("versionId") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return c.String(404, "mod not found, modID:"+modID) + } if mod == nil { return c.String(404, "mod not found, modID:"+modID) } - version := postgres.GetModVersion(c.Request().Context(), mod.ID, versionID) + version, err := mod.QueryVersions().Where(version2.ID(versionID)).First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return c.String(404, "version not found, modID:"+modID+" versionID:"+versionID) + } if version == nil { return c.String(404, "version not found, modID:"+modID+" versionID:"+versionID) } if redis.CanIncrement(c.RealIP(), "download", "version:"+versionID, time.Hour*4) { - postgres.IncrementVersionDownloads(c.Request().Context(), version) + _ = version.Update().AddDownloads(1).Exec(c.Request().Context()) } return c.Redirect(302, storage.GenerateDownloadLink(version.Key)) @@ -304,26 +411,41 @@ func downloadModVersionTarget(c echo.Context) error { versionID := c.Param("versionId") target := c.Param("target") - mod := postgres.GetModByID(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.ID(modID)). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return c.String(404, "mod not found, modID:"+modID) + } if mod == nil { return c.String(404, "mod not found, modID:"+modID) } - version := postgres.GetModVersion(c.Request().Context(), mod.ID, versionID) + version, err := mod.QueryVersions().Where(version2.ID(versionID)).First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return err + } if version == nil { return c.String(404, "version not found, modID:"+modID+" versionID:"+versionID) } - versionTarget := postgres.GetVersionTarget(c.Request().Context(), versionID, target) + versionTarget, err := version.QueryTargets().Where(versiontarget.TargetName(target)).First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching target", slog.Any("err", err)) + return err + } if versionTarget == nil { return c.String(404, "target not found, modID:"+modID+" versionID:"+versionID+" target:"+target) } if redis.CanIncrement(c.RealIP(), "download", "version:"+versionID, time.Hour*4) { - postgres.IncrementVersionDownloads(c.Request().Context(), version) + _ = version.Update().AddDownloads(1).Exec(c.Request().Context()) } return c.Redirect(302, storage.GenerateDownloadLink(versionTarget.Key)) @@ -340,18 +462,29 @@ func downloadModVersionTarget(c echo.Context) error { func getAllModVersions(c echo.Context) (interface{}, *ErrorResponse) { modID := c.Param("modId") - mod := postgres.GetModByIDOrReference(c.Request().Context(), modID) + mod, err := db.From(c.Request().Context()).Mod.Query(). + WithTags(). + Where(mod2.Or(mod2.ID(modID), mod2.ModReference(modID))). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mod", slog.Any("err", err)) + return nil, &ErrorModNotFound + } if mod == nil { return nil, &ErrorModNotFound } - versions := postgres.GetAllModVersionsWithDependencies(c.Request().Context(), mod.ID) - - converted := make([]*Version, len(versions)) - for k, v := range versions { - converted[k] = TinyVersionToVersion(&v) + versions, err := mod.QueryVersions(). + WithDependencies(). + WithTargets(). + Where(version2.Approved(true), version2.Denied(false)). + Select(version2.FieldHash, version2.FieldSize, version2.FieldSmlVersion, version2.FieldVersion). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching versions", slog.Any("err", err)) + return nil, &ErrorVersionNotFound } - return converted, nil + return (*conv.VersionImpl)(nil).ConvertSlice(versions), nil } diff --git a/nodes/mod_types.go b/nodes/mod_types.go deleted file mode 100644 index 25cab078..00000000 --- a/nodes/mod_types.go +++ /dev/null @@ -1,178 +0,0 @@ -package nodes - -import ( - "time" - - "github.com/satisfactorymodding/smr-api/db/postgres" -) - -type Mod struct { - UpdatedAt time.Time `json:"updated_at"` - CreatedAt time.Time `json:"created_at"` - CreatorID string `json:"creator_id"` - FullDescription string `json:"full_description"` - Logo string `json:"logo"` - SourceURL string `json:"source_url"` - ID string `json:"id"` - ShortDescription string `json:"short_description"` - Name string `json:"name"` - Views uint `json:"views"` - Downloads uint `json:"downloads"` - Hotness uint `json:"hotness"` - Popularity uint `json:"popularity"` - Approved bool `json:"approved"` -} - -func ModToMod(mod *postgres.Mod, short bool) *Mod { - result := Mod{ - ID: mod.ID, - Name: mod.Name, - ShortDescription: mod.ShortDescription, - Logo: mod.Logo, - SourceURL: mod.SourceURL, - CreatorID: mod.CreatorID, - Approved: mod.Approved, - Views: mod.Views, - Downloads: mod.Downloads, - Hotness: mod.Hotness, - Popularity: mod.Popularity, - UpdatedAt: mod.UpdatedAt, - CreatedAt: mod.CreatedAt, - } - - if !short { - result.FullDescription = mod.FullDescription - } - - return &result -} - -type Version struct { - UpdatedAt time.Time `json:"updated_at,omitempty"` - CreatedAt time.Time `json:"created_at,omitempty"` - ID string `json:"id,omitempty"` - Version string `json:"version,omitempty"` - SMLVersion string `json:"sml_version,omitempty"` - Changelog string `json:"changelog,omitempty"` - Stability string `json:"stability,omitempty"` - ModID string `json:"mod_id,omitempty"` - Dependencies []VersionDependency `json:"dependencies,omitempty"` - Targets []VersionTarget `json:"targets,omitempty"` - Downloads uint `json:"downloads,omitempty"` - Approved bool `json:"approved,omitempty"` -} - -type VersionDependency struct { - ModID string `json:"mod_id"` - Condition string `json:"condition"` - Optional bool `json:"optional"` -} - -type VersionTarget struct { - VersionID string `json:"version_id"` - TargetName string `json:"target_name"` - Key string `json:"key"` - Hash string `json:"hash"` - Size int64 `json:"size"` -} - -func TinyVersionToVersion(version *postgres.TinyVersion) *Version { - var dependencies []VersionDependency - if version.Dependencies != nil { - dependencies = make([]VersionDependency, len(version.Dependencies)) - for i, v := range version.Dependencies { - dependencies[i] = VersionDependencyToVersionDependency(v) - } - } - - var targets []VersionTarget - if version.Targets != nil { - targets = make([]VersionTarget, len(version.Targets)) - for i, v := range version.Targets { - targets[i] = VersionTargetToVersionTarget(v) - } - } - - return &Version{ - UpdatedAt: version.UpdatedAt, - CreatedAt: version.CreatedAt, - ID: version.ID, - Version: version.Version, - SMLVersion: version.SMLVersion, - Dependencies: dependencies, - Targets: targets, - } -} - -func VersionToVersion(version *postgres.Version) *Version { - return &Version{ - ID: version.ID, - Version: version.Version, - SMLVersion: version.SMLVersion, - Changelog: version.Changelog, - Downloads: version.Downloads, - Stability: version.Stability, - Approved: version.Approved, - UpdatedAt: version.UpdatedAt, - CreatedAt: version.CreatedAt, - ModID: version.ModID, - } -} - -func VersionDependencyToVersionDependency(version postgres.VersionDependency) VersionDependency { - return VersionDependency{ - ModID: version.ModID, - Condition: version.Condition, - Optional: version.Optional, - } -} - -func VersionTargetToVersionTarget(version postgres.VersionTarget) VersionTarget { - return VersionTarget{ - VersionID: version.VersionID, - TargetName: version.TargetName, - Key: version.Key, - Hash: version.Hash, - Size: version.Size, - } -} - -type ModUser struct { - UserID string `json:"user_id"` - Role string `json:"role"` -} - -func ModUserToModUser(userMod *postgres.UserMod) *ModUser { - return &ModUser{ - UserID: userMod.UserID, - Role: userMod.Role, - } -} - -type SMLVersion struct { - Date time.Time `json:"date"` - UpdatedAt time.Time `json:"updated_at"` - CreatedAt time.Time `json:"created_at"` - BootstrapVersion *string `json:"bootstrap_version"` - ID string `json:"id"` - Version string `json:"version"` - Stability string `json:"stability"` - Link string `json:"link"` - Changelog string `json:"changelog"` - SatisfactoryVersion int `json:"satisfactory_version"` -} - -func SMLVersionToSMLVersion(version *postgres.SMLVersion) *SMLVersion { - return &SMLVersion{ - ID: version.ID, - Version: version.Version, - SatisfactoryVersion: version.SatisfactoryVersion, - BootstrapVersion: version.BootstrapVersion, - Stability: version.Stability, - Date: version.Date, - Link: version.Link, - Changelog: version.Changelog, - UpdatedAt: version.UpdatedAt, - CreatedAt: version.CreatedAt, - } -} diff --git a/nodes/oauth.go b/nodes/oauth.go index 9b8877d7..b5852b65 100644 --- a/nodes/oauth.go +++ b/nodes/oauth.go @@ -1,15 +1,12 @@ package nodes import ( - "bytes" "net/url" "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" "github.com/satisfactorymodding/smr-api/oauth" - "github.com/satisfactorymodding/smr-api/storage" - "github.com/satisfactorymodding/smr-api/util" ) // @Summary Retrieve a list of OAuth methods @@ -56,23 +53,12 @@ func getGithub(c echo.Context) (interface{}, *ErrorResponse) { userAgent := c.Request().Header.Get("User-Agent") - avatarURL := user.Avatar - user.Avatar = "" - - session, dbUser, newUser := postgres.GetUserSession(c.Request().Context(), user, userAgent) - - if avatarURL != "" && newUser { - avatarData, err := util.LinkToWebp(c.Request().Context(), avatarURL) - if err != nil { - return nil, GenericUserError(err) - } - - success, avatarKey := storage.UploadUserAvatar(c.Request().Context(), session.UserID, bytes.NewReader(avatarData)) - if success { - dbUser.Avatar = storage.GenerateDownloadLink(avatarKey) - postgres.Save(c.Request().Context(), &dbUser) - } + token, err := db.CompleteOAuthFlow(c.Request().Context(), user, userAgent) + if err != nil { + return nil, GenericUserError(err) } - return SessionToSession(session), nil + return &UserSession{ + Token: *token, + }, nil } diff --git a/nodes/oauth_types.go b/nodes/oauth_types.go index 8ea05403..2a550b8c 100755 --- a/nodes/oauth_types.go +++ b/nodes/oauth_types.go @@ -1,15 +1,5 @@ package nodes -import ( - "github.com/satisfactorymodding/smr-api/db/postgres" -) - type UserSession struct { Token string `json:"token"` } - -func SessionToSession(session *postgres.UserSession) *UserSession { - return &UserSession{ - Token: session.Token, - } -} diff --git a/nodes/shared.go b/nodes/shared.go index ddcd60fb..87abce59 100644 --- a/nodes/shared.go +++ b/nodes/shared.go @@ -3,7 +3,7 @@ package nodes import ( "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/generated/ent" ) type DataFunction func(c echo.Context) (data interface{}, err *ErrorResponse) @@ -25,7 +25,7 @@ func dataWrapper(nested DataFunction) func(c echo.Context) error { } } -type AuthorizedDataFunction func(user *postgres.User, c echo.Context) (data interface{}, err *ErrorResponse) +type AuthorizedDataFunction func(user *ent.User, c echo.Context) (data interface{}, err *ErrorResponse) func authorized(nested AuthorizedDataFunction) DataFunction { return func(c echo.Context) (interface{}, *ErrorResponse) { diff --git a/nodes/sml.go b/nodes/sml.go index cc983dce..7f043a4d 100644 --- a/nodes/sml.go +++ b/nodes/sml.go @@ -1,9 +1,15 @@ package nodes import ( + "log/slog" + + "entgo.io/ent/dialect/sql" + "github.com/Vilsol/slox" "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/conv" + "github.com/satisfactorymodding/smr-api/generated/ent/smlversion" ) // @Summary Retrieve a list of latest versions for sml @@ -14,17 +20,17 @@ import ( // @Success 200 // @Router /sml/latest-versions [get] func getSMLLatestVersions(c echo.Context) (interface{}, *ErrorResponse) { - smlVersions := postgres.GetSMLLatestVersions(c.Request().Context()) - - if smlVersions == nil { + smlVersions, err := db.From(c.Request().Context()).SmlVersion.Query(). + WithTargets(). + Modify(func(s *sql.Selector) { + s.SelectExpr(sql.ExprP("distinct on (stability) *")) + }). + Order(smlversion.ByStability(sql.OrderDesc()), smlversion.ByCreatedAt(sql.OrderDesc())). + All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching sml versions", slog.Any("err", err)) return nil, &ErrorVersionNotFound } - result := make(map[string]*SMLVersion) - - for _, v := range *smlVersions { - result[v.Stability] = SMLVersionToSMLVersion(&v) - } - - return result, nil + return (*conv.SMLVersionImpl)(nil).ConvertSlice(smlVersions), nil } diff --git a/nodes/user.go b/nodes/user.go index a029ee49..aadb375b 100644 --- a/nodes/user.go +++ b/nodes/user.go @@ -1,21 +1,34 @@ package nodes import ( + "log/slog" "strings" + "github.com/Vilsol/slox" "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/conv" + "github.com/satisfactorymodding/smr-api/generated/ent" + "github.com/satisfactorymodding/smr-api/generated/ent/user" + "github.com/satisfactorymodding/smr-api/generated/ent/usermod" + "github.com/satisfactorymodding/smr-api/generated/ent/usersession" ) -func userFromContext(c echo.Context) *postgres.User { +func userFromContext(c echo.Context) *ent.User { authorization := c.Request().Header.Get("Authorization") if authorization == "" { return nil } - user := postgres.GetUserByToken(c.Request().Context(), authorization) + user, err := db.From(c.Request().Context()).User.Query(). + Where(user.HasSessionsWith(usersession.Token(authorization))). + First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mods", slog.Any("err", err)) + return nil + } if user == nil { return nil @@ -31,8 +44,14 @@ func userFromContext(c echo.Context) *postgres.User { // @Produce json // @Success 200 // @Router /user/me [get] -func getMe(user *postgres.User, _ echo.Context) (interface{}, *ErrorResponse) { - return UserToPrivateUser(user), nil +func getMe(user *ent.User, _ echo.Context) (interface{}, *ErrorResponse) { + return &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Avatar: user.Avatar, + CreatedAt: user.CreatedAt, + }, nil } // @Summary Log Out Current User @@ -42,8 +61,13 @@ func getMe(user *postgres.User, _ echo.Context) (interface{}, *ErrorResponse) { // @Produce json // @Success 200 // @Router /user/me/logout [get] -func getLogout(_ *postgres.User, c echo.Context) (interface{}, *ErrorResponse) { - postgres.LogoutSession(c.Request().Context(), c.Request().Header.Get("Authorization")) +func getLogout(_ *ent.User, c echo.Context) (interface{}, *ErrorResponse) { + if _, err := db.From(c.Request().Context()).UserSession.Delete(). + Where(usersession.Token(c.Request().Header.Get("Authorization"))). + Exec(c.Request().Context()); err != nil { + slox.Error(c.Request().Context(), "failed deleting session", slog.Any("err", err)) + return nil, &ErrorUserNotFound + } return nil, nil } @@ -54,15 +78,14 @@ func getLogout(_ *postgres.User, c echo.Context) (interface{}, *ErrorResponse) { // @Produce json // @Success 200 // @Router /user/me/mods [get] -func getMyMods(user *postgres.User, c echo.Context) (interface{}, *ErrorResponse) { - mods := postgres.GetUserMods(c.Request().Context(), user.ID) - - converted := make([]*UserMod, len(mods)) - for k, v := range mods { - converted[k] = UserModToUserMod(&v) +func getMyMods(user *ent.User, c echo.Context) (interface{}, *ErrorResponse) { + mods, err := db.From(c.Request().Context()).UserMod.Query().Where(usermod.UserID(user.ID)).All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching authors", slog.Any("err", err)) + return nil, &ErrorUserNotFound } - return converted, nil + return (*conv.UserModImpl)(nil).ConvertSlice(mods), nil } // @Summary Retrieve a list of Users @@ -80,18 +103,13 @@ func getUsers(c echo.Context) (interface{}, *ErrorResponse) { // TODO limit amount of users requestable - users := postgres.GetUsersByID(c.Request().Context(), userIDSplit) - - if users == nil { - return nil, &ErrorUserNotFound + users, err := db.From(c.Request().Context()).User.Query().Where(user.IDIn(userIDSplit...)).All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching users", slog.Any("err", err)) + return nil, nil } - converted := make([]*PublicUser, len(*users)) - for k, v := range *users { - converted[k] = UserToPublicUser(&v) - } - - return converted, nil + return (*conv.UserImpl)(nil).ConvertSlice(users), nil } // @Summary Retrieve a Users Mods @@ -105,20 +123,23 @@ func getUsers(c echo.Context) (interface{}, *ErrorResponse) { func getUserMods(c echo.Context) (interface{}, *ErrorResponse) { userID := c.Param("userId") - user := postgres.GetUserByID(c.Request().Context(), userID) + user, err := db.From(c.Request().Context()).User.Get(c.Request().Context(), userID) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mods", slog.Any("err", err)) + return nil, &ErrorUserNotFound + } if user == nil { return nil, &ErrorUserNotFound } - mods := postgres.GetUserMods(c.Request().Context(), user.ID) - - converted := make([]*UserMod, len(mods)) - for k, v := range mods { - converted[k] = UserModToUserMod(&v) + mods, err := db.From(c.Request().Context()).UserMod.Query().Where(usermod.UserID(user.ID)).All(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mods", slog.Any("err", err)) + return nil, &ErrorModNotFound } - return converted, nil + return (*conv.UserModImpl)(nil).ConvertSlice(mods), nil } // @Summary Retrieve a User @@ -132,11 +153,20 @@ func getUserMods(c echo.Context) (interface{}, *ErrorResponse) { func getUser(c echo.Context) (interface{}, *ErrorResponse) { userID := c.Param("userId") - user := postgres.GetUserByID(c.Request().Context(), userID) + user, err := db.From(c.Request().Context()).User.Get(c.Request().Context(), userID) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching mods", slog.Any("err", err)) + return nil, &ErrorUserNotFound + } if user == nil { return nil, &ErrorUserNotFound } - return UserToPublicUser(user), nil + return &PublicUser{ + ID: user.ID, + Username: user.Username, + Avatar: user.Avatar, + CreatedAt: user.CreatedAt, + }, nil } diff --git a/nodes/user_types.go b/nodes/user_types.go index f929f531..dcfd4512 100644 --- a/nodes/user_types.go +++ b/nodes/user_types.go @@ -2,8 +2,6 @@ package nodes import ( "time" - - "github.com/satisfactorymodding/smr-api/db/postgres" ) type User struct { @@ -14,40 +12,9 @@ type User struct { Avatar string `json:"avatar"` } -func UserToPrivateUser(user *postgres.User) *User { - return &User{ - ID: user.ID, - Email: user.Email, - Username: user.Username, - Avatar: user.Avatar, - CreatedAt: user.CreatedAt, - } -} - type PublicUser struct { CreatedAt time.Time `json:"created_at"` ID string `json:"id"` Username string `json:"username"` Avatar string `json:"avatar"` } - -func UserToPublicUser(user *postgres.User) *PublicUser { - return &PublicUser{ - ID: user.ID, - Username: user.Username, - Avatar: user.Avatar, - CreatedAt: user.CreatedAt, - } -} - -type UserMod struct { - ModID string `json:"mod_id"` - Role string `json:"role"` -} - -func UserModToUserMod(mod *postgres.UserMod) *UserMod { - return &UserMod{ - ModID: mod.ModID, - Role: mod.Role, - } -} diff --git a/nodes/version.go b/nodes/version.go index 1da2eeef..3302e0e2 100644 --- a/nodes/version.go +++ b/nodes/version.go @@ -1,11 +1,15 @@ package nodes import ( + "log/slog" "time" + "github.com/Vilsol/slox" "github.com/labstack/echo/v4" - "github.com/satisfactorymodding/smr-api/db/postgres" + "github.com/satisfactorymodding/smr-api/db" + "github.com/satisfactorymodding/smr-api/generated/conv" + "github.com/satisfactorymodding/smr-api/generated/ent/versiontarget" "github.com/satisfactorymodding/smr-api/redis" "github.com/satisfactorymodding/smr-api/storage" ) @@ -21,13 +25,17 @@ import ( func getVersion(c echo.Context) (interface{}, *ErrorResponse) { versionID := c.Param("versionId") - version := postgres.GetVersion(c.Request().Context(), versionID) + version, err := db.From(c.Request().Context()).Version.Get(c.Request().Context(), versionID) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return nil, &ErrorVersionNotFound + } if version == nil { return nil, &ErrorVersionNotFound } - return VersionToVersion(version), nil + return (*conv.VersionImpl)(nil).Convert(version), nil } // @Summary Download a Version @@ -41,14 +49,18 @@ func getVersion(c echo.Context) (interface{}, *ErrorResponse) { func downloadVersion(c echo.Context) error { versionID := c.Param("versionId") - version := postgres.GetVersion(c.Request().Context(), versionID) + version, err := db.From(c.Request().Context()).Version.Get(c.Request().Context(), versionID) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return err + } if version == nil { return c.String(404, "version not found") } if redis.CanIncrement(c.RealIP(), "download", "version:"+versionID, time.Hour*4) { - postgres.IncrementVersionDownloads(c.Request().Context(), version) + _ = version.Update().AddDownloads(1).Exec(c.Request().Context()) } return c.Redirect(302, storage.GenerateDownloadLink(version.Key)) @@ -68,20 +80,28 @@ func downloadModTarget(c echo.Context) error { versionID := c.Param("versionId") target := c.Param("target") - version := postgres.GetVersion(c.Request().Context(), versionID) + version, err := db.From(c.Request().Context()).Version.Get(c.Request().Context(), versionID) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching version", slog.Any("err", err)) + return err + } if version == nil { return c.String(404, "version not found, versionID:"+versionID) } - versionTarget := postgres.GetVersionTarget(c.Request().Context(), versionID, target) + versionTarget, err := version.QueryTargets().Where(versiontarget.TargetName(target)).First(c.Request().Context()) + if err != nil { + slox.Error(c.Request().Context(), "failed fetching target", slog.Any("err", err)) + return err + } if versionTarget == nil { return c.String(404, "target not found, versionID:"+versionID+" target:"+target) } if redis.CanIncrement(c.RealIP(), "download", "version:"+versionID, time.Hour*4) { - postgres.IncrementVersionDownloads(c.Request().Context(), version) + _ = version.Update().AddDownloads(1).Exec(c.Request().Context()) } return c.Redirect(302, storage.GenerateDownloadLink(versionTarget.Key)) diff --git a/tests/announcements_test.go b/tests/announcements_test.go index c880b350..608fb703 100644 --- a/tests/announcements_test.go +++ b/tests/announcements_test.go @@ -8,7 +8,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/migrations" ) @@ -16,7 +15,6 @@ import ( func init() { migrations.SetMigrationDir("../migrations") config.SetConfigDir("../") - postgres.EnableDebug() db.EnableDebug() } diff --git a/tests/guides_test.go b/tests/guides_test.go index 6544423c..da5d8369 100644 --- a/tests/guides_test.go +++ b/tests/guides_test.go @@ -8,7 +8,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/migrations" ) @@ -16,7 +15,6 @@ import ( func init() { migrations.SetMigrationDir("../migrations") config.SetConfigDir("../") - postgres.EnableDebug() db.EnableDebug() } diff --git a/tests/mod_test.go b/tests/mod_test.go index 1b0733e9..7006e507 100644 --- a/tests/mod_test.go +++ b/tests/mod_test.go @@ -8,7 +8,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/migrations" ) @@ -16,7 +15,6 @@ import ( func init() { migrations.SetMigrationDir("../migrations") config.SetConfigDir("../") - postgres.EnableDebug() db.EnableDebug() } diff --git a/tests/sml_versions_test.go b/tests/sml_versions_test.go index e981a12e..768530a9 100644 --- a/tests/sml_versions_test.go +++ b/tests/sml_versions_test.go @@ -9,7 +9,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/migrations" ) @@ -17,7 +16,6 @@ import ( func init() { migrations.SetMigrationDir("../migrations") config.SetConfigDir("../") - postgres.EnableDebug() db.EnableDebug() } diff --git a/tests/utils.go b/tests/utils.go index cac3229c..240737cf 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -2,19 +2,24 @@ package tests import ( "context" + "database/sql" + "fmt" "log/slog" "sync" "github.com/Vilsol/slox" "github.com/machinebox/graphql" + "github.com/spf13/viper" smr "github.com/satisfactorymodding/smr-api/api" "github.com/satisfactorymodding/smr-api/auth" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/redis" "github.com/satisfactorymodding/smr-api/util" "github.com/satisfactorymodding/smr-api/validation" + + // Import pgx + _ "github.com/jackc/pgx/v5/stdlib" ) func setup() (context.Context, *graphql.Client, func()) { @@ -30,14 +35,39 @@ func setup() (context.Context, *graphql.Client, func()) { TableName string } - // TODO Replace with ENT - err := postgres.DBCtx(ctx).Raw(`SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'`).Scan(&out).Error + connection, err := sql.Open("pgx", fmt.Sprintf( + "sslmode=disable host=%s port=%d user=%s dbname=%s password=%s", + viper.GetString("database.postgres.host"), + viper.GetInt("database.postgres.port"), + viper.GetString("database.postgres.user"), + viper.GetString("database.postgres.db"), + viper.GetString("database.postgres.pass"), + )) + if err != nil { + panic(err) + } + + query, err := connection.Query(`SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'`) if err != nil { panic(err) } + defer query.Close() + + for query.Next() { + row := struct { + TableName string + }{} + + err = query.Scan(&row.TableName) + if err != nil { + panic(err) + } + + out = append(out, row) + } for _, name := range out { - err := postgres.DBCtx(ctx).Exec(`DROP TABLE IF EXISTS ` + name.TableName + ` CASCADE`).Error + _, err = connection.Exec(`DROP TABLE IF EXISTS ` + name.TableName + ` CASCADE`) if err != nil { panic(err) } diff --git a/tests/version_test.go b/tests/version_test.go index bcc023a3..552667fe 100644 --- a/tests/version_test.go +++ b/tests/version_test.go @@ -21,7 +21,6 @@ import ( "github.com/satisfactorymodding/smr-api/config" "github.com/satisfactorymodding/smr-api/db" - "github.com/satisfactorymodding/smr-api/db/postgres" "github.com/satisfactorymodding/smr-api/generated" "github.com/satisfactorymodding/smr-api/migrations" ) @@ -29,7 +28,6 @@ import ( func init() { migrations.SetMigrationDir("../migrations") config.SetConfigDir("../") - postgres.EnableDebug() db.EnableDebug() } diff --git a/validation/validation.go b/validation/validation.go index 1a9113b2..50df00b4 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -170,6 +170,8 @@ func ExtractModInfo(ctx context.Context, body []byte, withMetadata bool, withVal break } } + } else { + slox.Warn(ctx, "no database context provided to validator") } slox.Info(ctx, "decided engine version", slog.String("version", engineVersion))