Source file src/runtime/testdata/testprog/coro.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"iter"
    10  	"runtime"
    11  )
    12  
    13  func init() {
    14  	register("CoroLockOSThreadIterLock", func() {
    15  		println("expect: OK")
    16  		CoroLockOSThread(callerExhaust, iterLock)
    17  	})
    18  	register("CoroLockOSThreadIterLockYield", func() {
    19  		println("expect: OS thread locking must match")
    20  		CoroLockOSThread(callerExhaust, iterLockYield)
    21  	})
    22  	register("CoroLockOSThreadLock", func() {
    23  		println("expect: OK")
    24  		CoroLockOSThread(callerExhaustLocked, iterSimple)
    25  	})
    26  	register("CoroLockOSThreadLockIterNested", func() {
    27  		println("expect: OK")
    28  		CoroLockOSThread(callerExhaustLocked, iterNested)
    29  	})
    30  	register("CoroLockOSThreadLockIterLock", func() {
    31  		println("expect: OK")
    32  		CoroLockOSThread(callerExhaustLocked, iterLock)
    33  	})
    34  	register("CoroLockOSThreadLockIterLockYield", func() {
    35  		println("expect: OS thread locking must match")
    36  		CoroLockOSThread(callerExhaustLocked, iterLockYield)
    37  	})
    38  	register("CoroLockOSThreadLockIterYieldNewG", func() {
    39  		println("expect: OS thread locking must match")
    40  		CoroLockOSThread(callerExhaustLocked, iterYieldNewG)
    41  	})
    42  	register("CoroLockOSThreadLockAfterPull", func() {
    43  		println("expect: OS thread locking must match")
    44  		CoroLockOSThread(callerLockAfterPull, iterSimple)
    45  	})
    46  	register("CoroLockOSThreadStopLocked", func() {
    47  		println("expect: OK")
    48  		CoroLockOSThread(callerStopLocked, iterSimple)
    49  	})
    50  	register("CoroLockOSThreadStopLockedIterNested", func() {
    51  		println("expect: OK")
    52  		CoroLockOSThread(callerStopLocked, iterNested)
    53  	})
    54  }
    55  
    56  func CoroLockOSThread(driver func(iter.Seq[int]) error, seq iter.Seq[int]) {
    57  	if err := driver(seq); err != nil {
    58  		println("error:", err.Error())
    59  		return
    60  	}
    61  	println("OK")
    62  }
    63  
    64  func callerExhaust(i iter.Seq[int]) error {
    65  	next, _ := iter.Pull(i)
    66  	for {
    67  		v, ok := next()
    68  		if !ok {
    69  			break
    70  		}
    71  		if v != 5 {
    72  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
    73  		}
    74  	}
    75  	return nil
    76  }
    77  
    78  func callerExhaustLocked(i iter.Seq[int]) error {
    79  	runtime.LockOSThread()
    80  	next, _ := iter.Pull(i)
    81  	for {
    82  		v, ok := next()
    83  		if !ok {
    84  			break
    85  		}
    86  		if v != 5 {
    87  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
    88  		}
    89  	}
    90  	runtime.UnlockOSThread()
    91  	return nil
    92  }
    93  
    94  func callerLockAfterPull(i iter.Seq[int]) error {
    95  	n := 0
    96  	next, _ := iter.Pull(i)
    97  	for {
    98  		runtime.LockOSThread()
    99  		n++
   100  		v, ok := next()
   101  		if !ok {
   102  			break
   103  		}
   104  		if v != 5 {
   105  			return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   106  		}
   107  	}
   108  	for range n {
   109  		runtime.UnlockOSThread()
   110  	}
   111  	return nil
   112  }
   113  
   114  func callerStopLocked(i iter.Seq[int]) error {
   115  	runtime.LockOSThread()
   116  	next, stop := iter.Pull(i)
   117  	v, _ := next()
   118  	stop()
   119  	if v != 5 {
   120  		return fmt.Errorf("bad iterator: wanted value %d, got %d", 5, v)
   121  	}
   122  	runtime.UnlockOSThread()
   123  	return nil
   124  }
   125  
   126  func iterSimple(yield func(int) bool) {
   127  	for range 3 {
   128  		if !yield(5) {
   129  			return
   130  		}
   131  	}
   132  }
   133  
   134  func iterNested(yield func(int) bool) {
   135  	next, stop := iter.Pull(iterSimple)
   136  	for {
   137  		v, ok := next()
   138  		if ok {
   139  			if !yield(v) {
   140  				stop()
   141  			}
   142  		} else {
   143  			return
   144  		}
   145  	}
   146  }
   147  
   148  func iterLock(yield func(int) bool) {
   149  	for range 3 {
   150  		runtime.LockOSThread()
   151  		runtime.UnlockOSThread()
   152  
   153  		if !yield(5) {
   154  			return
   155  		}
   156  	}
   157  }
   158  
   159  func iterLockYield(yield func(int) bool) {
   160  	for range 3 {
   161  		runtime.LockOSThread()
   162  		ok := yield(5)
   163  		runtime.UnlockOSThread()
   164  		if !ok {
   165  			return
   166  		}
   167  	}
   168  }
   169  
   170  func iterYieldNewG(yield func(int) bool) {
   171  	for range 3 {
   172  		done := make(chan struct{})
   173  		var ok bool
   174  		go func() {
   175  			ok = yield(5)
   176  			done <- struct{}{}
   177  		}()
   178  		<-done
   179  		if !ok {
   180  			return
   181  		}
   182  	}
   183  }
   184  

View as plain text