Hospital Plans from bad CAD

I’m please at where this is headed. Having moved back to medical work, we get a lot of large questionable CAD files.. But need Revit. I’ve used EvolveLAB Helix a few times for conversion, but it was SLOW. Some plans were overnight runs. (I think I know why now.)
I whacked together a basic app to take the wall layers and make walls. It worked ok. But stepped back and took a look at what ChatGPT spit out. Processing a checking endpoints and lots of math to find candidate walls.

Just look at the lines bounding box and make it bigger. Fast. Like 100x faster.

So still have a bit more work to do but should be able to draw big floor plans zoom-zoom soon.
(Thanks for the Overkill…..)

The below is running on just model lines from the CAD files to get the concept down.

# -*- coding: utf-8 -*-
__title__ = "Trace Center (Offset + Overkill)"
__author__ = "OpenAI"

# pylint: disable=E0401,C0103

import math
from collections import namedtuple

from pyrevit.framework import List
from pyrevit import revit, DB, forms, script
from pyrevit.compat import get_elementid_value_func

doc = revit.doc
uidoc = revit.uidoc
active_view = doc.ActiveView
output = script.get_output()
logger = script.get_logger()

MAX_OFFSET = 1.5   # 1'-6" in internal units (feet)
ANG_TOL = 1e-6
DIST_TOL = 1e-6
LEN_TOL = 1e-6

# overkill configs
POINT_RESOLUTION = '{:.4f}'
DIRECTION_RESOLUTION = '{:.2f}'
OFFSET_RESOLUTION = '{:.4f}'


# -----------------------------------------------------------------------------
# Pairing helpers
# -----------------------------------------------------------------------------
def is_line_curve(curve):
    return isinstance(curve, DB.Line)


def normalize(v):
    if v is None:
        return None
    if v.GetLength() < 1e-12:
        return None
    return v.Normalize()


def are_parallel(v1, v2, tol=ANG_TOL):
    if v1 is None or v2 is None:
        return False
    return v1.CrossProduct(v2).GetLength() <= tol


def get_curve_bbox_xyz(curve):
    p0 = curve.GetEndPoint(0)
    p1 = curve.GetEndPoint(1)

    min_pt = DB.XYZ(
        min(p0.X, p1.X),
        min(p0.Y, p1.Y),
        min(p0.Z, p1.Z)
    )
    max_pt = DB.XYZ(
        max(p0.X, p1.X),
        max(p0.Y, p1.Y),
        max(p0.Z, p1.Z)
    )
    return min_pt, max_pt


def expand_bbox(min_pt, max_pt, offset):
    return (
        DB.XYZ(min_pt.X - offset, min_pt.Y - offset, min_pt.Z - offset),
        DB.XYZ(max_pt.X + offset, max_pt.Y + offset, max_pt.Z + offset)
    )


def bboxes_overlap(min_a, max_a, min_b, max_b, tol=1e-9):
    return (
        min_a.X <= max_b.X + tol and max_a.X >= min_b.X - tol and
        min_a.Y <= max_b.Y + tol and max_a.Y >= min_b.Y - tol and
        min_a.Z <= max_b.Z + tol and max_a.Z >= min_b.Z - tol
    )


def get_active_view_sketch_plane(document, view):
    sp = view.SketchPlane
    if sp:
        return sp

    plane = DB.Plane.CreateByNormalAndOrigin(view.ViewDirection, view.Origin)

    t = DB.Transaction(document, "Create Active View Sketch Plane")
    t.Start()
    sp = DB.SketchPlane.Create(document, plane)
    view.SketchPlane = sp
    t.Commit()
    return sp


class LineStyleOption(forms.TemplateListItem):
    @property
    def name(self):
        gs = self.item
        cat = gs.GraphicsStyleCategory
        return cat.Name if cat else "<No Category>"


def get_model_line_styles(document):
    styles = []
    lines_cat = document.Settings.Categories.get_Item(DB.BuiltInCategory.OST_Lines)
    if not lines_cat:
        return styles

    for subcat in lines_cat.SubCategories:
        try:
            gs = subcat.GetGraphicsStyle(DB.GraphicsStyleType.Projection)
            if gs:
                styles.append(gs)
        except:
            pass

    return sorted(styles, key=lambda x: x.GraphicsStyleCategory.Name)


def collect_model_lines_in_active_view(document, view):
    fec = DB.FilteredElementCollector(document, view.Id).OfClass(DB.CurveElement)
    model_lines = []

    for ce in fec:
        try:
            if not isinstance(ce, DB.ModelCurve):
                continue

            curve = ce.GeometryCurve
            if not curve:
                continue
            if not is_line_curve(curve):
                continue
            if curve.Length <= LEN_TOL:
                continue

            model_lines.append(ce)
        except Exception as ex:
            output.print_md(
                "- Skipped curve element {} because geometry could not be read: `{}`".format(
                    ce.Id.IntegerValue, ex
                )
            )

    return model_lines


def line_data(model_curve):
    curve = model_curve.GeometryCurve
    p0 = curve.GetEndPoint(0)
    p1 = curve.GetEndPoint(1)
    direction = normalize(p1 - p0)
    bb_min, bb_max = get_curve_bbox_xyz(curve)
    midpoint = (p0 + p1) * 0.5

    return {
        "elem": model_curve,
        "curve": curve,
        "p0": p0,
        "p1": p1,
        "mid": midpoint,
        "dir": direction,
        "length": curve.Length,
        "id": model_curve.Id.IntegerValue,
        "bb_min": bb_min,
        "bb_max": bb_max,
    }


def project_point_to_segment(seg_curve, test_point, tol=1e-6):
    r = seg_curve.Project(test_point)
    if r is None:
        return (False, None, None)

    proj = r.XYZPoint
    p0 = seg_curve.GetEndPoint(0)
    p1 = seg_curve.GetEndPoint(1)

    seg_len = p0.DistanceTo(p1)
    if seg_len <= tol:
        return (False, None, None)

    on_seg = abs((proj.DistanceTo(p0) + proj.DistanceTo(p1)) - seg_len) <= tol
    return (on_seg, proj, test_point.DistanceTo(proj))


def build_center_line_from_pair(data_a, data_b, max_offset):
    """
    For each valid pair:
    - choose the shorter line
    - project the shorter line midpoint to the other line
    - offset-transform the shorter line by half that vector
    """
    if not are_parallel(data_a["dir"], data_b["dir"]):
        return None

    if data_a["length"] <= data_b["length"]:
        short_data = data_a
        other_data = data_b
    else:
        short_data = data_b
        other_data = data_a

    ok, proj_mid, dist = project_point_to_segment(other_data["curve"], short_data["mid"], DIST_TOL)
    if not ok:
        return None

    if dist is None or dist <= DIST_TOL or dist > max_offset:
        return None

    move_vec = proj_mid - short_data["mid"]

    if abs(move_vec.DotProduct(short_data["dir"])) > 1e-5:
        return None

    half_vec = move_vec * 0.5

    c0 = short_data["p0"] + half_vec
    c1 = short_data["p1"] + half_vec

    if c0.DistanceTo(c1) <= LEN_TOL:
        return None

    return DB.Line.CreateBound(c0, c1)


# -----------------------------------------------------------------------------
# Overkill helpers adapted from your script
# -----------------------------------------------------------------------------
CurvePoint = namedtuple('CurvePoint', ['x', 'y', 'cid'])


class LinearCurveGroup(object):
    supported_geoms = (DB.Line,)

    def __init__(self, curve, include_style=False):
        self.points = set()

        p1 = curve.GeometryCurve.GetEndPoint(0)
        p2 = curve.GeometryCurve.GetEndPoint(1)

        self.dir_x, self.dir_y = self.get_direction(curve)
        self.dir_offset = self.get_offset(p1, p2)
        self.weight = self.get_weight(curve) if include_style else None

        self.cgroup_id = (self.dir_x, self.dir_y, self.dir_offset, self.weight)

        get_elementid_value = get_elementid_value_func()
        self.dir_cid = get_elementid_value(curve.Id)
        self.add_points([
            self.get_point(p1.X, p1.Y),
            self.get_point(p2.X, p2.Y),
        ])

    @staticmethod
    def shortest_dist(x1, y1, x2, y2, x0, y0):
        return (
            abs(x0 * (y2 - y1) - y0 * (x2 - x1) + x2 * y1 - y2 * x1) /
            math.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
        )

    @staticmethod
    def is_inside(min_cpoint, max_cpoint, cpoint):
        x_range = max_cpoint.x - min_cpoint.x
        y_range = max_cpoint.y - min_cpoint.y

        if x_range == 0.0:
            return min_cpoint.y <= cpoint.y <= max_cpoint.y
        elif y_range == 0.0:
            return min_cpoint.x <= cpoint.x <= max_cpoint.x
        else:
            xr = LinearCurveGroup.get_float((cpoint.x - min_cpoint.x) / x_range)
            yr = LinearCurveGroup.get_float((cpoint.y - min_cpoint.y) / y_range)
            return xr == yr

    @staticmethod
    def get_style_weight(curve):
        graphic_style_cat = curve.LineStyle.GraphicsStyleCategory
        graphic_style_type = curve.LineStyle.GraphicsStyleType
        return graphic_style_cat.GetLineWeight(graphic_style_type)

    @staticmethod
    def get_float(float_value, res=POINT_RESOLUTION):
        return float(res.format(float_value))

    @property
    def max_point(self):
        return max(self.points, key=lambda p: p.x + p.y)

    @property
    def min_point(self):
        return min(self.points, key=lambda p: p.x + p.y)

    def get_point(self, x, y):
        return CurvePoint(
            x=LinearCurveGroup.get_float(x, res=POINT_RESOLUTION),
            y=LinearCurveGroup.get_float(y, res=POINT_RESOLUTION),
            cid=self.dir_cid
        )

    def get_direction(self, curve):
        dir_x = curve.GeometryCurve.Direction.X
        dir_y = curve.GeometryCurve.Direction.Y

        if (dir_x <= 0.0 and dir_y <= 0.0) or (dir_x > 0.0 and dir_y < 0.0):
            dir_x, dir_y = -dir_x, -dir_y

        dir_x = 0.0 if dir_x == 0.0 else dir_x
        dir_y = 0.0 if dir_y == 0.0 else dir_y

        return (
            LinearCurveGroup.get_float(dir_x, res=DIRECTION_RESOLUTION),
            LinearCurveGroup.get_float(dir_y, res=DIRECTION_RESOLUTION)
        )

    def get_offset(self, p1, p2):
        dir_offset = LinearCurveGroup.shortest_dist(
            LinearCurveGroup.get_float(p1.X, res=POINT_RESOLUTION),
            LinearCurveGroup.get_float(p1.Y, res=POINT_RESOLUTION),
            LinearCurveGroup.get_float(p2.X, res=POINT_RESOLUTION),
            LinearCurveGroup.get_float(p2.Y, res=POINT_RESOLUTION),
            self.dir_x,
            self.dir_y
        )
        return LinearCurveGroup.get_float(dir_offset, res=OFFSET_RESOLUTION)

    def get_weight(self, curve):
        return LinearCurveGroup.get_style_weight(curve)

    def add_points(self, curve_points):
        for curve_point in curve_points:
            if isinstance(curve_point, CurvePoint):
                self.points.add(curve_point)

    def merge(self, cgroup):
        if isinstance(cgroup, LinearCurveGroup):
            for cgroup_point in cgroup.points:
                if LinearCurveGroup.is_inside(self.min_point, self.max_point, cgroup_point):
                    self.add_points(cgroup.points)
                    return True
        return False

    def overkill(self, document=None):
        bounded_curve_ids = set()

        if document and len(self.points) > 2:
            root_curve = document.GetElement(DB.ElementId(self.dir_cid))
            min_dpoint = self.min_point
            max_dpoint = self.max_point

            curve_z = root_curve.GeometryCurve.GetEndPoint(0).Z
            min_point = DB.XYZ(min_dpoint.x, min_dpoint.y, curve_z)
            max_point = DB.XYZ(max_dpoint.x, max_dpoint.y, curve_z)

            reset_dir_curve = False
            try:
                geom_curve = DB.Line.CreateBound(min_point, max_point)
                root_curve.SetGeometryCurve(geom_curve, overrideJoins=True)
                reset_dir_curve = True
            except Exception as set_ex:
                logger.debug('Failed re-setting root curve | %s', set_ex)

            if reset_dir_curve:
                bounded_curve_ids = {x.cid for x in self.points if x.cid != self.dir_cid}
                if bounded_curve_ids:
                    bounded_curve_ids = [DB.ElementId(x) for x in bounded_curve_ids]
                    document.Delete(List[DB.ElementId](bounded_curve_ids))

        return len(bounded_curve_ids)


class CurveGroupCollection(object):
    def __init__(self, include_style=False):
        self.curve_groups = []
        self._include_style = include_style

    def __iter__(self):
        return iter(self.curve_groups)

    def merge(self, cgroup):
        matching_cgroups = [x for x in self.curve_groups if x.cgroup_id == cgroup.cgroup_id]
        if matching_cgroups:
            first_matching = matching_cgroups[0]
            matching_cgroups.remove(first_matching)
            if matching_cgroups:
                for matching_cgroup in matching_cgroups:
                    cgroup.merge(matching_cgroup)
                    self.curve_groups.remove(matching_cgroup)
            return first_matching.merge(cgroup)

    def extend(self, curve_element):
        if isinstance(curve_element.GeometryCurve, LinearCurveGroup.supported_geoms):
            cgroup = LinearCurveGroup(curve_element, include_style=self._include_style)
            if not self.merge(cgroup):
                self.curve_groups.append(cgroup)
            return True


def overkill_created_curves(curve_elements, include_style=True):
    cgroup_collection = CurveGroupCollection(include_style=include_style)

    for curve_element in curve_elements:
        cgroup_collection.extend(curve_element)

    del_count = 0
    for cgroup in cgroup_collection:
        del_count += cgroup.overkill(document=doc)

    return del_count


# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
output.print_md("### Trace Center")
output.print_md("- Active view: `{}`".format(active_view.Name))
output.print_md("- View type: `{}`".format(active_view.ViewType))

styles = get_model_line_styles(doc)
if not styles:
    forms.alert("No line styles were found under Lines subcategories.", exitscript=True)

selected_style = forms.SelectFromList.show(
    [LineStyleOption(x) for x in styles],
    title="Select Model Line Style",
    button_name="Use Selected Style",
    multiselect=False
)

if not selected_style:
    script.exit()

sketch_plane = get_active_view_sketch_plane(doc, active_view)
if not sketch_plane:
    forms.alert("Could not obtain or create a sketch plane for the active view.", exitscript=True)

model_lines = collect_model_lines_in_active_view(doc, active_view)
if len(model_lines) < 2:
    forms.alert("Need at least two straight model lines in the active view.", exitscript=True)

data = [line_data(x) for x in model_lines]

output.print_md("- Straight model lines found: `{}`".format(len(data)))

total_pairs = (len(data) * (len(data) - 1)) // 2
output.print_md("- Total possible pairs: `{}`".format(total_pairs))

candidate_lines = []
broad_phase_hits = 0
exact_checks = 0
valid_pair_count = 0
log_step = 250

for i in range(len(data)):
    data_i = data[i]
    exp_min_i, exp_max_i = expand_bbox(data_i["bb_min"], data_i["bb_max"], MAX_OFFSET)

    for j in range(i + 1, len(data)):
        data_j = data[j]

        if not bboxes_overlap(exp_min_i, exp_max_i, data_j["bb_min"], data_j["bb_max"]):
            continue

        broad_phase_hits += 1

        if not are_parallel(data_i["dir"], data_j["dir"]):
            continue

        exact_checks += 1

        if exact_checks % log_step == 0:
            output.print_md("- Exact checks after bbox prefilter: `{}`".format(exact_checks))

        try:
            center_line = build_center_line_from_pair(data_i, data_j, MAX_OFFSET)
            if center_line:
                valid_pair_count += 1
                candidate_lines.append(center_line)
        except Exception as ex:
            output.print_md(
                "- Failed pair {} / {}: `{}`".format(data_i["id"], data_j["id"], ex)
            )

output.print_md("- BBox candidate hits: `{}`".format(broad_phase_hits))
output.print_md("- Exact geometry checks: `{}`".format(exact_checks))
output.print_md("- Raw candidate center lines: `{}`".format(len(candidate_lines)))
output.print_md("- Valid source pairs found: `{}`".format(valid_pair_count))

if not candidate_lines:
    forms.alert("No valid parallel line pairs were found within 1'-6\".", exitscript=True)

created_ids = []
failed = []
create_log_step = 100

with revit.Transaction("Create Centered Model Lines From Parallel Pairs", swallow_errors=True):
    for idx, line in enumerate(candidate_lines, 1):
        if idx % create_log_step == 0 or idx == len(candidate_lines):
            output.print_md("- Creation progress: `{}` / `{}`".format(idx, len(candidate_lines)))

        try:
            new_mc = doc.Create.NewModelCurve(line, sketch_plane)
            if isinstance(new_mc, DB.CurveElement):
                new_mc.LineStyle = selected_style
            created_ids.append(new_mc.Id)
        except Exception as ex:
            failed.append((idx, str(ex)))

created_curves = [doc.GetElement(x) for x in created_ids if doc.GetElement(x) is not None]

deleted_count = 0
if created_curves:
    with revit.Transaction("Overkill Center Lines", swallow_errors=True):
        deleted_count = overkill_created_curves(created_curves, include_style=True)

final_created_count = len(created_ids) - deleted_count

msg = []
msg.append("Processed {} model lines in active view.".format(len(model_lines)))
msg.append("Total possible pairs: {}.".format(total_pairs))
msg.append("BBox candidate hits: {}.".format(broad_phase_hits))
msg.append("Exact checks after prefilter: {}.".format(exact_checks))
msg.append("Valid source pairs found: {}.".format(valid_pair_count))
msg.append("Raw candidate center lines created: {}.".format(len(created_ids)))
msg.append("Overkill removed {} overlapping line(s).".format(deleted_count))
msg.append("Final remaining center lines: {}.".format(final_created_count))
msg.append("Failed {} line(s).".format(len(failed)))

forms.alert("\n".join(msg), title="Complete")

if failed:
    output.print_md("### Failed created lines")
    for idx, err in failed:
        output.print_md("- Raw candidate line {} failed: `{}`".format(idx, err))
2 Likes