package functional
import (
"cmp"
"context"
"slices"
"github.com/jake-scott/go-functional/iter/channel"
"github.com/jake-scott/go-functional/iter/slice"
)
// FilterFunc is a generic function type that takes a single element and
// returns true if it is to be included or false if the element is to be
// excluded from the result set.
//
// If an error is returned, it is passed to the stage's error handler
// function which may elect to continue or abort processing.
//
// Example:
//
// func findEvenInts(i int) (bool, error) {
// return i%2 == 0, nil
// }
type FilterFunc[T any] func(T) (bool, error)
// Filter is the non-OO version of Stage.Filter().
func Filter[T any](s *Stage[T], f FilterFunc[T], opts ...StageOption) *Stage[T] {
return s.Filter(f, opts...)
}
// Filter processes this stage's input elements by calling f for each element
// and returns a new stage that will process all the elements where f(e)
// is true.
//
// If this stage is configured to process in batch, Filter returns after all
// the input elements have been processed; those elements are passed to the
// next stage as a slice.
//
// If this stage is configured to stream, Filter returns immediately after
// launching a go-routine to process the elements in the background. The
// next stage reads from a channel that the processing goroutine writes its
// results to as they are processed.
func (s *Stage[T]) Filter(f FilterFunc[T], opts ...StageOption) *Stage[T] {
// opts for this Filter are the stage options overridden by the filter
// specific options passed to this call
merged := *s
merged.opts.processOptions(opts...)
t := merged.tracer("Filter")
defer t.end()
// Run a batch or streaming filter. The streaming filter will return
// immediately.
var i Iterator[T]
switch merged.opts.stageType {
case BatchStage:
i = merged.filterBatch(t, f)
case StreamingStage:
i = merged.filterStreaming(t, f)
}
return s.nextStage(i, opts...)
}
func (s *Stage[T]) filterBatch(t tracer, f FilterFunc[T]) Iterator[T] {
tBatch := t.subTracer("batch")
defer tBatch.end()
if sh, ok := s.i.(Size[T]); ok {
s.opts.sizeHint = sh.Size()
}
// handle parallel filters separately..
if s.opts.maxParallelism > 1 {
return s.parallelBatchFilter(tBatch, f)
}
t.msg("Sequential processing")
out := make([]T, 0, s.opts.sizeHint)
filterLoop:
for s.i.Next(s.opts.ctx) {
item := s.i.Get()
keep, err := f(item)
switch {
case err != nil:
if !s.opts.onError(ErrorContextFilterFunction, err) {
t.msg("filter done due to error: %s", err)
break filterLoop
}
case keep:
out = append(out, item)
}
}
if s.i.Error() != nil {
// if there is an iterator read error ..
if !s.opts.onError(ErrorContextItertator, s.i.Error()) {
// clear the output slice if we are told not to continue
out = []T{}
}
}
i := slice.New(out)
return &i
}
func (s *Stage[T]) parallelBatchFilter(t tracer, f FilterFunc[T]) Iterator[T] {
var tParallel tracer
var output []T
if s.opts.preserveOrder {
tParallel = t.subTracer("parallel, ordered")
// when preserving order, we call parallelBatchFilterProcessor[T, item[T]],
// so we receive a []item[T] as a return value
results := parallelBatchFilterProcessor(s, tParallel, f, orderedWrapper[T], orderedUnwrapper[T])
// sort by the index of each item[T]
slices.SortFunc(results, func(a, b item[T]) int {
return cmp.Compare(a.idx, b.idx)
})
// pull the values from the item[T] results
output = make([]T, len(results))
for i, v := range results {
output[i] = v.item
}
} else {
// when not preserving order, we call parallelBatchFilterProcessor[T, T],
// and so we receive a []T as a return value
tParallel = t.subTracer("parallel, unordered")
output = parallelBatchFilterProcessor(s, tParallel, f, unorderedWrapper[T], unorderedUnwrapper[T])
}
i := slice.New(output)
tParallel.end()
return &i
}
// Filter the stage's input items. Returns items wrapped by wrapper which
// could include the original index of the input items for use in the caller to
// sort the result.
//
// T: input type; TW: wrapped input type
func parallelBatchFilterProcessor[T any, TW any](s *Stage[T], t tracer, f FilterFunc[T],
wrapper wrapItemFunc[T, TW], unwrapper unwrapItemFunc[TW, T]) []TW {
numParallel := min(s.opts.sizeHint, s.opts.maxParallelism)
t = t.subTracer("parallelization=%d", numParallel)
defer t.end()
// MW is inferred to be the same as TW for a filter..
chOut := parallelProcessor(s.opts, numParallel, s.i, t,
// write wrapped input values to the query channel
func(ctx context.Context, i uint, t T, ch chan TW) {
item := wrapper(i, t)
select {
case ch <- item:
case <-ctx.Done():
}
},
// read wrapped values, write to the output channel if f() == true
func(item TW, ch chan TW) error {
keep, err := f(unwrapper(item))
if err != nil {
return err
}
if keep {
select {
case ch <- item:
case <-s.opts.ctx.Done():
return s.opts.ctx.Err()
}
}
return nil
})
// Back in the main thread, read results until the result channel has been
// closed by a go-routine started by parallelProcessor()
items := make([]TW, 0, s.opts.sizeHint)
for i := range chOut {
items = append(items, i)
}
if s.i.Error() != nil {
// if there is an iterator read error ..
if !s.opts.onError(ErrorContextItertator, s.i.Error()) {
// return no items if we are told not to continue
return []TW{}
}
}
return items
}
func (s *Stage[T]) filterStreaming(t tracer, f FilterFunc[T]) Iterator[T] {
if sh, ok := s.i.(Size[T]); ok {
s.opts.sizeHint = sh.Size()
}
// handle parallel filters separately..
if s.opts.maxParallelism > 1 {
return s.parallelStreamingFilter(t, f)
}
// otherwise just run a simple serial filter
t = t.subTracer("streaming, sequential")
ch := make(chan T)
go func() {
t := s.tracer("processor")
defer t.end()
readLoop:
for s.i.Next(s.opts.ctx) {
item := s.i.Get()
keep, err := f(item)
switch {
case err != nil:
if !s.opts.onError(ErrorContextFilterFunction, err) {
t.msg("filter done due to error: %s", err)
break readLoop
}
case keep:
select {
case ch <- item:
case <-s.opts.ctx.Done():
t.msg("Cancelled")
s.opts.onError(ErrorContextOther, s.opts.ctx.Err())
break readLoop
}
}
}
// if there is an iterator read error, report it even though we
// can't abort the next stage; at least we can stop sending items
// to it by way of having closed ch
if s.i.Error() != nil {
s.opts.onError(ErrorContextItertator, s.i.Error())
}
close(ch)
}()
i := channel.New(ch)
t.end()
return &i
}
func (s *Stage[T]) parallelStreamingFilter(t tracer, f FilterFunc[T]) Iterator[T] {
numParallel := min(s.opts.sizeHint, s.opts.maxParallelism)
t = t.subTracer("streaming, parallel=%d", numParallel)
defer t.end()
chOut := parallelProcessor(s.opts, numParallel, s.i, t,
// write input values to the query channel
func(ctx context.Context, i uint, t T, ch chan T) {
select {
case ch <- t:
case <-ctx.Done():
}
},
// read values from the query channel, write to the output channel if f() == true
func(t T, ch chan T) error {
keep, err := f(t)
if err != nil {
return err
}
if keep {
select {
case ch <- t:
case <-s.opts.ctx.Done():
return s.opts.ctx.Err()
}
}
return nil
})
i := channel.New(chOut)
return &i
}
// Package channel implements an interator that reads a data stream from
// the supplied channel.
package channel
import "context"
// Iterator traverses the elements of type T from a channel, until
// the channel is closed.
type Iterator[T any] struct {
ch chan T
item *T
err error
}
// New returns an implementation of Iterator that traverses the
// provided channel until reading the channel returns an error, or the channel
// is closed.
//
// ChannelIterator does not support the SizeHint interface
func New[T any](ch chan T) Iterator[T] {
return Iterator[T]{
ch: ch,
}
}
// Next reads an item from the channel and stores the value, which can be
// retrieved using the Get() method. Next returns true if an element was
// successfully read from the channel, or false if the channel was closed or
// if the context expired.
//
// If the context expired, Err() will return the result of the context's
// Err() function.
func (i *Iterator[T]) Next(ctx context.Context) bool {
ret := false
select {
case item, ok := <-i.ch:
if ok {
i.item = &item
ret = true
}
// if ok is false, the read failed due to empty closed channel
case <-ctx.Done():
i.err = ctx.Err()
}
return ret
}
// Get returns the value stored by the last successful Next method call,
// or the zero value of type T if Next has not been called.
//
// The context is not used by this method.
func (i *Iterator[T]) Get() T {
// return the zero value if called before Next()
if i.item == nil {
var ret T
return ret
}
return *i.item
}
// Error returns the context expiry reason if any from a previous call
// to Next, otherwise it returns nil.
func (i *Iterator[T]) Error() error {
return i.err
}
// Package scanner implements a stream tokenizer iterator.
//
// The package makes use of the standard library bufio.Scanner to buffer and
// split data read from an io.Reader. Scanner has a set of standard splitters
// for words, lines and runes and supports custom split functions as well.
package scanner
import (
"context"
"fmt"
)
// Iterator wraps a bufio.Scanner to traverse over a stream of tokens
// such as words or lines read from an io.Reader.
//
// Iterator does not support the SizeHint interface.
type Iterator struct {
scanner Scanner
err error
}
// Scanner() is an interface defining a subset of the methods exposed by
// bufio.Scanner, and is here primarily to assist with unit testing.
type Scanner interface {
Scan() bool
Text() string
Err() error
}
// ErrTooManyTokens is returned in response to a panic in the
// scanner.Scan() method, the result of too many tokens being returned without
// the scanner advancing.
type ErrTooManyTokens struct {
panicMessage string
err error
}
func (e ErrTooManyTokens) Error() string {
if e.err == nil {
return "too many tokens: " + e.panicMessage
} else {
return fmt.Sprintf("too many tokens: %s", e.err)
}
}
func (e ErrTooManyTokens) Unwrap() error {
return e.err
}
// New returns an implementation of Iterator that uses bufio.Scanner
// to traverse through tokens such as words or lines from an io.Reader
// such as a file.
func New(scanner Scanner) Iterator {
return Iterator{
scanner: scanner,
}
}
// Next advances the iterator to the next element (scanner token) by
// calling Scanner.Scan(). It returns false if the end of the input is reached
// or an error is encountered including cancellation of the context.
// If the scanner panics, Next returns false and Error() will return the
// message from the scanner.
func (i *Iterator) Next(ctx context.Context) (ret bool) {
defer func() {
switch err := recover().(type) {
default:
i.err = ErrTooManyTokens{panicMessage: fmt.Sprintf("%v", err)}
ret = false
case error:
i.err = ErrTooManyTokens{err: err}
ret = false
case nil:
}
}()
select {
case <-ctx.Done():
i.err = ctx.Err()
return false
default:
}
return i.scanner.Scan()
}
// Get returns the most recent token returned by the scanner during a call
// to Next(), as a string.
//
// The context is not used in this iterator implementation.
func (i *Iterator) Get() string {
return i.scanner.Text()
}
// Error returns the panic message from the scanner if one occured during
// a Next() call, or the cancellation message if the context was cancelled.
// Otherwise, Error calls the Scanner's Err() method, which
// returns nil if there are no errors or if the end of input is reached,
// otherwise the first error encounterd by the scanner.
func (i *Iterator) Error() error {
if i.err != nil {
return i.err
}
return i.scanner.Err()
}
// Package slice implements an iterator that traverses uni-directionally
// over a generic slice of elements
//
// Slice supports the SizeHint interface.
package slice
import "context"
// Iterator traverses over a slice of element of type T.
type Iterator[T any] struct {
s []T
pos int
err error
}
// New returns an implementation of Iterator that traverses
// over the provided slice. The iterator returned supports the
// SizeHint itnerface.
func New[T any](s []T) Iterator[T] {
return Iterator[T]{
s: s,
}
}
// Size returns the length of the underlying slice, implementing the
// SizeHint interface.
func (t *Iterator[T]) Size() uint {
return uint(len(t.s))
}
// Next advances the iterator to the next element of the underlying
// slice. It returns false when the end of the slice has been reached or
// the context is cancelled.
func (r *Iterator[T]) Next(ctx context.Context) bool {
if r.pos >= len(r.s) {
return false
}
select {
case <-ctx.Done():
r.err = ctx.Err()
return false
default:
}
r.pos++
return true
}
// Get returns element of the underlying slice that the iterator refers to
//
// The context is not used in this iterator implementation.
func (r *Iterator[T]) Get() T {
if r.pos == 0 {
var ret T
return ret
}
return r.s[r.pos-1]
}
// Error returns the context's error if the context is cancelled
// during a call to Next()
func (r *Iterator[T]) Error() error {
return r.err
}
package functional
import (
"cmp"
"context"
"slices"
"github.com/jake-scott/go-functional/iter/channel"
"github.com/jake-scott/go-functional/iter/slice"
)
// MapFunc is a generic function that takes a single element and returns
// a single transformed element.
//
// Example:
//
// func ipAddress(host string) (net.IP, error) {
// return net.LookupIp(host)
// }
type MapFunc[T any, M any] func(T) (M, error)
// Map processes the stage's input elements by calling m for each element,
// returning a new stage containing the same number of elements, mapped to
// new values of the same type.
//
// If the map function returns values of a different type to the input values,
// the non-OO version of Map() must be used instead.
//
// If the stage is configured to process in batch, Map returns after all the
// input elements have been processed; those elements are passed to the next
// stage as a slice.
//
// If the stage is configued to stream, Map returns immediately after
// launching go-routines to process the elements in the background. The
// returned stage reads from a channel that the processing goroutine writes
// its result to as they are processed.
func (s *Stage[T]) Map(m MapFunc[T, T], opts ...StageOption) *Stage[T] {
return Map(s, m, opts...)
}
// Map is the non-OO version of Stage.Map(). It must be used in the case
// where the map function returns items of a different type than the input
// elements, due to limitations of Golang's generic syntax.
func Map[T, M any](s *Stage[T], m MapFunc[T, M], opts ...StageOption) *Stage[M] {
// opts for this Map are the stage options overridden by the Map
// specific options passed to this call
merged := *s
merged.opts.processOptions(opts...)
t := merged.tracer("Map")
defer t.end()
var i Iterator[M]
switch merged.opts.stageType {
case BatchStage:
i = mapBatch(t, &merged, m)
case StreamingStage:
i = mapStreaming(t, &merged, m)
}
return nextStage(s, i, opts...)
}
// T: input type; M: mapped type
func mapBatch[T, M any](t tracer, s *Stage[T], m MapFunc[T, M]) Iterator[M] {
t = t.subTracer("batch")
defer t.end()
if sh, ok := s.i.(Size[T]); ok {
s.opts.sizeHint = sh.Size()
}
// handle parallel maps separately..
if s.opts.maxParallelism > 1 {
return mapBatchParallel(t, s, m)
}
t.msg("Sequential processing")
out := make([]M, 0, s.opts.sizeHint)
mapLoop:
for s.i.Next(s.opts.ctx) {
item := s.i.Get()
newItem, err := m(item)
switch {
case err != nil:
if !s.opts.onError(ErrorContextMapFunction, err) {
t.msg("map done due to error: %s", err)
break mapLoop
}
default:
out = append(out, newItem)
}
}
if s.i.Error() != nil {
// if there is an iterator read error ..
if !s.opts.onError(ErrorContextItertator, s.i.Error()) {
// clear the output slice if we are told not to continue
out = []M{}
}
}
i := slice.New(out)
return &i
}
// T: input type; M: mapped type
func mapBatchParallel[T, M any](t tracer, s *Stage[T], m MapFunc[T, M]) Iterator[M] {
var tParallel tracer
var output []M
if s.opts.preserveOrder {
tParallel = t.subTracer("parallel, ordered")
results := mapBatchParallelProcessor(s, tParallel, m, orderedWrapper[T], orderedUnwrapper[T], orderedSwitcher[T, M])
// sort by the index of each item[M]
slices.SortFunc(results, func(a, b item[M]) int {
return cmp.Compare(a.idx, b.idx)
})
// pull the values from the item[T] results
output = make([]M, len(results))
for i, v := range results {
output[i] = v.item
}
} else {
tParallel = t.subTracer("parallel, unordered")
output = mapBatchParallelProcessor(s, tParallel, m, unorderedWrapper[T], unorderedUnwrapper[T], unorderedSwitcher[T, M])
}
i := slice.New(output)
tParallel.end()
return &i
}
// T: input type; TW: wrapped input type; M: mapped type; MW: wrapped map type
func mapBatchParallelProcessor[T, TW, M, MW any](s *Stage[T], t tracer, m MapFunc[T, M],
wrapper wrapItemFunc[T, TW], unwrapper unwrapItemFunc[TW, T], switcher switcherFunc[TW, M, MW]) []MW {
numParallel := min(s.opts.sizeHint, s.opts.maxParallelism)
t = t.subTracer("parallelization=%d", numParallel)
defer t.end()
chOut := parallelProcessor(s.opts, numParallel, s.i, t,
// write wrapped input values to the query channel
func(ctx context.Context, i uint, t T, ch chan TW) {
item := wrapper(i, t)
select {
case ch <- item:
case <-ctx.Done():
}
},
// read wrapped input values, write mapped wrapped output values to
// the output channel
func(item TW, ch chan MW) error {
// Make a MW item from the TW item using a map function
mappedValue, err := m(unwrapper(item))
if err != nil {
return err
}
itemOut := switcher(item, mappedValue)
select {
case ch <- itemOut:
case <-s.opts.ctx.Done():
return s.opts.ctx.Err()
}
return nil
})
// Back in the main thread, read results until the result channel has been
// closed by (3)
items := make([]MW, 0, s.opts.sizeHint)
for i := range chOut {
items = append(items, i)
}
return items
}
// T: input type; M: mapped type
func mapStreaming[T, M any](t tracer, s *Stage[T], m MapFunc[T, M]) Iterator[M] {
if sh, ok := s.i.(Size[T]); ok {
s.opts.sizeHint = sh.Size()
}
// handle parallel maps separately..
if s.opts.maxParallelism > 1 {
return mapStreamingParallel(t, s, m)
}
// otherwise just run a simple serial map
t = t.subTracer("streaming, sequential")
ch := make(chan M)
go func() {
t := s.tracer("Map, streaming, sequential background")
defer t.end()
readLoop:
for s.i.Next(s.opts.ctx) {
item := s.i.Get()
mapped, err := m(item)
switch {
case err != nil:
if !s.opts.onError(ErrorContextFilterFunction, err) {
t.msg("map done due to error: %s", err)
break readLoop
}
default:
select {
case ch <- mapped:
case <-s.opts.ctx.Done():
t.msg("Cancelled")
s.opts.onError(ErrorContextOther, s.opts.ctx.Err())
break readLoop
}
}
}
// if there is an iterator read error, report it even though we
// can't abort the next stage; at least we can stop sending items
// to it by way of having closed ch
if s.i.Error() != nil {
s.opts.onError(ErrorContextItertator, s.i.Error())
}
close(ch)
}()
i := channel.New(ch)
t.end()
return &i
}
func mapStreamingParallel[T, M any](t tracer, s *Stage[T], m MapFunc[T, M]) Iterator[M] {
numParallel := min(s.opts.sizeHint, s.opts.maxParallelism)
t = t.subTracer("streaming, parallel=%d", numParallel)
defer t.end()
chOut := parallelProcessor(s.opts, numParallel, s.i, t,
func(ctx context.Context, i uint, t T, ch chan T) {
select {
case ch <- t:
case <-ctx.Done():
}
},
func(item T, ch chan M) error {
mapped, err := m(item)
if err != nil {
return err
}
select {
case ch <- mapped:
case <-s.opts.ctx.Done():
return s.opts.ctx.Err()
}
return nil
})
i := channel.New(chOut)
return &i
}
package functional
// ReduceFunc is a generic function that takes an element value of type T
// and an accululator value of type A and returns a new accumulator value.
//
// Example:
//
// func add(a, i int) (int, error) {
// return a + i, nil
// }
type ReduceFunc[T any, A any] func(A, T) (A, error)
// Reduce processes the stage's input elements to a single element of the same
// type, by calling r for every element and passing an accumulator value
// that each invocation of r can update by returning a value.
//
// If the Reduce function returns a value of a different type to the input
// values, the non-OO version of Reduce() must be used instead.
//
// Reduce always runs sequentially in a batch mode.
func (s *Stage[T]) Reduce(initial T, r ReduceFunc[T, T], opts ...StageOption) T {
return Reduce(s, initial, r, opts...)
}
// Reduce is the non-OO version of stage.Reduce(). It must be used in the case
// where the accumulator of the reduce function is of a different type to
// the input elements (due to limitations of go generics).
func Reduce[T, A any](s *Stage[T], initial A, r ReduceFunc[T, A], opts ...StageOption) A {
// opts for this Reduce are the stage options overridden by the reduce
// specific options passed to this call
merged := *s
merged.opts.processOptions(opts...)
t := merged.tracer("Reduce")
defer t.end()
accum := initial
var err error
reduceLoop:
for s.i.Next(merged.opts.ctx) {
// select {
// case <-merged.opts.ctx.Done():
// break reduceLoop
// default:
accum, err = r(accum, s.i.Get())
if err != nil {
if !merged.opts.onError(ErrorContextFilterFunction, err) {
t.msg("reduce done due to error: %s", err)
break reduceLoop
}
}
// }
}
if s.i.Error() != nil {
merged.opts.onError(ErrorContextItertator, s.i.Error())
}
if merged.opts.ctx.Err() != nil {
merged.opts.onError(ErrorContextOther, merged.opts.ctx.Err())
}
return accum
}
// Convenience reduction function that returns a slice of elements from the
// iterator of the pipeline stage.
func SliceFromIterator[T any](a []T, t T) ([]T, error) {
return append(a, t), nil
}
// Package functional provides highly performant functional primitives
// for Go. It supports streaming and parallel execution of stages
// of a processing pipeline.
package functional
import (
"context"
"fmt"
"sync"
"sync/atomic"
"github.com/jake-scott/go-functional/iter/channel"
"github.com/jake-scott/go-functional/iter/scanner"
"github.com/jake-scott/go-functional/iter/slice"
)
// DefaultSizeHint is used by batch processing functions for initial allocations
// when the underlying iterator cannot provide size infomation and a stage
// specific size hint has not been provided.
var DefaultSizeHint uint = 100
// StageType describes the behaviour of a pipeline stage
type StageType int
const (
// Batch stages collect the results of processing all of the
// input items before passing control to the next stage
BatchStage StageType = iota
// Streaming stages pass the results of processed input items to the
// next pipeline stage as a stream while processing other elements continues.
StreamingStage
)
func (t StageType) String() string {
switch t {
default:
return "unknown"
case BatchStage:
return "Batch"
case StreamingStage:
return "Streaming"
}
}
// ErrorContext provides error handler callbacks with a hint about where in
// processing the error occured
type ErrorContext int
const (
// ErrorContextIterator hints that the error occured reading an interator
ErrorContextItertator ErrorContext = iota
// ErrorContextFilterFunction means the error occured in a filter func
ErrorContextFilterFunction
// ErrorContextMapFunction means the error occured in a map func
ErrorContextMapFunction
// ErrorContextReduceFunction means the error occued in a reduce func
ErrorContextReduceFunction
// We don't know which phase of processing the error occured when
// the hint it ErrorContextOther
ErrorContextOther
)
// Functions complying with the ErrorHandler prototype can be used to process
// errors that occur during the pipeline processing functions. The default
// handler ignores the error. A custom handler can be provided using the
// WithErrorHandler option.
//
// Parameters:
// - where describes the context in which the error occured
// - err is the error to be handled
//
// The function should return true if processing should continue regardless,
// or false to stop processing.
type ErrorHandler func(where ErrorContext, err error) bool
func nullErrorHandler(ErrorContext, error) bool {
return true
}
var stageCounter atomic.Uint32
// Stage represents one processing phase of a larger pipeline
// The processing methods of a stage read input elements using the underlying
// Iterator and return a new Stage ready to read elements from the previous
// stage using a new iterator.
type Stage[T any] struct {
i Iterator[T]
id uint32
wg *sync.WaitGroup
opts stageOptions
}
type stageOptions struct {
stageType StageType
maxParallelism uint
preserveOrder bool
inheritOptions bool
sizeHint uint
tracer TraceFunc
tracing bool
ctx context.Context
onError ErrorHandler
}
// StageOptions provide a mechanism to customize how the processing functions
// of a stage opterate.
type StageOption func(g *stageOptions)
// The ProcessingType option configures whether the stage operates in batch
// or streaming mode. If not specified, stages default to processing in
// batch mode.
func ProcessingType(t StageType) StageOption {
return func(o *stageOptions) {
o.stageType = t
}
}
// The Parallem option defines the maximum concurrency of the stage.
//
// If not specified, the default is to process elements serially.
func Parallelism(max uint) StageOption {
return func(o *stageOptions) {
o.maxParallelism = max
}
}
// The SizeHint option provides the stage processor functions with a guideline
// regarding the number of elements there are to process. This is primarily
// used with iterators that cannot provide the information themselves.
//
// If not specified and the iterator cannot provide the information, the default
// value DefaultSizeHint is used.
func SizeHint(hint uint) StageOption {
return func(o *stageOptions) {
o.sizeHint = hint
}
}
// PreserveOrder causes concurent batch stages to retain the order of
// processed elements. This is always the case with serial stages and is
// not possible for concurrent streaming stages. Maintaining the order of
// elements for concurrent batch stages incurs a performance penalty.
//
// The default is to not maintain order.
func PreserveOrder(preserve bool) StageOption {
return func(o *stageOptions) {
o.preserveOrder = preserve
}
}
// WithContext attaches the provided context to the stage.
func WithContext(ctx context.Context) StageOption {
return func(o *stageOptions) {
o.ctx = ctx
}
}
// WithTraceFunc sets the trace function for the stage. Use WithTracing
// to enable/disable tracing.
func WithTraceFunc(f TraceFunc) StageOption {
return func(o *stageOptions) {
o.tracer = f
}
}
// WithTracing enables tracing for the stage. If a custom trace function
// has not been set using WithTraceFunc, trace messages are printed to stderr.
func WithTracing(enable bool) StageOption {
return func(o *stageOptions) {
o.tracing = enable
}
}
// WithErrorHandler installs a custom error handler which will be called
// from the processing functions when the filter/map/reduce function or
// an iterator emits an error.
//
// The handler should return true to continue processing or false to abort.
//
// The handler can stash the error for use in the pipeline's caller.
func WithErrorHandler(handler ErrorHandler) StageOption {
return func(o *stageOptions) {
o.onError = handler
}
}
// InheritOptions causes this stage's options to be inherited by the next
// stage. The next stage can override these inherited options. Further
// inheritence can be disabled by passing this option with a false value.
//
// The default is no inheritence.
func InheritOptions(inherit bool) StageOption {
return func(o *stageOptions) {
o.inheritOptions = inherit
}
}
func (o *stageOptions) processOptions(opts ...StageOption) {
for _, f := range opts {
f(o)
}
}
// NewStage instantiates a pipeline stage from an Iterator and optional
// set of processing optionns
func NewStage[T any](i Iterator[T], opts ...StageOption) *Stage[T] {
s := &Stage[T]{
i: i,
opts: stageOptions{
ctx: context.Background(),
sizeHint: DefaultSizeHint,
onError: nullErrorHandler,
},
id: stageCounter.Add(1),
wg: &sync.WaitGroup{},
}
s.opts.processOptions(opts...)
return s
}
// NewSliceStage instantiates a pipeline stage using a slice iterator backed by
// the provided slice.
func NewSliceStage[T any](s []T, opts ...StageOption) *Stage[T] {
iter := slice.New(s)
return NewStage(&iter, opts...)
}
// NewChannelStage instantiates a pipeline stage using a channel iterator
// backed by the provided channel.
func NewChannelStage[T any](ch chan T, opts ...StageOption) *Stage[T] {
iter := channel.New(ch)
return NewStage(&iter, opts...)
}
// NewScannerState instantiates a pipeline stage using a scanner iterator,
// backed by the provided scanner.
func NewScannerStage(s scanner.Scanner, opts ...StageOption) *Stage[string] {
iter := scanner.New(s)
return NewStage(&iter, opts...)
}
// Iterator returns the underlying iterator for a stage. It is most useful
// as a mechanism for retrieving the result from the last stage of a pipeline
// by the caller of the pipeline.
func (s *Stage[T]) Iterator() Iterator[T] {
return s.i
}
func (s *Stage[T]) tracer(description string, v ...any) tracer {
if s.opts.tracing {
var t T
description = fmt.Sprintf("(%T) %s", t, description)
return newTracer(s.id, description, s.opts.tracer, v...)
} else {
return nullTracer{}
}
}
func (s *Stage[T]) nextStage(i Iterator[T], opts ...StageOption) *Stage[T] {
return nextStage(s, i, opts...)
}
func nextStage[T, U any](s *Stage[T], i Iterator[U], opts ...StageOption) *Stage[U] {
nextStage := &Stage[U]{
i: i,
id: stageCounter.Add(1),
wg: s.wg,
}
// if this stage has inheritence enabled them copy its options to the
// next stage
if s.opts.inheritOptions {
nextStage.opts = s.opts
} else {
nextStage.opts = stageOptions{
ctx: context.Background(),
sizeHint: DefaultSizeHint,
onError: nullErrorHandler,
}
}
// process new options on their own to see if we should inherit
var newOpts stageOptions
newOpts.processOptions(opts...)
// .. if so then merge the new opts with the stage options
if newOpts.inheritOptions {
nextStage.opts.processOptions(opts...)
}
return nextStage
}
// parallelProcessor reads values from iter in a producer go-routine, and calls push() for
// each element. The push function should write an element to ch.
// numParallel worker goroutines read elements from from the producer goroutine and call
// pull() for each element. The pull function should write possibly new elements to ch.
// The return value is a channel to which unordered results can be read.
//
// T: source item type
// TW: wrapped source item type
// MW: wrapped result item type (same as TW for filters, possibly different than TW for maps)
func parallelProcessor[T, TW, MW any](opts stageOptions, numParallel uint, iter Iterator[T], t tracer,
push func(context.Context, uint, T, chan TW), pull func(TW, chan MW) error) chan MW {
chWorker := make(chan TW) // channel towards to workers
chOut := make(chan MW) // worker output channel
ctx, cancel := context.WithCancel(opts.ctx)
wgReader := sync.WaitGroup{}
wgReader.Add(1)
chStop := make(chan struct{})
// (1) Read items from the iterator in a separate goroutine, until done or
// the context expires, then write the items to the worker channel
go func() {
t := t.subTracer("reader")
defer wgReader.Done()
defer t.end()
// close chWorker when done.. this will cause the workers to terminate
// when they have processed the items
defer func() {
close(chWorker)
}()
i := 0
iterLoop:
for iter.Next(ctx) {
// run the push function which should write all items to chWorker
select {
case <-chStop: // if the workers terminate first, this tells us to stop
break iterLoop
default:
push(ctx, uint(i), iter.Get(), chWorker)
}
i++
}
// if there is an iterator read error, report it to the error handler
// even though there isn't anything we can do to abort the next stages
// .. we can at least stop sending new items
if iter.Error() != nil {
opts.onError(ErrorContextItertator, iter.Error())
}
}()
// (2) Start worker go-routines. These read items from chWorker until that
// channel is closed by the producer go-routine (1) above.
wgWorker := sync.WaitGroup{}
for i := uint(0); i < numParallel; i++ {
wgWorker.Add(1)
i := i
go func() {
t := t.subTracer("processor %d", i)
defer wgWorker.Done()
defer t.end()
readLoop:
for {
select {
case item, ok := <-chWorker:
if ok {
// run the pull function which should selectively write
// items to chOut depending on functionality
err := pull(item, chOut)
if err != nil {
if !opts.onError(ErrorContextOther, err) {
cancel()
continue
}
}
} else {
// if not OK the read failed on an empty, closed channel
break readLoop
}
case <-ctx.Done():
t.msg("cancelled")
opts.onError(ErrorContextOther, ctx.Err())
break readLoop
}
}
}()
}
// (3) Wait for the workers in a separate go-routine and close the result
// channel once they are all done
go func() {
t := t.subTracer("wait for processors")
wgWorker.Wait()
close(chOut)
close(chStop)
wgReader.Wait()
t.end()
cancel()
}()
return chOut
}
func closeChanIfOpen[T any](ch chan T) {
ok := true
select {
case _, ok = <-ch:
default:
}
if ok {
defer func() {
_ = recover()
}()
close(ch)
}
}
package functional
import (
"fmt"
"os"
"slices"
"strconv"
"strings"
"sync/atomic"
"time"
)
type tracer interface {
subTracer(description string, v ...any) tracer
msg(format string, v ...any)
end()
}
// TraceFunc defines the function prototype of a tracing function
// Per stage functions can be configured using WithTraceFunc
type TraceFunc func(format string, v ...any)
// DefaultTracer is the global default trace function. It prints messages to
// stderr. DefaultTracer can be replaced by another tracing function to effect
// all stages.
var DefaultTracer = func(format string, v ...any) {
fmt.Fprintf(os.Stderr, "<TRACE> "+format+"\n", v...)
}
type realTracer struct {
begin time.Time
description string
ids []uint32
subids atomic.Uint32
traceFunc TraceFunc
}
// newTracer creates a new tracer with a given ID and description. If
// the tracefunc f is nil, DefaultTracer is used to process trace calls.
// The optional parameters v are used as fmt.Printf parameters to format
// the description.
// Usually one tracer will be created for a transation, and sub-routines will
// create new tracers with SubTracer().
//
// Example:
//
// parentTracer := newTracer(1, "parent", nil)
func newTracer(id uint32, description string, f TraceFunc, v ...any) *realTracer {
if f == nil {
f = DefaultTracer
}
now := time.Now()
description = fmt.Sprintf(description, v...)
t := &realTracer{
begin: now,
description: description,
ids: []uint32{id},
traceFunc: f,
}
t.start()
return t
}
func (t *realTracer) id() string {
idStrings := make([]string, len(t.ids))
for i, n := range t.ids {
idStrings[i] = strconv.Itoa(int(n))
}
return strings.Join(idStrings, ".")
}
func (t *realTracer) start() {
t.begin = time.Now()
t.traceFunc("%s: START [stage #%s] %s", t.begin.Format(time.RFC3339), t.id(), t.description)
}
// subTracer returns a new tracer based on t, with a new sub ID
// description is formatted with the optional v parameters, and
// added to the description of the parent.
//
// Example:
//
// childTracer1 := parentTracer.subTracer("child %d", 1)
// childTracer2 := parentTracer.subTracer("child %d", 2)
// childTracer2a := chileTracer2.subTracer("grandchild %d", 1)
func (t *realTracer) subTracer(description string, v ...any) tracer {
subId := t.subids.Add(1)
t2 := *t
t2.subids = atomic.Uint32{}
t2.ids = append(slices.Clone(t.ids), subId)
t2.description += fmt.Sprintf(" / "+description, v...)
t2.start()
return &t2
}
// msg
func (t *realTracer) msg(format string, v ...any) {
var args []any = []any{
time.Now().Format(time.RFC3339), t.id(), t.description,
}
args = append(args, v...)
t.traceFunc("%s: MSG [stage #%s] %s: "+format, args...)
}
func (t *realTracer) end() {
t.traceFunc("%s: END [stage #%s] %s", time.Now().Format(time.RFC3339), t.id(), t.description)
}
type nullTracer struct{}
func (t nullTracer) subTracer(description string, v ...any) tracer { return t }
func (t nullTracer) msg(string, ...any) {}
func (t nullTracer) end() {}
package functional
/* The helpers here are used in the parallel batch processors in order to
* optimize the unordered processing case (streaming processors are always
* unordered).
*
* In the ordered case, every input element needs to be tagged
* with its index so that the results can be sorted after the processing
* goroutines are done. We do that by wrapping the elements in an
* item[T] struct. The wrapped element is passed to the processing
* goroutines which retrieve the original element from the struct.
*
* In the unordered case we would like to avoid the overhead of wrapping
* and unwrapping the elements. To do this, we define wrap and unwrap
* function templates (wrapItemFunc, unwrapItemFunc) that the processing
* functions call. In the unordered case, the concrete functions that
* implement these templates are no-ops (they just return the origin element),
* and in the ordered case the functions wrap the inpuut element in an
* item struct.
*
* The switcher functions are similar except that they return a different
* type than the input; these are used in map functions.
*/
// item wraps a variable along with an index
// this is passed between the producer and consumer go-routines of a parallel
// operation so that the order can be preserved
type item[T any] struct {
idx uint
item T
}
// in the function templates below:
// T represents the original unwrapped type (eg. string)
// TW represents the original wrapped type (eg. item[string])
// S represents the new type being mapped from T (eg. int)
// SW represente the wrapped mapped type (eg. item[int])
// a function that returns a value derived from an input value and an index
type wrapItemFunc[T any, TW any] func(i uint, t T) TW
// a function that returns the original value from a derived value
type unwrapItemFunc[TW any, T any] func(u TW) T
// a function that switches original T value from a wrapped item and replaces
// it with a new type in an output wrapped item
type switcherFunc[TW, S, SW any] func(w TW, x S) SW
// the unordered item maker just returns the input value, ignoring the
// index (it does nothing); TW is derived to be T by the compiler
func unorderedWrapper[T any](i uint, t T) T {
return t
}
// the unordered item getter returns the input value (does nothing)
// output element type T is derived to be TW by the compiler
func unorderedUnwrapper[TW any](tw TW) TW {
return tw
}
// the unordered item switcher just retruns the value of the new type (x)
// SW is derived to be S by the compiler
func unorderedSwitcher[TW any, S any](tw TW, s S) S {
return s
}
// the ordered item maker returns a wrapped item struct that includes the
// input value and index. TW is derived to be item[T] by the compiler
func orderedWrapper[T any](i uint, t T) item[T] {
return item[T]{
idx: i,
item: t,
}
}
// the ordered item getter returns the original value from an wrapped item
// TW is derived to be item[T] by the compiler
func orderedUnwrapper[T any](wt item[T]) T {
return wt.item
}
// the ordered type switcher returns an item wrapper containing the
// value s of mapped type S, with the same index as the original wrapped
// item tw.
func orderedSwitcher[T any, S any](tw item[T], s S) item[S] {
return item[S]{
idx: tw.idx,
item: s,
}
}