App下載

如何使用python在2D中計(jì)算多邊形IoU

猿友 2021-07-21 10:07:07 瀏覽數(shù) (3162)
反饋

本文W3Cschool和大家介紹一下如何使用python在2D中計(jì)算多邊形IoU。假設(shè)多邊形不是自相交的,即圓圓周圍的點(diǎn)的順序是單調(diào)的,那么我相信有一個(gè)相對(duì)簡(jiǎn)單的方法來(lái)確定 IoU 值,而不需要一個(gè)一般的形狀包。

  1. 假設(shè)每個(gè)多邊形的點(diǎn)是順時(shí)針在圓圈中排列的。如果我們發(fā)現(xiàn)簽名區(qū)域?yàn)樨?fù)值,我們可以通過(guò)增加角度 w.r.t x 軸或倒車點(diǎn)來(lái)確保這一點(diǎn)。
  2. 將兩個(gè)多邊形的點(diǎn)合并到單個(gè)列表中,跟蹤每個(gè)點(diǎn)屬于哪個(gè)多邊形。我們還需要能夠確定每個(gè)點(diǎn)在原始多邊形中的上一點(diǎn)和下一點(diǎn)。L
  3. 通過(guò)增加角度來(lái)排序w.r.t x軸。L
  4. 如果輸入多邊形相交,則從一個(gè)多邊形到另一個(gè)多邊形的過(guò)渡次數(shù)將大于兩個(gè)。L
  5. 遍歷。如果連續(xù)點(diǎn)屬于不同的多邊形,則第一點(diǎn)與其下一點(diǎn)和第二點(diǎn)之間的線的交叉點(diǎn)及其前一點(diǎn)將屬于兩個(gè)多邊形之間的交點(diǎn)。L
  6. 將步驟 4 中確定的每個(gè)點(diǎn)添加到新的多邊形中。將按順序遇到積分。II
  7. 每個(gè)多邊形的面積之和將等于其聯(lián)盟加上交叉口的區(qū)域,因?yàn)檫@將計(jì)算兩次。
  8. 因此,將的價(jià)值由兩個(gè)多邊形的面積之和減去面積之和來(lái)表示。IoUII

唯一需要的幾何形狀是使用Shoelace 公式計(jì)算簡(jiǎn)單多邊形的面積,并確定步驟 5 所需的兩條線段之間的交匯點(diǎn)。

這里有一些Java代碼(Ideone)來(lái)說(shuō)明 - 你也許可以使它在Python更緊湊。

double[][] coords = {{-0.708, 0.707, 0.309, -0.951, 0.587, -0.809},
                       {1, 0, 0, 1, -1, 0, 0, -1, 0.708, -0.708}};

double areaSum = 0;
List<CPoint> pts = new ArrayList<>();
for(int p=0; p<coords.length; p++)
{
    List<CPoint> poly = new ArrayList<>();
    for(int j=0; j<coords[p].length; j+=2)
    {
        poly.add(new CPoint(p, coords[p][j], coords[p][j+1]));
    }
    
    double area = area(poly);
    if(area < 0)
    {
        area = -area;
        Collections.reverse(poly);
    }
    areaSum += area;
    
    pts.addAll(poly);

    int n = poly.size();
    for(int i=0, j=n-1; i<n; j=i++)
    {
        poly.get(i).prev = poly.get(j);
        poly.get(j).next = poly.get(i);             
    }
}       
        
pts.sort((a, b) -> Double.compare(a.theta, b.theta));
        
List<Point2D> intersections = new ArrayList<>();
int n = pts.size();
for(int i=0, j=n-1; i<n; j=i++)
{
    if(pts.get(j).id != pts.get(i).id)
    {
        intersections.add(intersect(pts.get(j), pts.get(j).next, pts.get(i).prev, pts.get(i)));
    }
}

double areaInt = area(intersections);
double iou = areaInt/(areaSum - areaInt);
System.out.println(iou);

輸出:

0.12403616470027268

和支持代碼:

static class CPoint extends Point2D.Double
{
    int id;
    double theta;
    CPoint prev, next;
    
    public CPoint(int id, double x, double y)
    {
        super(x, y);
        this.id = id;
        theta = Math.atan2(y, x);
        if(theta < 0) theta = 2*Math.PI + theta;
    }
}   

static double area(List<? extends Point2D> poly)
{
    double area = 0;
    for(int i=0, j=poly.size()-1; i<poly.size(); j=i++)
        area += (poly.get(j).getX() * poly.get(i).getY()) - (poly.get(i).getX() * poly.get(j).getY());
    return Math.abs(area)/2;
}

// https://rosettacode.org/wiki/Find_the_intersection_of_two_lines#Java
static Point2D intersect(Point2D p1, Point2D p2, Point2D p3, Point2D p4)
{
  double a1 = p2.getY() - p1.getY();
  double b1 = p1.getX() - p2.getX();
  double c1 = a1 * p1.getX() + b1 * p1.getY();

  double a2 = p4.getY() - p3.getY();
  double b2 = p3.getX() - p4.getX();
  double c2 = a2 * p3.getX() + b2 * p3.getY();

0 人點(diǎn)贊