diff --git a/apiserver/apiserver.yaml b/apiserver/apiserver.yaml new file mode 100644 index 0000000..456455d --- /dev/null +++ b/apiserver/apiserver.yaml @@ -0,0 +1,8 @@ +port: 8081 +vectorizer: + url: http://localhost:8080/vectorize +db: + user: faceserver + password: aaa + name: faceserver + host: localhost diff --git a/apiserver/apiserver/math.go b/apiserver/apiserver/math.go new file mode 100644 index 0000000..b0e0141 --- /dev/null +++ b/apiserver/apiserver/math.go @@ -0,0 +1,37 @@ +package apiserver + +import ( + "math" +) + +// replaced with single run version of CosinMetric +func FrobeniusNorm(arr []float64) float64 { + sum := 0.0 + for _, num := range arr { + anum := math.Abs(num) + sum += anum * anum + } + return math.Sqrt(sum) +} + +func CosinMetric(x []float64, y []float64) float64 { + l := len(x) + if l != len(y) { + return -1.0 + } + + xsum := 0.0 + ysum := 0.0 + sum := 0.0 + idx := 0 + for idx < l { + xabs := math.Abs(x[idx]) + yabs := math.Abs(y[idx]) + xsum += xabs * xabs + ysum += yabs * yabs + sum += x[idx] * y[idx] + idx++ + } + + return sum / (math.Sqrt(xsum) * math.Sqrt(ysum)) +} diff --git a/apiserver/apiserver/math_test.go b/apiserver/apiserver/math_test.go new file mode 100644 index 0000000..bc3d94f --- /dev/null +++ b/apiserver/apiserver/math_test.go @@ -0,0 +1,27 @@ +package apiserver + +import ( + "math" + "testing" +) + +const eps = 1e-8 + +func equaleps(a, b float64) bool { + return math.Abs(a-b) < eps +} + +func TestFrobeniusNorm(t *testing.T) { + norm := FrobeniusNorm([]float64{-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0}) + if !equaleps(norm, 7.745966692414834) { + t.Errorf("Norm was incorrect, got: %f, want: %f.", norm, 7.745966692414834) + } +} + +func TestCosinMetric(t *testing.T) { + norm := CosinMetric([]float64{0.46727048, 0.43004233, 0.27952332, 0.1524828, 0.47310451}, + []float64{0.03538705, 0.81665373, 0.15395064, 0.29546334, 0.50521321}) + if !equaleps(norm, 0.8004287073454146) { + t.Errorf("CosineMetric was incorrect, got: %f, want: %f.", norm, 0.8004287073454146) + } +} diff --git a/apiserver/apiserver/server.go b/apiserver/apiserver/server.go new file mode 100644 index 0000000..0d89246 --- /dev/null +++ b/apiserver/apiserver/server.go @@ -0,0 +1,224 @@ +package apiserver + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/google/uuid" + "github.com/spf13/viper" + "io" + "log" + "mime/multipart" + "net/http" + "os" + "path/filepath" +) + +var Dbo PgStorage + +type JsonPerson struct { + Id string `json:"id,omitempty"` + Box []uint32 `json:"box,omitempty"` + Score float64 `json:"score,omitempty"` + Probability float64 `json:"probability,omitempty"` +} + +type JsonResponse struct { + Status string `json:"status,omitempty"` + Url string `json:"url,omitempty"` + Filename string `json:"filename,omitempty"` + Directory string `json:"directory,omitempty"` + Persons []JsonPerson `json:"persons"` +} + +func sendError(w http.ResponseWriter, err error) { + log.Printf("%v\n", err) + jsonResponse(w, 400, JsonResponse{ + Status: err.Error(), + }) +} + +func Learn(w http.ResponseWriter, r *http.Request) { + filename, uid, result, err := uploadSave(w, r) + if err != nil { + sendError(w, err) + return + } + + if len(result) != 1 { + sendError(w, errors.New("More than one face detected.")) + return + } + + pid := r.FormValue("person") + if pid == "" { + sendError(w, errors.New("Person identification is required.")) + return + } + directory := r.FormValue("directory") + if directory == "" { + sendError(w, errors.New("Directory is required.")) + return + } + + person := Person{ + Id: pid, + Directory: directory, + Filename: filename, + FilenameUid: uid, + Score: result[0].Score, + Box: result[0].Box, + Vector: result[0].Vector, + } + err = Dbo.Store(person) + if err != nil { + sendError(w, err) + return + } + + jsonResponse(w, http.StatusCreated, JsonResponse{ + Status: "OK", + Url: "/files/" + uid, + Filename: filename, + Directory: directory, + Persons: []JsonPerson{{ + Id: pid, + Box: person.Box, + Score: person.Score, + }}, + }) +} + +func Recognize(w http.ResponseWriter, r *http.Request) { + filename, uid, result, err := uploadSave(w, r) + if err != nil { + sendError(w, err) + return + } + + directory := r.FormValue("directory") + if directory == "" { + sendError(w, errors.New("Directory is required.")) + return + } + + persons, err := Dbo.GetDirectory(directory) + if err != nil { + sendError(w, err) + return + } + + jp := []JsonPerson{} + for _, r := range result { + maxprob := -1.0 + var maxperson Person + for _, p := range persons { + cm := CosinMetric(r.Vector, p.Vector) + if cm > maxprob { + maxprob = cm + maxperson = p + } + } + jp = append(jp, JsonPerson{ + Id: maxperson.Id, + Box: r.Box, + Score: r.Score, + Probability: maxprob, + }) + } + + jsonResponse(w, http.StatusCreated, JsonResponse{ + Status: "OK", + Url: "/files/" + uid, + Filename: filename, + Directory: directory, + Persons: jp, + }) +} + +func uploadSave(w http.ResponseWriter, r *http.Request) (string, string, []VectorizerResult, error) { + if err := checkMethod(w, r); err != nil { + return "", "", nil, err + } + + if err := r.ParseMultipartForm(32 << 20); err != nil { + return "", "", nil, err + } + + file, handle, err := r.FormFile("file") + if err != nil { + return "", "", nil, err + } + defer file.Close() + + mimeType := handle.Header.Get("Content-Type") + if err := checkFileType(mimeType); err != nil { + return "", "", nil, err + } + + uid, err := saveFile(w, file, handle) + if err != nil { + return "", "", nil, err + } + + reader, err := os.Open("./files/" + uid) + if err != nil { + return "", "", nil, err + } + defer reader.Close() + results, err := Vectorize(uid, reader, viper.GetString("vectorizer.url")) + if err != nil { + return "", "", nil, err + } + + return handle.Filename, uid, results, nil +} + +func checkMethod(w http.ResponseWriter, r *http.Request) error { + if r.Method != http.MethodPost { + return errors.New("POST method required") + } + + return nil +} + +func checkFileType(mimeType string) error { + switch mimeType { + case "image/jpeg", "image/png": + return nil + default: + return errors.New(fmt.Sprintf("Invalid file format %s", mimeType)) + } +} + +func generateFilename(filename string) string { + e := filepath.Ext(filename) + uid := uuid.New().String() + return uid + e +} + +func saveFile(w http.ResponseWriter, file multipart.File, handle *multipart.FileHeader) (string, error) { + uid := generateFilename(handle.Filename) + f, err := os.OpenFile("./files/"+uid, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + return "", err + } + defer f.Close() + + _, err = io.Copy(f, file) + if err != nil { + return "", err + } + + return uid, nil +} + +func jsonResponse(w http.ResponseWriter, code int, message JsonResponse) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + resp, err := json.Marshal(message) + if err != nil { + log.Fatalf("Cannot format %v", message) + } + w.Write(resp) +} diff --git a/apiserver/apiserver/storage.go b/apiserver/apiserver/storage.go new file mode 100644 index 0000000..242009b --- /dev/null +++ b/apiserver/apiserver/storage.go @@ -0,0 +1,94 @@ +package apiserver + +import ( + "database/sql" + "fmt" + "github.com/lib/pq" +) + +type Person struct { + Id string + Directory string + Filename string + FilenameUid string + Score float64 + Box []uint32 + Vector []float64 +} + +func (u Person) String() string { + return fmt.Sprintf("Person<%s %s %f %s %s %v %v>", u.Id, u.Directory, u.Score, u.Filename, u.FilenameUid, u.Box, u.Vector) +} + +type PgStorage struct { + db *sql.DB +} + +func NewStorage(user string, password string, database string, host string) (PgStorage, error) { + connStr := fmt.Sprintf("user=%s dbname=%s password=%s host=%s", user, database, password, host) + db, err := sql.Open("postgres", connStr) + if err != nil { + return PgStorage{}, err + } + + pgo := PgStorage{ + db: db, + } + + return pgo, nil +} + +func (pgo *PgStorage) CloseStorage() { + pgo.db.Close() +} + +func (pgo *PgStorage) Store(person Person) error { + _, err := pgo.db.Exec("insert into persons (id, directory, filename, filenameuid, score, box, vector) values ($1, $2, $3, $4, $5, $6, $7)", + person.Id, person.Directory, person.Filename, person.FilenameUid, person.Score, pq.Array(person.Box), pq.Array(person.Vector)) + + return err +} + +func (pgo *PgStorage) GetDirectory(directory string) ([]Person, error) { + var persons []Person + rows, err := pgo.db.Query("select id, filename, filenameuid, score, box, vector from persons where directory=$1", directory) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var id string + var filename string + var filenameuid string + var score float64 + var box []int64 + var vector []float64 + err = rows.Scan(&id, &filename, &filenameuid, &score, pq.Array(&box), pq.Array(&vector)) + if err != nil { + return nil, err + } + + rebox := make([]uint32, len(box)) + for i, v := range box { + rebox[i] = uint32(v) + } + + persons = append(persons, Person{ + Id: id, + Directory: directory, + Filename: filename, + FilenameUid: filenameuid, + Score: score, + Box: rebox, + Vector: vector, + }) + } + + err = rows.Err() + if err != nil { + return nil, err + } + + return persons, nil +} diff --git a/apiserver/apiserver/vectorizer.go b/apiserver/apiserver/vectorizer.go new file mode 100644 index 0000000..1403094 --- /dev/null +++ b/apiserver/apiserver/vectorizer.go @@ -0,0 +1,71 @@ +package apiserver + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/textproto" + "path/filepath" + "strings" +) + +type VectorizerResult struct { + Box []uint32 `json:"box"` + Vector []float64 `json:"vector"` + Score float64 `json:"score"` +} + +func Vectorize(filename string, reader io.Reader, vectorizerUrl string) ([]VectorizerResult, error) { + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "file", filepath.Base(filename))) + switch e := strings.ToLower(filepath.Ext(filename)); e { + case ".png": + h.Set("Content-Type", "image/png") + case ".jpg", ".jpeg": + h.Set("Content-Type", "image/jpeg") + default: + return nil, errors.New(fmt.Sprintf("Invalid extension %s", e)) + } + fileWriter, err := bodyWriter.CreatePart(h) + + if err != nil { + fmt.Println("error writing to buffer") + return nil, err + } + + //iocopy + _, err = io.Copy(fileWriter, reader) + if err != nil { + return nil, err + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, err := http.Post(vectorizerUrl, contentType, bodyBuf) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + resp_body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var result []VectorizerResult + err = json.Unmarshal(resp_body, &result) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/apiserver/go.mod b/apiserver/go.mod new file mode 100644 index 0000000..bd3ac83 --- /dev/null +++ b/apiserver/go.mod @@ -0,0 +1,9 @@ +module gitea.ehp.cz/Aprar/faceserver + +go 1.12 + +require ( + github.com/google/uuid v1.1.1 + github.com/lib/pq v1.2.0 + github.com/spf13/viper v1.4.0 +) diff --git a/apiserver/main.go b/apiserver/main.go new file mode 100644 index 0000000..2c63ad3 --- /dev/null +++ b/apiserver/main.go @@ -0,0 +1,39 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "strconv" + + "gitea.ehp.cz/Aprar/faceserver/apiserver" + "github.com/spf13/viper" +) + +func main() { + viper.SetConfigName("apiserver") // name of config file (without extension) + viper.AddConfigPath("/etc/faceserver/") // path to look for the config file in + viper.AddConfigPath("$HOME/.faceserver") // call multiple times to add many search paths + viper.AddConfigPath(".") // optionally look for config in the working directory + viper.SetEnvPrefix("AS_") + viper.AutomaticEnv() + err := viper.ReadInConfig() // Find and read the config file + if err != nil { // Handle errors reading the config file + panic(fmt.Errorf("Fatal error config file: %s \n", err)) + } + + apiserver.Dbo, err = apiserver.NewStorage(viper.GetString("db.user"), viper.GetString("db.password"), viper.GetString("db.name"), viper.GetString("db.host")) + if err != nil { + panic(fmt.Errorf("Fatal error database connection: %s \n", err)) + } + + http.Handle("/", http.FileServer(http.Dir("./public"))) + + fs := http.FileServer(http.Dir("./files")) + http.Handle("/files/", http.StripPrefix("/files", fs)) + + http.HandleFunc("/learn", apiserver.Learn) + http.HandleFunc("/recognize", apiserver.Recognize) + log.Println("Running") + log.Fatal(http.ListenAndServe(":" + strconv.Itoa(viper.GetInt("port")), nil)) +}