diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..3380c4e --- /dev/null +++ b/config/config.go @@ -0,0 +1,123 @@ +package config + +import ( + "errors" + "flag" + "fmt" + "log" + "os" + "strconv" + "time" +) + +type configOption[T any] struct { + flagName string + flagDefaultValue T + flagHelpBase string + envVarName string +} + +type configOptions struct { + userID configOption[string] + token configOption[string] + room configOption[string] + printHomeserverMemberCount configOption[bool] + homeserverVersionInfoTimeout configOption[time.Duration] +} + +type Config struct { + UserID string + Token string + Room string + PrintHomeserverMemberCount bool + HomeserverVersionInfoTimeout time.Duration +} + +var configOpts = configOptions{ + userID: configOption[string]{"user-id", "", "The Matrix user ID to use.", "MRVC_USER_ID"}, + token: configOption[string]{"token", "", "The Matrix access token to use.", "MRVC_TOKEN"}, + room: configOption[string]{"room", "", "The Matrix room to check.", "MRVC_ROOM"}, + printHomeserverMemberCount: configOption[bool]{"print-homeserver-member-count", false, "Print the member count for each homeserver.", "MRVC_PRINT_HOMESERVER_MEMBER_COUNT"}, + homeserverVersionInfoTimeout: configOption[time.Duration]{"homeserver-version-info-timeout", time.Second * 5, "Timeout for getting the homeservers version information per homeserver.", "MRVC_HOMESERVER_VERSION_INFO_TIMEOUT"}, +} + +func (configOpt configOption[T]) getFlagHelp() string { + return fmt.Sprintf("%s (EnvVar: %s)", configOpt.flagHelpBase, configOpt.envVarName) +} + +func (configOpt configOption[T]) getFlagArgs() (string, T, string) { + return configOpt.flagName, configOpt.flagDefaultValue, configOpt.getFlagHelp() +} + +// Functions for getting config values from flags and environment variables. +// Flags take precedence over environment variables. +func (configOpt configOption[T]) getConfigValueWithDefault(configFlag *T, visitedFlags map[string]bool, envVarParser func(string) T) T { + if visitedFlags[configOpt.flagName] { + return *configFlag + } else if envVar, ok := os.LookupEnv(configOpt.envVarName); ok { + return envVarParser(envVar) + } else { + return configOpt.flagDefaultValue + } +} + +// This function can be used to ensure some configuration options got explicitly set. +func (configOpt configOption[T]) getConfigValueWithError(configFlag *T, visitedFlags map[string]bool, envVarParser func(string) T) (T, error) { + if visitedFlags[configOpt.flagName] { + return *configFlag, nil + } else if envVar, ok := os.LookupEnv(configOpt.envVarName); ok { + return envVarParser(envVar), nil + } else { + return configOpt.flagDefaultValue, errors.New("no command-line flag or environment variable set") + } +} + +var userIdFlag = flag.String(configOpts.userID.getFlagArgs()) +var tokenFlag = flag.String(configOpts.token.getFlagArgs()) +var roomFlag = flag.String(configOpts.room.getFlagArgs()) +var printHomeserverMemberCountFlag = flag.Bool(configOpts.printHomeserverMemberCount.getFlagArgs()) +var homeserverVersionInfoTimeoutFlag = flag.Duration(configOpts.homeserverVersionInfoTimeout.getFlagArgs()) + +func Get() Config { + flag.Parse() + + var config Config + + visitedFlags := make(map[string]bool) + flag.Visit(func(flag *flag.Flag) { + visitedFlags[flag.Name] = true + }) + + var err error + + config.UserID, err = configOpts.userID.getConfigValueWithError(userIdFlag, visitedFlags, func(envVar string) string { return envVar }) + if err != nil { + log.Fatal("A Matrix user ID must be provided.") + } + config.Token, err = configOpts.token.getConfigValueWithError(tokenFlag, visitedFlags, func(envVar string) string { return envVar }) + if err != nil { + log.Fatal("A Matrix access token must be provided.") + } + config.Room, err = configOpts.room.getConfigValueWithError(roomFlag, visitedFlags, func(envVar string) string { return envVar }) + if err != nil { + log.Fatal("A Matrix room must be provided.") + } + config.PrintHomeserverMemberCount = configOpts.printHomeserverMemberCount.getConfigValueWithDefault(printHomeserverMemberCountFlag, visitedFlags, func(envVar string) bool { + parsedEnvVar, err := strconv.ParseBool(envVar) + if err != nil { + log.Printf("Error parsing %s:\n", configOpts.printHomeserverMemberCount.envVarName) + log.Fatal(err) + } + return parsedEnvVar + }) + config.HomeserverVersionInfoTimeout = configOpts.homeserverVersionInfoTimeout.getConfigValueWithDefault(homeserverVersionInfoTimeoutFlag, visitedFlags, func(envVar string) time.Duration { + parsedEnvVar, err := time.ParseDuration(envVar) + if err != nil { + log.Printf("Error parsing %s:\n", configOpts.homeserverVersionInfoTimeout.envVarName) + log.Fatal(err) + } + return parsedEnvVar + }) + + return config +} diff --git a/main.go b/main.go index 1884cf8..27e7ff5 100644 --- a/main.go +++ b/main.go @@ -2,22 +2,20 @@ package main import ( "context" - "flag" "fmt" "log" - "os" "slices" "sort" - "strconv" "strings" "sync" - "time" "github.com/hashicorp/go-version" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" + + "codeberg.org/june64/mrvc/config" ) type HomeserverServerVersionInfo struct { @@ -36,12 +34,6 @@ type VersionPath struct { Version string } -var userIdFlag = flag.String("user-id", "", "The Matrix user ID to use. (EnvVar: MRVC_USER_ID)") -var tokenFlag = flag.String("token", "", "The Matrix access token to use. (EnvVar: MRVC_TOKEN)") -var roomFlag = flag.String("room", "", "The Matrix room to check. (EnvVar: MRVC_ROOM)") -var printHomeserverMemberCountFlag = flag.Bool("print-homeserver-member-count", false, "Print the member count for each homeserver. (EnvVar: MRVC_PRINT_HOMESERVER_MEMBER_COUNT)") -var homeserverVersionInfoTimeoutFlag = flag.Duration("homeserver-version-info-timeout", time.Second*5, "Timeout for getting the homeservers version information per homeserver. (EnvVar: MRVC_HOMESERVER_VERSION_INFO_TIMEOUT)") - var unknownServerVersionInfo = fclient.Version{ Server: struct { Name string `json:"name"` @@ -289,70 +281,9 @@ func compareVersionStrings(a, b string) int { } func main() { - flag.Parse() + config := config.Get() - // Configuration variables. - var userIdString, token, room string - var printHomeserverMemberCount bool - var homeserverVersionInfoTimeout time.Duration - - visitedFlags := make(map[string]bool) - flag.Visit(func(visitedFlag *flag.Flag) { - visitedFlags[visitedFlag.Name] = true - }) - - // Assign flag and environment variable values to configuration variables. - // Flags take precedence over environment variables. - // This also ensures some configuration options got explicitly set. - if visitedFlags["user-id"] { - userIdString = *userIdFlag - } else if userIdEnvVar, ok := os.LookupEnv("MRVC_USER_ID"); ok { - userIdString = userIdEnvVar - } else { - log.Fatal("A Matrix user ID must be provided.") - } - if visitedFlags["token"] { - token = *tokenFlag - } else if tokenEnvVar, ok := os.LookupEnv("MRVC_TOKEN"); ok { - token = tokenEnvVar - } else { - log.Fatal("A Matrix access token must be provided.") - } - if visitedFlags["room"] { - room = *roomFlag - } else if roomEnvVar, ok := os.LookupEnv("MRVC_ROOM"); ok { - room = roomEnvVar - } else { - log.Fatal("A Matrix room must be provided.") - } - if visitedFlags["print-homeserver-member-count"] { - printHomeserverMemberCount = *printHomeserverMemberCountFlag - } else if printHomeserverMemberCountEnvVar, ok := os.LookupEnv("MRVC_PRINT_HOMESERVER_MEMBER_COUNT"); ok { - parsedPrintHomeserverMemberCountEnvVar, err := strconv.ParseBool(printHomeserverMemberCountEnvVar) - if err != nil { - log.Println("Error parsing MRVC_PRINT_HOMESERVER_MEMBER_COUNT:") - log.Fatal(err) - } - printHomeserverMemberCount = parsedPrintHomeserverMemberCountEnvVar - } else { - // The flag holds the default value. - printHomeserverMemberCount = *printHomeserverMemberCountFlag - } - if visitedFlags["homeserver-version-info-timeout"] { - homeserverVersionInfoTimeout = *homeserverVersionInfoTimeoutFlag - } else if homeserverVersionInfoTimeoutEnvVar, ok := os.LookupEnv("MRVC_HOMESERVER_VERSION_INFO_TIMEOUT"); ok { - parsedHomeserverVersionInfoTimeoutEnvVar, err := time.ParseDuration(homeserverVersionInfoTimeoutEnvVar) - if err != nil { - log.Println("Error parsing MRVC_HOMESERVER_VERSION_INFO_TIMEOUT:") - log.Fatal(err) - } - homeserverVersionInfoTimeout = parsedHomeserverVersionInfoTimeoutEnvVar - } else { - // The flag holds the default value. - homeserverVersionInfoTimeout = *homeserverVersionInfoTimeoutFlag - } - - userId := id.UserID(userIdString) + userId := id.UserID(config.UserID) _, homeserver, err := userId.ParseAndValidate() if err != nil { log.Fatal(err) @@ -360,17 +291,17 @@ func main() { client, err := mautrix.NewClient( homeserver, userId, - token, + config.Token, ) if err != nil { log.Fatal(err) } federationClient := fclient.NewClient( fclient.WithWellKnownSRVLookups(true), - fclient.WithTimeout(homeserverVersionInfoTimeout), + fclient.WithTimeout(config.HomeserverVersionInfoTimeout), ) - joinedMembers, err := client.JoinedMembers(context.Background(), id.RoomID(room)) + joinedMembers, err := client.JoinedMembers(context.Background(), id.RoomID(config.Room)) if err != nil { log.Fatal(err) } @@ -424,7 +355,7 @@ func main() { } fmt.Println("Room:") - fmt.Printf("%s -> %d\n\n", room, len(joinedMembers.Joined)) + fmt.Printf("%s -> %d\n\n", config.Room, len(joinedMembers.Joined)) fmt.Println("Version Support:") @@ -458,7 +389,7 @@ func main() { fmt.Printf(" %s -> %d\n", versionKey, membersByVersionPath[VersionPath{maxRoomVersionKey, serverKey, versionKey}]) - if printHomeserverMemberCount { + if config.PrintHomeserverMemberCount { homeserverKeys := make([]string, 0, len(versionValue)) for key := range versionValue { homeserverKeys = append(homeserverKeys, key)