/* This application ensures the required environment variables are set Then it uses the included countersql package to test a connection to the database It then caches the number of unique visits in memory by querying the database Finally, it sets up a web server with the endpoint http://localhost:8080/ For each unique source IP address which makes a GET request to this endpoint, the count of unique visitors is incremented by 1, both in the database and in memory cache. */ package main import ( "log" "net/http" "os" "strconv" "strings" "deadbeef.codes/steven/siteviewcounter/countersql" "github.com/go-sql-driver/mysql" ) var ( database countersql.Configuration // Configuration information required to connect to the database, using the countersql package uniqueVisits int // In memory cache of unique visitors ) // Application Startup func init() { envVars := make(map[string]string) envVars["dbusername"] = os.Getenv("dbusername") envVars["dbpassword"] = os.Getenv("dbpassword") envVars["dbhostname"] = os.Getenv("dbhostname") envVars["dbname"] = os.Getenv("dbname") envVars["timezone"] = os.Getenv("timezone") for key, value := range envVars { if value == "" { log.Fatalf("shell environment variable %s is not set", key) } } // Database Config dbConfig := mysql.Config{} dbConfig.User = envVars["dbusername"] dbConfig.Passwd = envVars["dbpassword"] dbConfig.Addr = envVars["dbhostname"] dbConfig.DBName = envVars["dbname"] dbConfig.Net = "tcp" dbConfig.ParseTime = true dbConfig.AllowNativePasswords = true database = countersql.Configuration{} database.DSN = dbConfig.FormatDSN() // Test database online at startup and get count of visits dbConn, err := database.Connect() if err != nil { log.Fatalf("failed to connect to database: %v", err) } // Check if database needs to be initialized (create tables) // does nothing if tables exist err = dbConn.InitializeDatabase() if err != nil { log.Fatalf("failed to initialize database: %v", err) } uniqueVisits, err = dbConn.GetUniqueVisits() if err != nil { log.Fatalf("failed to get number of unique visits from database: %v", err) } dbConn.DB.Close() } // HTTP Routing func main() { // API Handlers http.HandleFunc("/", countHandler) log.Print("Service listening on :8080") log.Fatal(http.ListenAndServe(":8080", nil)) } // HTTP handler function func countHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { // CORS header change required. //TBD wildcard is bad because it could allow illegitmate visits to be recorded if someone was nefarious and embedded // front end code on a different website than your own. Need to implement environment variable to set allowed origin. w.Header().Set("Access-Control-Allow-Origin", "*") w.Write([]byte(strconv.Itoa(uniqueVisits))) // Connect to database dbConn, err := database.Connect() if err != nil { log.Printf("failed to connect to database: %v", err) w.WriteHeader(http.StatusFailedDependency) return } defer dbConn.DB.Close() // We now get the source IP address of this request var ipAddress string // Check if we're behind a reverse proxy / WAF if len(r.Header.Get("X-Forwarded-For")) > 0 { ipAddress = r.Header.Get("X-Forwarded-For") } else { ipAddress = r.RemoteAddr } ipAddress = strings.Split(ipAddress, ":")[0] // Check if this is the first time this IP address has visited returnVisitor, err := dbConn.HasIPVisited(ipAddress) if err != nil { log.Printf("failed to determine if this is a return visitor, no data is being logged: %v", err) return } if returnVisitor { // Log their visit err = dbConn.IncrementVisitor(ipAddress) log.Printf("return visitor from %s", ipAddress) } else { // Insert a new visitor row in the database err = dbConn.AddVisitor(ipAddress) uniqueVisits++ log.Printf("new visitor from %s", ipAddress) } if err != nil { log.Printf("failed to add/update visit record in database: %v", err) return } } else { // Needs to be GET method w.WriteHeader(http.StatusMethodNotAllowed) } }