#!/bin/env python
# coding: utf-8

import Queue
import threading
from contextlib import contextmanager
import time

#停止事件
StopEvent = object()

class ThreadPool(object):
  def __init__(self, max_num):
    self.q = Queue.Queue(max_num)
    self.max_num = max_num
    self.cancel = False
    self.terminal = False
    self.active_threads = []
    self.free_threads = []

  #线程池执行一个任务
  def run(self, func, args):
    if self.cancel:
      return

    if len(self.free_threads) == 0 and len(self.active_threads) < self.max_num:
      self.new_thread()
    w = (func, args)
    self.q.put(w)

  #创建一个线程
  def new_thread(self):
    t = threading.Thread(target=self.call)
    t.start()

  #获取任务并执行
  def call(self):
    current_thread = threading.currentThread()
    self.active_threads.append(current_thread)
    event = self.q.get()
    while event != StopEvent:
      func, args = event
      try:
        result = func(*args)
        success = True
      except Exception as e:
        success = False
        result = None

      with self.worker_state(self.free_threads, current_thread):
        if self.terminal:
          event = StopEvent
        else:
          event = self.q.get()
    else:
      self.active_threads.remove(current_thread)

  #执行完所有的任务后,所有线程停止
  def close(self):
    self.cancel = True
    count = len(self.active_threads)
    while count:
      self.q.put(StopEvent)
      count -= 1

  #终止线程
  def terminate(self):
    self.terminal = True
    while self.active_threads:
      self.q.put(StopEvent)
    self.q.queue.clear()

  #用于记录线程中正在等待的线程数
  @contextmanager
  def worker_state(self, state_threads, worker_thread):
    state_threads.append(worker_thread)
    try:
      yield
    finally:
      state_threads.remove(worker_thread)

pool = ThreadPool(5)

#用户定义的任务
def task(i):
  print(i)

for i in range(30):
  ret = pool.run(task, (i,))

time.sleep(3)

pool.close()

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐