diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a289a26..93c0756 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: with: go-version: ${{ matrix.go-version }} - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - uses: actions/cache@v3 with: path: ~/go/pkg/mod diff --git a/README.md b/README.md index ad7a723..afa6156 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Zero Allocation JSON Logger -[![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/zerolog) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/zerolog/master/LICENSE) [![Build Status](https://travis-ci.org/rs/zerolog.svg?branch=master)](https://travis-ci.org/rs/zerolog) [![Go Coverage](https://github.com/rs/zerolog/wiki/coverage.svg)](https://raw.githack.com/wiki/rs/zerolog/coverage.html) +[![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/zerolog) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/zerolog/master/LICENSE) [![Build Status](https://github.com/rs/zerolog/actions/workflows/test.yml/badge.svg)](https://github.com/rs/zerolog/actions/workflows/test.yml) [![Go Coverage](https://github.com/rs/zerolog/wiki/coverage.svg)](https://raw.githack.com/wiki/rs/zerolog/coverage.html) The zerolog package provides a fast and simple logger dedicated to JSON output. @@ -547,7 +547,7 @@ and facilitates the unification of logging and tracing in some systems: type TracingHook struct{} func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { - ctx := e.Ctx() + ctx := e.GetCtx() spanId := getSpanIdFromContext(ctx) // as per your tracing framework e.Str("span-id", spanId) } diff --git a/benchmark_test.go b/benchmark_test.go index b2bb79a..39a35da 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -3,7 +3,7 @@ package zerolog import ( "context" "errors" - "io/ioutil" + "io" "net" "testing" "time" @@ -15,7 +15,7 @@ var ( ) func BenchmarkLogEmpty(b *testing.B) { - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -25,7 +25,7 @@ func BenchmarkLogEmpty(b *testing.B) { } func BenchmarkDisabled(b *testing.B) { - logger := New(ioutil.Discard).Level(Disabled) + logger := New(io.Discard).Level(Disabled) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -35,7 +35,7 @@ func BenchmarkDisabled(b *testing.B) { } func BenchmarkInfo(b *testing.B) { - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -45,7 +45,7 @@ func BenchmarkInfo(b *testing.B) { } func BenchmarkContextFields(b *testing.B) { - logger := New(ioutil.Discard).With(). + logger := New(io.Discard).With(). Str("string", "four!"). Time("time", time.Time{}). Int("int", 123). @@ -60,7 +60,7 @@ func BenchmarkContextFields(b *testing.B) { } func BenchmarkContextAppend(b *testing.B) { - logger := New(ioutil.Discard).With(). + logger := New(io.Discard).With(). Str("foo", "bar"). Logger() b.ResetTimer() @@ -72,7 +72,7 @@ func BenchmarkContextAppend(b *testing.B) { } func BenchmarkLogFields(b *testing.B) { - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -102,7 +102,7 @@ func BenchmarkLogArrayObject(b *testing.B) { obj1 := obj{"a", "b", 2} obj2 := obj{"c", "d", 3} obj3 := obj{"e", "f", 4} - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { @@ -224,7 +224,7 @@ func BenchmarkLogFieldType(b *testing.B) { return e.Object("k", objects[0]) }, } - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() for name := range types { f := types[name] @@ -358,7 +358,7 @@ func BenchmarkContextFieldType(b *testing.B) { return c.Timestamp() }, } - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() for name := range types { f := types[name] diff --git a/binary_test.go b/binary_test.go index b4882d7..6559735 100644 --- a/binary_test.go +++ b/binary_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" - // "io/ioutil" stdlog "log" "time" ) diff --git a/cmd/prettylog/prettylog.go b/cmd/prettylog/prettylog.go index 61bf8b8..9b8b43c 100644 --- a/cmd/prettylog/prettylog.go +++ b/cmd/prettylog/prettylog.go @@ -1,6 +1,8 @@ package main import ( + "bufio" + "errors" "fmt" "io" "os" @@ -14,13 +16,39 @@ func isInputFromPipe() bool { } func main() { - if !isInputFromPipe() { - fmt.Println("The command is intended to work with pipes.") - fmt.Println("Usage: app_with_zerolog | 2> >(prettylog)") + writer := zerolog.NewConsoleWriter() + + if isInputFromPipe() { + _, _ = io.Copy(writer, os.Stdin) + } else if len(os.Args) > 1 { + for _, filename := range os.Args[1:] { + // Scan each line from filename and write it into writer + r, err := os.Open(filename) + if err != nil { + fmt.Printf("%s open: %v", filename, err) + os.Exit(1) + } + scanner := bufio.NewScanner(r) + for scanner.Scan() { + _, err = writer.Write(scanner.Bytes()) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + fmt.Printf("%s write: %v", filename, err) + os.Exit(1) + } + } + if err := scanner.Err(); err != nil { + fmt.Printf("%s scan: %v", filename, err) + os.Exit(1) + } + } + } else { + fmt.Println("Usage:") + fmt.Println(" app_with_zerolog | 2> >(prettylog)") + fmt.Println(" prettylog zerolog_output.jsonl") os.Exit(1) return } - - writer := zerolog.NewConsoleWriter() - _, _ = io.Copy(writer, os.Stdin) } diff --git a/console.go b/console.go index 8b0e0c6..e8eeaa3 100644 --- a/console.go +++ b/console.go @@ -76,6 +76,8 @@ type ConsoleWriter struct { FormatErrFieldValue Formatter FormatExtra func(map[string]interface{}, *bytes.Buffer) error + + FormatPrepare func(map[string]interface{}) error } // NewConsoleWriter creates and initializes a new ConsoleWriter. @@ -124,6 +126,13 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) { return n, fmt.Errorf("cannot decode event: %s", err) } + if w.FormatPrepare != nil { + err = w.FormatPrepare(evt) + if err != nil { + return n, err + } + } + for _, p := range w.PartsOrder { w.writePart(buf, evt, p) } @@ -272,7 +281,7 @@ func (w ConsoleWriter) writePart(buf *bytes.Buffer, evt map[string]interface{}, } case MessageFieldName: if w.FormatMessage == nil { - f = consoleDefaultFormatMessage + f = consoleDefaultFormatMessage(w.NoColor, evt[LevelFieldName]) } else { f = w.FormatMessage } @@ -310,8 +319,13 @@ func needsQuote(s string) bool { return false } -// colorize returns the string s wrapped in ANSI code c, unless disabled is true. +// colorize returns the string s wrapped in ANSI code c, unless disabled is true or c is 0. func colorize(s interface{}, c int, disabled bool) string { + e := os.Getenv("NO_COLOR") + if e != "" || c == 0 { + disabled = true + } + if disabled { return fmt.Sprintf("%s", s) } @@ -373,27 +387,16 @@ func consoleDefaultFormatLevel(noColor bool) Formatter { return func(i interface{}) string { var l string if ll, ok := i.(string); ok { - switch ll { - case LevelTraceValue: - l = colorize("TRC", colorMagenta, noColor) - case LevelDebugValue: - l = colorize("DBG", colorYellow, noColor) - case LevelInfoValue: - l = colorize("INF", colorGreen, noColor) - case LevelWarnValue: - l = colorize("WRN", colorRed, noColor) - case LevelErrorValue: - l = colorize(colorize("ERR", colorRed, noColor), colorBold, noColor) - case LevelFatalValue: - l = colorize(colorize("FTL", colorRed, noColor), colorBold, noColor) - case LevelPanicValue: - l = colorize(colorize("PNC", colorRed, noColor), colorBold, noColor) - default: - l = colorize(ll, colorBold, noColor) + level, _ := ParseLevel(ll) + fl, ok := FormattedLevels[level] + if ok { + l = colorize(fl, LevelColors[level], noColor) + } else { + l = strings.ToUpper(ll)[0:3] } } else { if i == nil { - l = colorize("???", colorBold, noColor) + l = "???" } else { l = strings.ToUpper(fmt.Sprintf("%s", i))[0:3] } @@ -420,11 +423,18 @@ func consoleDefaultFormatCaller(noColor bool) Formatter { } } -func consoleDefaultFormatMessage(i interface{}) string { - if i == nil { - return "" +func consoleDefaultFormatMessage(noColor bool, level interface{}) Formatter { + return func(i interface{}) string { + if i == nil || i == "" { + return "" + } + switch level { + case LevelInfoValue, LevelWarnValue, LevelErrorValue, LevelFatalValue, LevelPanicValue: + return colorize(fmt.Sprintf("%s", i), colorBold, noColor) + default: + return fmt.Sprintf("%s", i) + } } - return fmt.Sprintf("%s", i) } func consoleDefaultFormatFieldName(noColor bool) Formatter { @@ -445,6 +455,6 @@ func consoleDefaultFormatErrFieldName(noColor bool) Formatter { func consoleDefaultFormatErrFieldValue(noColor bool) Formatter { return func(i interface{}) string { - return colorize(fmt.Sprintf("%s", i), colorRed, noColor) + return colorize(colorize(fmt.Sprintf("%s", i), colorBold, noColor), colorRed, noColor) } } diff --git a/console_test.go b/console_test.go index 18c2db7..acf1ccb 100644 --- a/console_test.go +++ b/console_test.go @@ -3,7 +3,7 @@ package zerolog_test import ( "bytes" "fmt" - "io/ioutil" + "io" "os" "strings" "testing" @@ -97,13 +97,32 @@ func TestConsoleWriter(t *testing.T) { t.Errorf("Unexpected error when writing output: %s", err) } - expectedOutput := "\x1b[90m\x1b[0m \x1b[31mWRN\x1b[0m Foobar\n" + expectedOutput := "\x1b[90m\x1b[0m \x1b[33mWRN\x1b[0m \x1b[1mFoobar\x1b[0m\n" actualOutput := buf.String() if actualOutput != expectedOutput { t.Errorf("Unexpected output %q, want: %q", actualOutput, expectedOutput) } }) + t.Run("NO_COLOR = true", func(t *testing.T) { + os.Setenv("NO_COLOR", "anything") + + buf := &bytes.Buffer{} + w := zerolog.ConsoleWriter{Out: buf} + + _, err := w.Write([]byte(`{"level": "warn", "message": "Foobar"}`)) + if err != nil { + t.Errorf("Unexpected error when writing output: %s", err) + } + + expectedOutput := " WRN Foobar\n" + actualOutput := buf.String() + if actualOutput != expectedOutput { + t.Errorf("Unexpected output %q, want: %q", actualOutput, expectedOutput) + } + os.Unsetenv("NO_COLOR") + }) + t.Run("Write fields", func(t *testing.T) { buf := &bytes.Buffer{} w := zerolog.ConsoleWriter{Out: buf, NoColor: true} @@ -229,7 +248,7 @@ func TestConsoleWriter(t *testing.T) { t.Errorf("Unexpected error when writing output: %s", err) } - expectedOutput := "\x1b[90m\x1b[0m \x1b[31mWRN\x1b[0m Foobar \x1b[36mfoo=\x1b[0mbar\n" + expectedOutput := "\x1b[90m\x1b[0m \x1b[33mWRN\x1b[0m \x1b[1mFoobar\x1b[0m \x1b[36mfoo=\x1b[0mbar\n" actualOutput := buf.String() if actualOutput != expectedOutput { t.Errorf("Unexpected output %q, want: %q", actualOutput, expectedOutput) @@ -399,6 +418,29 @@ func TestConsoleWriterConfiguration(t *testing.T) { } }) + t.Run("Sets FormatPrepare", func(t *testing.T) { + buf := &bytes.Buffer{} + w := zerolog.ConsoleWriter{ + Out: buf, NoColor: true, PartsOrder: []string{"level", "message"}, + FormatPrepare: func(evt map[string]interface{}) error { + evt["message"] = fmt.Sprintf("msg=%s", evt["message"]) + return nil + }, + } + + evt := `{"level": "info", "message": "Foobar"}` + _, err := w.Write([]byte(evt)) + if err != nil { + t.Errorf("Unexpected error when writing output: %s", err) + } + + expectedOutput := "INF msg=Foobar\n" + actualOutput := buf.String() + if actualOutput != expectedOutput { + t.Errorf("Unexpected output %q, want: %q", actualOutput, expectedOutput) + } + }) + t.Run("Uses local time for console writer without time zone", func(t *testing.T) { // Regression test for issue #483 (check there for more details) @@ -432,7 +474,7 @@ func BenchmarkConsoleWriter(b *testing.B) { var msg = []byte(`{"level": "info", "foo": "bar", "message": "HELLO", "time": "1990-01-01"}`) - w := zerolog.ConsoleWriter{Out: ioutil.Discard, NoColor: false} + w := zerolog.ConsoleWriter{Out: io.Discard, NoColor: false} for i := 0; i < b.N; i++ { w.Write(msg) diff --git a/context.go b/context.go index 47989a5..23827b2 100644 --- a/context.go +++ b/context.go @@ -3,7 +3,7 @@ package zerolog import ( "context" "fmt" - "io/ioutil" + "io" "math" "net" "time" @@ -23,7 +23,7 @@ func (c Context) Logger() Logger { // Only map[string]interface{} and []interface{} are accepted. []interface{} must // alternate string keys and arbitrary values, and extraneous ones are ignored. func (c Context) Fields(fields interface{}) Context { - c.l.context = appendFields(c.l.context, fields) + c.l.context = appendFields(c.l.context, fields, c.l.stack) return c } @@ -57,7 +57,7 @@ func (c Context) Array(key string, arr LogArrayMarshaler) Context { // Object marshals an object that implement the LogObjectMarshaler interface. func (c Context) Object(key string, obj LogObjectMarshaler) Context { - e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) + e := newEvent(LevelWriterAdapter{io.Discard}, 0) e.Object(key, obj) c.l.context = enc.AppendObjectData(c.l.context, e.buf) putEvent(e) @@ -66,7 +66,7 @@ func (c Context) Object(key string, obj LogObjectMarshaler) Context { // EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface. func (c Context) EmbedObject(obj LogObjectMarshaler) Context { - e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) + e := newEvent(LevelWriterAdapter{io.Discard}, 0) e.EmbedObject(obj) c.l.context = enc.AppendObjectData(c.l.context, e.buf) putEvent(e) @@ -163,6 +163,22 @@ func (c Context) Errs(key string, errs []error) Context { // Err adds the field "error" with serialized err to the logger context. func (c Context) Err(err error) Context { + if c.l.stack && ErrorStackMarshaler != nil { + switch m := ErrorStackMarshaler(err).(type) { + case nil: + case LogObjectMarshaler: + c = c.Object(ErrorStackFieldName, m) + case error: + if m != nil && !isNilValue(m) { + c = c.Str(ErrorStackFieldName, m.Error()) + } + case string: + c = c.Str(ErrorStackFieldName, m) + default: + c = c.Interface(ErrorStackFieldName, m) + } + } + return c.AnErr(ErrorFieldName, err) } @@ -385,10 +401,19 @@ func (c Context) Durs(key string, d []time.Duration) Context { // Interface adds the field key with obj marshaled using reflection. func (c Context) Interface(key string, i interface{}) Context { + if obj, ok := i.(LogObjectMarshaler); ok { + return c.Object(key, obj) + } c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i) return c } +// Type adds the field key with val's type using reflection. +func (c Context) Type(key string, val interface{}) Context { + c.l.context = enc.AppendType(enc.AppendKey(c.l.context, key), val) + return c +} + // Any is a wrapper around Context.Interface. func (c Context) Any(key string, i interface{}) Context { return c.Interface(key, i) diff --git a/ctx_test.go b/ctx_test.go index 5bc41e5..397e966 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -1,14 +1,15 @@ package zerolog import ( + "bytes" "context" - "io/ioutil" + "io" "reflect" "testing" ) func TestCtx(t *testing.T) { - log := New(ioutil.Discard) + log := New(io.Discard) ctx := log.WithContext(context.Background()) log2 := Ctx(ctx) if !reflect.DeepEqual(log, *log2) { @@ -37,13 +38,13 @@ func TestCtx(t *testing.T) { } func TestCtxDisabled(t *testing.T) { - dl := New(ioutil.Discard).Level(Disabled) + dl := New(io.Discard).Level(Disabled) ctx := dl.WithContext(context.Background()) if ctx != context.Background() { t.Error("WithContext stored a disabled logger") } - l := New(ioutil.Discard).With().Str("foo", "bar").Logger() + l := New(io.Discard).With().Str("foo", "bar").Logger() ctx = l.WithContext(ctx) if !reflect.DeepEqual(Ctx(ctx), &l) { t.Error("WithContext did not store logger") @@ -68,3 +69,31 @@ func TestCtxDisabled(t *testing.T) { t.Error("WithContext did not override logger with a disabled logger") } } + +type logObjectMarshalerImpl struct { + name string + age int +} + +func (t logObjectMarshalerImpl) MarshalZerologObject(e *Event) { + e.Str("name", "custom_value").Int("age", t.age) +} + +func Test_InterfaceLogObjectMarshaler(t *testing.T) { + var buf bytes.Buffer + log := New(&buf) + ctx := log.WithContext(context.Background()) + + log2 := Ctx(ctx) + + withLog := log2.With().Interface("obj", &logObjectMarshalerImpl{ + name: "foo", + age: 29, + }).Logger() + + withLog.Info().Msg("test") + + if got, want := buf.String(), `{"level":"info","obj":{"name":"custom_value","age":29},"message":"test"}`+"\n"; got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/diode/diode_test.go b/diode/diode_test.go index d0d0aff..29b3f86 100644 --- a/diode/diode_test.go +++ b/diode/diode_test.go @@ -3,7 +3,7 @@ package diode_test import ( "bytes" "fmt" - "io/ioutil" + "io" "log" "os" "testing" @@ -39,7 +39,7 @@ func TestClose(t *testing.T) { } func Benchmark(b *testing.B) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) benchs := map[string]time.Duration{ "Waiter": 0, @@ -47,7 +47,7 @@ func Benchmark(b *testing.B) { } for name, interval := range benchs { b.Run(name, func(b *testing.B) { - w := diode.NewWriter(ioutil.Discard, 100000, interval, nil) + w := diode.NewWriter(io.Discard, 100000, interval, nil) log := zerolog.New(w) defer w.Close() diff --git a/event.go b/event.go index 2a5d3b0..5c949f8 100644 --- a/event.go +++ b/event.go @@ -164,7 +164,7 @@ func (e *Event) Fields(fields interface{}) *Event { if e == nil { return e } - e.buf = appendFields(e.buf, fields) + e.buf = appendFields(e.buf, fields, e.stack) return e } diff --git a/example.jsonl b/example.jsonl new file mode 100644 index 0000000..d73193d --- /dev/null +++ b/example.jsonl @@ -0,0 +1,7 @@ +{"time":"5:41PM","level":"info","message":"Starting listener","listen":":8080","pid":37556} +{"time":"5:41PM","level":"debug","message":"Access","database":"myapp","host":"localhost:4962","pid":37556} +{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":23} +{"time":"5:41PM","level":"info","message":"Access","method":"POST","path":"/posts","pid":37556,"resp_time":532} +{"time":"5:41PM","level":"warn","message":"Slow request","method":"POST","path":"/posts","pid":37556,"resp_time":532} +{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":10} +{"time":"5:41PM","level":"error","message":"Database connection lost","database":"myapp","pid":37556,"error":"connection reset by peer"} diff --git a/fields.go b/fields.go index c1eb5ce..23606dd 100644 --- a/fields.go +++ b/fields.go @@ -12,13 +12,13 @@ func isNilValue(i interface{}) bool { return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0 } -func appendFields(dst []byte, fields interface{}) []byte { +func appendFields(dst []byte, fields interface{}, stack bool) []byte { switch fields := fields.(type) { case []interface{}: if n := len(fields); n&0x1 == 1 { // odd number fields = fields[:n-1] } - dst = appendFieldList(dst, fields) + dst = appendFieldList(dst, fields, stack) case map[string]interface{}: keys := make([]string, 0, len(fields)) for key := range fields { @@ -28,13 +28,13 @@ func appendFields(dst []byte, fields interface{}) []byte { kv := make([]interface{}, 2) for _, key := range keys { kv[0], kv[1] = key, fields[key] - dst = appendFieldList(dst, kv) + dst = appendFieldList(dst, kv, stack) } } return dst } -func appendFieldList(dst []byte, kvList []interface{}) []byte { +func appendFieldList(dst []byte, kvList []interface{}, stack bool) []byte { for i, n := 0, len(kvList); i < n; i += 2 { key, val := kvList[i], kvList[i+1] if key, ok := key.(string); ok { @@ -74,6 +74,21 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte { default: dst = enc.AppendInterface(dst, m) } + + if stack && ErrorStackMarshaler != nil { + dst = enc.AppendKey(dst, ErrorStackFieldName) + switch m := ErrorStackMarshaler(val).(type) { + case nil: + case error: + if m != nil && !isNilValue(m) { + dst = enc.AppendString(dst, m.Error()) + } + case string: + dst = enc.AppendString(dst, m) + default: + dst = enc.AppendInterface(dst, m) + } + } case []error: dst = enc.AppendArrayStart(dst) for i, err := range val { diff --git a/globals.go b/globals.go index e1067de..b38a7fc 100644 --- a/globals.go +++ b/globals.go @@ -108,6 +108,34 @@ var ( // DefaultContextLogger is returned from Ctx() if there is no logger associated // with the context. DefaultContextLogger *Logger + + // LevelColors are used by ConsoleWriter's consoleDefaultFormatLevel to color + // log levels. + LevelColors = map[Level]int{ + TraceLevel: colorBlue, + DebugLevel: 0, + InfoLevel: colorGreen, + WarnLevel: colorYellow, + ErrorLevel: colorRed, + FatalLevel: colorRed, + PanicLevel: colorRed, + } + + // FormattedLevels are used by ConsoleWriter's consoleDefaultFormatLevel + // for a short level name. + FormattedLevels = map[Level]string{ + TraceLevel: "TRC", + DebugLevel: "DBG", + InfoLevel: "INF", + WarnLevel: "WRN", + ErrorLevel: "ERR", + FatalLevel: "FTL", + PanicLevel: "PNC", + } + + // TriggerLevelWriterBufferReuseLimit is a limit in bytes that a buffer is dropped + // from the TriggerLevelWriter buffer pool if the buffer grows above the limit. + TriggerLevelWriterBufferReuseLimit = 64 * 1024 ) var ( diff --git a/go.mod b/go.mod index 6ec5abc..4385ff9 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.15 require ( github.com/coreos/go-systemd/v22 v22.5.0 - github.com/mattn/go-colorable v0.1.12 + github.com/mattn/go-colorable v0.1.13 + github.com/mattn/go-isatty v0.0.19 // indirect github.com/pkg/errors v0.9.1 github.com/rs/xid v1.5.0 + golang.org/x/sys v0.12.0 // indirect ) diff --git a/go.sum b/go.sum index 50b6327..f2b8d84 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,13 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= @@ -12,3 +17,7 @@ github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/hlog/hlog.go b/hlog/hlog.go index 3773400..06ca4ad 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -3,7 +3,9 @@ package hlog import ( "context" + "net" "net/http" + "strings" "time" "github.com/rs/xid" @@ -89,6 +91,35 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { } } +func getHost(hostPort string) string { + if hostPort == "" { + return "" + } + + host, _, err := net.SplitHostPort(hostPort) + if err != nil { + return hostPort + } + return host +} + +// RemoteIPHandler is similar to RemoteAddrHandler, but logs only +// an IP, not a port. +func RemoteIPHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getHost(r.RemoteAddr) + if ip != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, ip) + }) + } + next.ServeHTTP(w, r) + }) + } +} + // UserAgentHandler adds the request's user-agent as a field to the context's logger // using fieldKey as field key. func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler { @@ -135,6 +166,21 @@ func ProtoHandler(fieldKey string) func(next http.Handler) http.Handler { } } +// HTTPVersionHandler is similar to ProtoHandler, but it does not store the "HTTP/" +// prefix in the protocol name. +func HTTPVersionHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proto := strings.TrimPrefix(r.Proto, "HTTP/") + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, proto) + }) + next.ServeHTTP(w, r) + }) + } +} + type idKey struct{} // IDFromRequest returns the unique id associated to the request if any. @@ -205,14 +251,76 @@ func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.H } } +// EtagHandler adds Etag header from response's header as a field to +// the context's logger using fieldKey as field key. +func EtagHandler(fieldKey string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + etag := w.Header().Get("Etag") + if etag != "" { + etag = strings.ReplaceAll(etag, `"`, "") + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, etag) + }) + } + }() + next.ServeHTTP(w, r) + }) + } +} + +func ResponseHeaderHandler(fieldKey, headerName string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + value := w.Header().Get(headerName) + if value != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, value) + }) + } + }() + next.ServeHTTP(w, r) + }) + } +} + // AccessHandler returns a handler that call f after each request. func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() lw := mutil.WrapWriter(w) + defer func() { + f(r, lw.Status(), lw.BytesWritten(), time.Since(start)) + }() next.ServeHTTP(lw, r) - f(r, lw.Status(), lw.BytesWritten(), time.Since(start)) + }) + } +} + +// HostHandler adds the request's host as a field to the context's logger +// using fieldKey as field key. If trimPort is set to true, then port is +// removed from the host. +func HostHandler(fieldKey string, trimPort ...bool) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var host string + if len(trimPort) > 0 && trimPort[0] { + host = getHost(r.Host) + } else { + host = r.Host + } + if host != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, host) + }) + } + next.ServeHTTP(w, r) }) } } diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index c584e98..445d6b6 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -7,7 +7,7 @@ import ( "bytes" "context" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -122,6 +122,38 @@ func TestRemoteAddrHandlerIPv6(t *testing.T) { } } +func TestRemoteIPHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + RemoteAddr: "1.2.3.4:1234", + } + h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestRemoteIPHandlerIPv6(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + RemoteAddr: "[2001:db8:a0b:12f0::1]:1234", + } + h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestUserAgentHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -201,6 +233,46 @@ func TestCustomHeaderHandler(t *testing.T) { } } +func TestEtagHandler(t *testing.T) { + out := &bytes.Buffer{} + w := httptest.NewRecorder() + r := &http.Request{} + h := EtagHandler("etag")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Etag", `"abcdef"`) + w.WriteHeader(http.StatusOK) + })) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + l := FromRequest(r) + l.Log().Msg("") + }) + h3 := NewHandler(zerolog.New(out))(h2) + h3.ServeHTTP(w, r) + if want, got := `{"etag":"abcdef"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestResponseHeaderHandler(t *testing.T) { + out := &bytes.Buffer{} + w := httptest.NewRecorder() + r := &http.Request{} + h := ResponseHeaderHandler("encoding", "Content-Encoding")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", `gzip`) + w.WriteHeader(http.StatusOK) + })) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + l := FromRequest(r) + l.Log().Msg("") + }) + h3 := NewHandler(zerolog.New(out))(h2) + h3.ServeHTTP(w, r) + if want, got := `{"encoding":"gzip"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestProtoHandler(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -217,6 +289,22 @@ func TestProtoHandler(t *testing.T) { } } +func TestHTTPVersionHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + Proto: "HTTP/1.1", + } + h := HTTPVersionHandler("proto")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"proto":"1.1"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestCombinedHandlers(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{ @@ -245,10 +333,10 @@ func BenchmarkHandlers(b *testing.B) { })) h2 := MethodHandler("method")(RequestHandler("request")(h1)) handlers := map[string]http.Handler{ - "Single": NewHandler(zerolog.New(ioutil.Discard))(h1), - "Combined": NewHandler(zerolog.New(ioutil.Discard))(h2), - "SingleDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h1), - "CombinedDisabled": NewHandler(zerolog.New(ioutil.Discard).Level(zerolog.Disabled))(h2), + "Single": NewHandler(zerolog.New(io.Discard))(h1), + "Combined": NewHandler(zerolog.New(io.Discard))(h2), + "SingleDisabled": NewHandler(zerolog.New(io.Discard).Level(zerolog.Disabled))(h1), + "CombinedDisabled": NewHandler(zerolog.New(io.Discard).Level(zerolog.Disabled))(h2), } for name := range handlers { h := handlers[name] @@ -292,3 +380,56 @@ func TestCtxWithID(t *testing.T) { t.Errorf("CtxWithID() = %v, want %v", got, want) } } + +func TestHostHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{Host: "example.com:8080"} + h := HostHandler("host")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"host":"example.com:8080"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestHostHandlerWithoutPort(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{Host: "example.com:8080"} + h := HostHandler("host", true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"host":"example.com"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + +func TestGetHost(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"example.com:8080", "example.com"}, + {"example.com", "example.com"}, + {"invalid", "invalid"}, + {"192.168.0.1:8080", "192.168.0.1"}, + {"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + {"こんにちは.com:8080", "こんにちは.com"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.input, func(t *testing.T) { + result := getHost(tt.input) + if tt.expected != result { + t.Errorf("Invalid log output, got: %s, want: %s", result, tt.expected) + } + }) + } +} diff --git a/hook_test.go b/hook_test.go index 0a15f74..b011816 100644 --- a/hook_test.go +++ b/hook_test.go @@ -3,7 +3,7 @@ package zerolog import ( "bytes" "context" - "io/ioutil" + "io" "testing" ) @@ -278,7 +278,7 @@ func TestPrehook(t *testing.T) { } func BenchmarkHooks(b *testing.B) { - logger := New(ioutil.Discard) + logger := New(io.Discard) b.ResetTimer() b.Run("Nop/Single", func(b *testing.B) { log := logger.Hook(nopHook) diff --git a/log.go b/log.go index 11b10b5..9a83e37 100644 --- a/log.go +++ b/log.go @@ -126,7 +126,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" "strconv" "strings" @@ -255,7 +254,7 @@ type Logger struct { // you may consider using sync wrapper. func New(w io.Writer) Logger { if w == nil { - w = ioutil.Discard + w = io.Discard } lw, ok := w.(LevelWriter) if !ok { @@ -338,10 +337,13 @@ func (l Logger) Sample(s Sampler) Logger { } // Hook returns a logger with the h Hook. -func (l Logger) Hook(h Hook) Logger { - newHooks := make([]Hook, len(l.hooks), len(l.hooks)+1) +func (l Logger) Hook(hooks ...Hook) Logger { + if len(hooks) == 0 { + return l + } + newHooks := make([]Hook, len(l.hooks), len(l.hooks)+len(hooks)) copy(newHooks, l.hooks) - l.hooks = append(newHooks, h) + l.hooks = append(newHooks, hooks...) return l } @@ -469,6 +471,14 @@ func (l *Logger) Printf(format string, v ...interface{}) { } } +// Println sends a log event using debug level and no extra field. +// Arguments are handled in the manner of fmt.Println. +func (l *Logger) Println(v ...interface{}) { + if e := l.Debug(); e.Enabled() { + e.CallerSkipFrame(1).Msg(fmt.Sprintln(v...)) + } +} + // Write implements the io.Writer interface. This is useful to set as a writer // for the standard library log. func (l Logger) Write(p []byte) (n int, err error) { @@ -510,6 +520,9 @@ func (l *Logger) newEvent(level Level, done func(string)) *Event { // should returns true if the log event should be logged. func (l *Logger) should(lvl Level) bool { + if l.w == nil { + return false + } if lvl < l.level || lvl < GlobalLevel() { return false } diff --git a/log_example_test.go b/log_example_test.go index 70caa66..c48a1e3 100644 --- a/log_example_test.go +++ b/log_example_test.go @@ -72,7 +72,7 @@ func ExampleLogger_Hook() { var levelNameHook LevelNameHook var messageHook MessageHook = "The message" - log := zerolog.New(os.Stdout).Hook(levelNameHook).Hook(messageHook) + log := zerolog.New(os.Stdout).Hook(levelNameHook, messageHook) log.Info().Msg("hello world") @@ -95,6 +95,14 @@ func ExampleLogger_Printf() { // Output: {"level":"debug","message":"hello world"} } +func ExampleLogger_Println() { + log := zerolog.New(os.Stdout) + + log.Println("hello world") + + // Output: {"level":"debug","message":"hello world\n"} +} + func ExampleLogger_Trace() { log := zerolog.New(os.Stdout) diff --git a/pkgerrors/stacktrace_test.go b/pkgerrors/stacktrace_test.go index 5a13832..787a096 100644 --- a/pkgerrors/stacktrace_test.go +++ b/pkgerrors/stacktrace_test.go @@ -28,6 +28,22 @@ func TestLogStack(t *testing.T) { } } +func TestLogStackFields(t *testing.T) { + zerolog.ErrorStackMarshaler = MarshalStack + + out := &bytes.Buffer{} + log := zerolog.New(out) + + err := fmt.Errorf("from error: %w", errors.New("error message")) + log.Log().Stack().Fields([]interface{}{"error", err}).Msg("") + + got := out.String() + want := `\{"error":"from error: error message","stack":\[\{"func":"TestLogStackFields","line":"37","source":"stacktrace_test.go"\},.*\]\}\n` + if ok, _ := regexp.MatchString(want, got); !ok { + t.Errorf("invalid log output:\ngot: %v\nwant: %v", got, want) + } +} + func TestLogStackFromContext(t *testing.T) { zerolog.ErrorStackMarshaler = MarshalStack @@ -38,7 +54,23 @@ func TestLogStackFromContext(t *testing.T) { log.Log().Err(err).Msg("") // not explicitly calling Stack() got := out.String() - want := `\{"stack":\[\{"func":"TestLogStackFromContext","line":"37","source":"stacktrace_test.go"\},.*\],"error":"from error: error message"\}\n` + want := `\{"stack":\[\{"func":"TestLogStackFromContext","line":"53","source":"stacktrace_test.go"\},.*\],"error":"from error: error message"\}\n` + if ok, _ := regexp.MatchString(want, got); !ok { + t.Errorf("invalid log output:\ngot: %v\nwant: %v", got, want) + } +} + +func TestLogStackFromContextWith(t *testing.T) { + zerolog.ErrorStackMarshaler = MarshalStack + + err := fmt.Errorf("from error: %w", errors.New("error message")) + out := &bytes.Buffer{} + log := zerolog.New(out).With().Stack().Err(err).Logger() // calling Stack() on log context instead of event + + log.Error().Msg("") + + got := out.String() + want := `\{"level":"error","stack":\[\{"func":"TestLogStackFromContextWith","line":"66","source":"stacktrace_test.go"\},.*\],"error":"from error: error message"\}\n` if ok, _ := regexp.MatchString(want, got); !ok { t.Errorf("invalid log output:\ngot: %v\nwant: %v", got, want) } diff --git a/pretty.png b/pretty.png index 2420336..1449e45 100644 Binary files a/pretty.png and b/pretty.png differ diff --git a/syslog_test.go b/syslog_test.go index c168ba6..e889b01 100644 --- a/syslog_test.go +++ b/syslog_test.go @@ -100,7 +100,7 @@ func TestSyslogWriter_WithCEE(t *testing.T) { sw := testCEEwriter{&buf} log := New(SyslogCEEWriter(sw)) log.Info().Str("key", "value").Msg("message string") - got := string(buf.Bytes()) + got := buf.String() want := "@cee:{" if !strings.HasPrefix(got, want) { t.Errorf("Bad CEE message start: want %v, got %v", want, got) diff --git a/writer.go b/writer.go index 9b9ef88..50d7653 100644 --- a/writer.go +++ b/writer.go @@ -180,3 +180,135 @@ func (w *FilteredLevelWriter) WriteLevel(level Level, p []byte) (int, error) { } return len(p), nil } + +var triggerWriterPool = &sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, +} + +// TriggerLevelWriter buffers log lines at the ConditionalLevel or below +// until a trigger level (or higher) line is emitted. Log lines with level +// higher than ConditionalLevel are always written out to the destination +// writer. If trigger never happens, buffered log lines are never written out. +// +// It can be used to configure "log level per request". +type TriggerLevelWriter struct { + // Destination writer. If LevelWriter is provided (usually), its WriteLevel is used + // instead of Write. + io.Writer + + // ConditionalLevel is the level (and below) at which lines are buffered until + // a trigger level (or higher) line is emitted. Usually this is set to DebugLevel. + ConditionalLevel Level + + // TriggerLevel is the lowest level that triggers the sending of the conditional + // level lines. Usually this is set to ErrorLevel. + TriggerLevel Level + + buf *bytes.Buffer + triggered bool + mu sync.Mutex +} + +func (w *TriggerLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + + // At first trigger level or above log line, we flush the buffer and change the + // trigger state to triggered. + if !w.triggered && l >= w.TriggerLevel { + err := w.trigger() + if err != nil { + return 0, err + } + } + + // Unless triggered, we buffer everything at and below ConditionalLevel. + if !w.triggered && l <= w.ConditionalLevel { + if w.buf == nil { + w.buf = triggerWriterPool.Get().(*bytes.Buffer) + } + + // We prefix each log line with a byte with the level. + // Hopefully we will never have a level value which equals a newline + // (which could interfere with reconstruction of log lines in the trigger method). + w.buf.WriteByte(byte(l)) + w.buf.Write(p) + return len(p), nil + } + + // Anything above ConditionalLevel is always passed through. + // Once triggered, everything is passed through. + if lw, ok := w.Writer.(LevelWriter); ok { + return lw.WriteLevel(l, p) + } + return w.Write(p) +} + +// trigger expects lock to be held. +func (w *TriggerLevelWriter) trigger() error { + if w.triggered { + return nil + } + w.triggered = true + + if w.buf == nil { + return nil + } + + p := w.buf.Bytes() + for len(p) > 0 { + // We do not use bufio.Scanner here because we already have full buffer + // in the memory and we do not want extra copying from the buffer to + // scanner's token slice, nor we want to hit scanner's token size limit, + // and we also want to preserve newlines. + i := bytes.IndexByte(p, '\n') + line := p[0 : i+1] + p = p[i+1:] + // We prefixed each log line with a byte with the level. + level := Level(line[0]) + line = line[1:] + var err error + if lw, ok := w.Writer.(LevelWriter); ok { + _, err = lw.WriteLevel(level, line) + } else { + _, err = w.Write(line) + } + if err != nil { + return err + } + } + + return nil +} + +// Trigger forces flushing the buffer and change the trigger state to +// triggered, if the writer has not already been triggered before. +func (w *TriggerLevelWriter) Trigger() error { + w.mu.Lock() + defer w.mu.Unlock() + + return w.trigger() +} + +// Close closes the writer and returns the buffer to the pool. +func (w *TriggerLevelWriter) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.buf == nil { + return nil + } + + // We return the buffer only if it has not grown above the limit. + // This prevents accumulation of large buffers in the pool just + // because occasionally a large buffer might be needed. + if w.buf.Cap() <= TriggerLevelWriterBufferReuseLimit { + w.buf.Reset() + triggerWriterPool.Put(w.buf) + } + w.buf = nil + + return nil +} diff --git a/writer_test.go b/writer_test.go index 60595ba..f2a61df 100644 --- a/writer_test.go +++ b/writer_test.go @@ -195,3 +195,58 @@ func TestFilteredLevelWriter(t *testing.T) { t.Errorf("Expected %q, got %q.", want, p) } } + +type testWrite struct { + Level + Line []byte +} + +func TestTriggerLevelWriter(t *testing.T) { + tests := []struct { + write []testWrite + want []byte + all []byte + }{{ + []testWrite{ + {DebugLevel, []byte("no\n")}, + {InfoLevel, []byte("yes\n")}, + }, + []byte("yes\n"), + []byte("yes\nno\n"), + }, { + []testWrite{ + {DebugLevel, []byte("yes1\n")}, + {InfoLevel, []byte("yes2\n")}, + {ErrorLevel, []byte("yes3\n")}, + {DebugLevel, []byte("yes4\n")}, + }, + []byte("yes2\nyes1\nyes3\nyes4\n"), + []byte("yes2\nyes1\nyes3\nyes4\n"), + }} + + for k, tt := range tests { + t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + buf := bytes.Buffer{} + writer := TriggerLevelWriter{Writer: LevelWriterAdapter{&buf}, ConditionalLevel: DebugLevel, TriggerLevel: ErrorLevel} + t.Cleanup(func() { writer.Close() }) + for _, w := range tt.write { + _, err := writer.WriteLevel(w.Level, w.Line) + if err != nil { + t.Error(err) + } + } + p := buf.Bytes() + if want := tt.want; !bytes.Equal([]byte(want), p) { + t.Errorf("Expected %q, got %q.", want, p) + } + err := writer.Trigger() + if err != nil { + t.Error(err) + } + p = buf.Bytes() + if want := tt.all; !bytes.Equal([]byte(want), p) { + t.Errorf("Expected %q, got %q.", want, p) + } + }) + } +}