@@ 5,6 5,8 @@ import "log"
import "net/http"
import "os"
+var proxied bool = false
+
func fatal(err error) {
fmt.Fprintln(os.Stderr, "fatal:", err)
os.Exit(2)
@@ 15,6 17,42 @@ func usage() {
os.Exit(1)
}
+type SavedWriter struct {
+ w http.ResponseWriter
+ statusCode int
+}
+func (s *SavedWriter) Header() http.Header {
+ return s.w.Header()
+}
+func (s *SavedWriter) Write(b []byte) (int, error) {
+ if s.statusCode == 0 {
+ s.statusCode = 200
+ }
+
+ return s.w.Write(b)
+}
+func (s *SavedWriter) WriteHeader(statusCode int) {
+ s.statusCode = statusCode
+ s.w.WriteHeader(statusCode)
+}
+
+type Handler struct { inner http.Handler }
+func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if proxied {
+ xff := r.Header.Get("X-Forwarded-For")
+ if xff != "" {
+ r.RemoteAddr = xff
+ }
+ }
+
+ s := SavedWriter { w, 0 }
+ h.inner.ServeHTTP(&s, r)
+
+ log.Printf(
+ "%d - %s %s %s - %s",
+ s.statusCode, r.Method, r.RequestURI, r.Proto, r.RemoteAddr)
+}
+
func main() {
if 2 > len(os.Args) || len(os.Args) > 3 {
usage()
@@ 51,7 89,9 @@ func main() {
http.HandleFunc("/", NotFoundGet)
- fmt.Println("initialized!")
+ handler := Handler { http.DefaultServeMux }
+
+ log.Println("initialized!")
- log.Fatal(http.ListenAndServe(addr, nil))
+ log.Fatal(http.ListenAndServe(addr, handler))
}
@@ 3,7 3,9 @@ package main
import "embed"
import "fmt"
import "html/template"
+import "log"
import "net/http"
+import "net/url"
import "reflect"
import "strconv"
import "strings"
@@ 55,27 57,58 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not Implemented!", 400)
}
+func TestSecHeaders(r *http.Request) bool {
+ sfm := r.Header.Get("Sec-Fetch-Mode")
+ if sfm == "websocket" {
+ return false
+ } else if sfm == "cors" || sfm == "no-cors" {
+ sfs := r.Header.Get("Sec-Fetch-Site")
+ return sfs != "cross-site"
+ } else {
+ return true
+ }
+}
+func Login(r *http.Request) (jrnl *Jrnl, user string) {
+ if !TestSecHeaders(r) {
+ return nil, ""
+ } else if r.Host != "" && r.Method == "POST" {
+ origin := r.Header.Get("Origin")
+ if origin == "" {
+ origin = r.Header.Get("Referer")
+ }
+
+ if origin != "" {
+ uri, err := url.Parse(origin)
+ if err != nil && uri.Host != r.Host {
+ return nil, ""
+ }
+ }
+ }
+
-func Login(r *http.Request) *Jrnl {
user, pass, ok := r.BasicAuth()
if !ok {
- return nil
+ return nil, ""
}
- jrnl, ok := Jrnls[user]
+ jrnl, ok = Jrnls[user]
if !ok {
- return nil
+ return nil, user
}
if !jrnl.Passhash.Check(pass) {
- return nil
+ return nil, user
} else {
- return jrnl
+ return jrnl, user
}
}
func MustLogin(w http.ResponseWriter, r *http.Request) *Jrnl {
- jrnl := Login(r)
+ jrnl, user := Login(r)
if jrnl == nil {
+ if user != "" {
+ log.Printf("failed login attempt as %s from %s", user, r.RemoteAddr)
+ }
+
w.Header().Add("WWW-Authenticate", "Basic realm=\"jrnl\"")
http.Error(w, "Unauthorized!", 401)
return nil
@@ 91,13 124,17 @@ type Err500Page struct {
Error error
}
func Err500(w http.ResponseWriter, r *http.Request, jrnl *Jrnl, err error) {
+ ty := reflect.TypeOf(err)
+
+ log.Printf("error 500 serving %s for %s: %s %s", r.URL.Path, jrnl.Id, ty, err)
+
w.WriteHeader(500)
ReplyTemplate(
w, r, jrnl, "500.html",
Err500Page {
Links: BuildLinks(r, jrnl),
Title: "Jrnl Server Error",
- Type: reflect.TypeOf(err),
+ Type: ty,
Error: err } )
}
@@ 128,7 165,7 @@ func ReplyTemplate(w http.ResponseWriter, r *http.Request, jrnl *Jrnl, tpl strin
err := Templates.ExecuteTemplate(&buf, tpl, ctx)
if err != nil {
msg := fmt.Sprintf("error in template %s: %s", tpl, err)
- fmt.Println(msg)
+ log.Println(msg)
Err500(w, r, jrnl, err)
} else {
w.Write([]byte(buf.String()))