用wait/notify来实现两个线程交替执行

2018 Aug 07 See all posts


写代码久了,往往各种工具类和包用得飞起,却忘了最基础的基本概念,温故而知新还是非常重要的

前几天看到一个有趣的题目:

两个线程,一个会输出1,3,5,7,9,另外一个输出2,4,6,8,10,请只用wait/notify来控制两个线程交替执行,即输出1,2,3,4,5,6,7,8,9,10


首先要了解一下基础知识:wait/notify是java基类Object的基础方法,作用分别是:


这道题重点考察的几个知识点:

  synchronized (obj) {
      while (<condition does not hold>)
          obj.wait();
      ... // Perform action appropriate to condition
  }


以上知识点,记得多年前有本书《JAVA多线程设计模式》,里面有详细的解释和很多朴素的例子。

网上有好多解答其实是错的,常见的错误及表现:

话不多说,上代码:

package com.cs.exercise.concurrent.basic;

import java.util.ArrayList;
import java.util.List;

public class WaitNotify {
  private static final Object lock = new Object();
  private static final int END = 10;
  private boolean th1FinishedCurRound = false;
  private boolean th2FinishedCurRound = false;
  private static final List<Integer> result = new ArrayList<>(END);

  class Thread1 extends Thread {

    @Override
    public void run() {
      int i = 1;
      boolean exit = false;
      while (!exit) {
        th1FinishedCurRound = false;
        synchronized (lock) {
          result.add(i);
          i = i + 2;
          if (i > END) {
            exit = true;
          }
          th1FinishedCurRound = true;
          th2FinishedCurRound = false;
          lock.notifyAll();
          try {
            while (!th2FinishedCurRound) {
              lock.wait();
            }
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
        }
      }
    }
  }

  class Thread2 extends Thread {

    @Override
    public void run() {
      int i = 2;
      boolean exit = false;
      while (!exit) {
        synchronized (lock) {
          try {
            while (!th1FinishedCurRound) {
              lock.wait();
            }
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
          result.add(i);
          i = i + 2;
          th2FinishedCurRound = true;
          th1FinishedCurRound = false;
          if (i > END) {
            exit = true;
          }
          lock.notifyAll();
        }
      }
    }
  }

  public void doRound(int ROUND_COUNT) {
    int successCount = 0;
    long start = System.currentTimeMillis();
    for (int i = 0; i < ROUND_COUNT; i++) {
      result.clear();
      Thread1 th1 = new Thread1();
      Thread2 th2 = new Thread2();
      th1.start();
      th2.start();
      try {
        th1.join();
        th2.join();
      } catch (InterruptedException e) {
        e.printStackTrace();
      }
      if (isSuccess(result)) {
        successCount++;
      } else {
        System.out.println("Round " + i + " error," + result.size() + result);
      }
      long elapsed = System.currentTimeMillis() - start;
      if (successCount == ROUND_COUNT) {
        System.out.println("ALL rounds are SUCCESS, round count:" + ROUND_COUNT  +", time cost(ms): " + elapsed);
      }
    }
  }

  private static boolean isSuccess(List<Integer> list) {
    if (list.size() != END) {
      return false;
    }
    for (int i = 1; i <= END; i++) {
      if (!list.get(i - 1).equals(i)) {
        return false;
      }
    }
    return true;
  }

  public static void main(String[] args) {
    WaitNotify waitNotify = new WaitNotify();
    waitNotify.doRound(100000);
  }

}

进阶:如果如果线程更多呢,比如3个线程甚至更多的线程呢,只要理解了上面的代码,就可以很容易写出来更通用简洁的代码,N个线程也不在话下

package com.cs.exercise.concurrent.basic;

import java.util.ArrayList;
import java.util.List;

public class WaitNotifyGeneral {
  private static final Object lock = new Object();
  private static final int END = 100;
  private int flag = 0;
  private static final List<Integer> result = new ArrayList<>(END);

  class CounterThread extends Thread {
    private final int threadIndex;
    private final int threadCount;
    private final int first;
    private final int step;

    public CounterThread(int threadIndex, int threadCount, int first, int step) {
      this.setName("thread" + threadIndex);
      this.threadIndex = threadIndex;
      this.threadCount = threadCount;
      this.first = first;
      this.step = step;
    }

    @Override
    public void run() {
      int i = first;
      boolean exit = false;
      while (!exit) {
        synchronized (lock) {
          try {
            while (flag != threadIndex) {
              lock.wait();
            }
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
          result.add(i);
          i = i + step;
          if (i > END) {
            exit = true;
          }
          flag = (flag + 1) % threadCount;
          lock.notifyAll();
          try {
            while (flag != threadIndex && i <= END) {
              lock.wait();
            }
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
        }
      }
    }
  }


  public void doRound(int firstValue, int threadCount, int ROUND_COUNT) {
    int successCount = 0;
    long start = System.currentTimeMillis();
    for (int i = 0; i < ROUND_COUNT; i++) {
      result.clear();
      flag = 0;
      Thread[] threads = new Thread[threadCount];
      for (int j = 0; j < threadCount; j++) {
        threads[j] = new CounterThread(j, threadCount, firstValue + j, threadCount);
      }
      for (Thread thread : threads) {
        thread.start();
      }
      for (Thread thread : threads) {
        try {
          thread.join();
        } catch (InterruptedException e) {
          e.printStackTrace();
        }
      }
      if (isSuccess(result)) {
        successCount++;
      } else {
        System.out.println("Round " + i + " error," + result.size() + result);
      }
      long elapsed = System.currentTimeMillis() - start;
      if (successCount == ROUND_COUNT) {
        System.out.println("ALL rounds are SUCCESS, round count:" + ROUND_COUNT + ", time cost(ms): " + elapsed);
      }
    }
  }

  private static boolean isSuccess(List<Integer> list) {
    if (list.size() != END) {
      return false;
    }
    for (int i = 1; i <= END; i++) {
      if (!list.get(i - 1).equals(i)) {
        return false;
      }
    }
    return true;
  }

  public static void main(String[] args) {
    WaitNotifyGeneral waitNotify = new WaitNotifyGeneral();
    waitNotify.doRound(1, 3, 10000);
  }

}

注意,如果线程执行完当前轮等待的时候,很容易漏掉i <= END这个条件写成flag != threadIndex,这样的后果是最后一个线程阻塞后,但所有任务其实已经完成,没有别的线程唤醒它,这样这个线程永远退出不了,就此死锁,下一轮没法玩了。

Back to top